naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17    ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18    TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21    arena::{Handle, HandleSet},
22    back::{
23        self, get_entry_points,
24        msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25        Baked,
26    },
27    common,
28    proc::{
29        self, concrete_int_scalars,
30        index::{self, BoundsCheck},
31        ExternalTextureNameKey, NameKey, TypeResolution,
32    },
33    valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39// This is a hack: we need to pass a pointer to an atomic,
40// but generally the backend isn't putting "&" in front of every pointer.
41// Some more general handling of pointers is needed to be implemented here.
42const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59    "nagaTextureSampleBaseClampToEdge";
60/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
61/// However, if you put that texture inside a struct, everything is totally fine. This
62/// baffles me to no end.
63///
64/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
65/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
66/// you have noticed that this should be exactly the same to the compiler, and you're correct.
67pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68/// Name of the struct that is declared to wrap the 3 textures and parameters
69/// buffer that [`crate::ImageClass::External`] variables are lowered to,
70/// allowing them to be conveniently passed to user-defined or wrapper
71/// functions. The struct is declared in [`Writer::write_type_defs`].
72pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
77///
78/// The `sizes` slice determines whether this function writes a
79/// scalar, vector, or matrix type:
80///
81/// - An empty slice produces a scalar type.
82/// - A one-element slice produces a vector type.
83/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
84fn put_numeric_type(
85    out: &mut impl Write,
86    scalar: crate::Scalar,
87    sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89    match (scalar, sizes) {
90        (scalar, &[]) => {
91            write!(out, "{}", scalar.to_msl_name())
92        }
93        (scalar, &[rows]) => {
94            write!(
95                out,
96                "{}::{}{}",
97                NAMESPACE,
98                scalar.to_msl_name(),
99                common::vector_size_str(rows)
100            )
101        }
102        (scalar, &[rows, columns]) => {
103            write!(
104                out,
105                "{}::{}{}x{}",
106                NAMESPACE,
107                scalar.to_msl_name(),
108                common::vector_size_str(columns),
109                common::vector_size_str(rows)
110            )
111        }
112        (_, _) => Ok(()), // not meaningful
113    }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117    use crate::ScalarKind::*;
118    match scalar.kind {
119        Sint | Uint | AbstractInt | Bool => true,
120        Float | AbstractFloat => false,
121    }
122}
123
124/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
125const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
128const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130/// Wrapper for identifier names for clamped level-of-detail values
131///
132/// Values of this type implement [`core::fmt::Display`], formatting as
133/// the name of the variable used to hold the cached clamped
134/// level-of-detail value for an `ImageLoad` expression.
135struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140    }
141}
142
143/// Wrapper for generating `struct _mslBufferSizes` member names for
144/// runtime-sized array lengths.
145///
146/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
147/// as an argument to the entry point. This argument's type in the MSL is
148/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
149/// each global variable containing a runtime-sized array.
150///
151/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
152/// runtime-sized array, then the value `ArraySize(global)` implements
153/// [`core::fmt::Display`], formatting as the name of the struct member carrying
154/// the number of elements in that runtime-sized array.
155///
156/// [`GlobalVariable`]: crate::GlobalVariable
157struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161        self.0.write_prefixed(f, "size")
162    }
163}
164
165/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
166///
167/// Implements [`core::fmt::Display`], formatting as a name derived from
168/// `target_type` and the variable name of `orig`.
169#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171    target_type: &'a str,
172    orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177        Self { target_type, orig }
178    }
179}
180
181impl Display for Reinterpreted<'_> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183        f.write_str(REINTERPRET_PREFIX)?;
184        f.write_str(self.target_type)?;
185        self.orig.write_prefixed(f, "_e")
186    }
187}
188
189pub(super) struct TypeContext<'a> {
190    pub handle: Handle<crate::Type>,
191    pub gctx: proc::GlobalCtx<'a>,
192    pub names: &'a FastHashMap<NameKey, String>,
193    pub access: crate::StorageAccess,
194    pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198    fn scalar(&self) -> Option<crate::Scalar> {
199        let ty = &self.gctx.types[self.handle];
200        ty.inner.scalar()
201    }
202
203    fn vector_size(&self) -> Option<crate::VectorSize> {
204        let ty = &self.gctx.types[self.handle];
205        match ty.inner {
206            crate::TypeInner::Vector { size, .. } => Some(size),
207            _ => None,
208        }
209    }
210
211    fn unwrap_array(self) -> Self {
212        match self.gctx.types[self.handle].inner {
213            crate::TypeInner::Array { base, .. } => Self {
214                handle: base,
215                ..self
216            },
217            _ => self,
218        }
219    }
220}
221
222impl Display for TypeContext<'_> {
223    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224        let ty = &self.gctx.types[self.handle];
225        if ty.needs_alias() && !self.first_time {
226            let name = &self.names[&NameKey::Type(self.handle)];
227            return write!(out, "{name}");
228        }
229
230        match ty.inner {
231            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232            crate::TypeInner::Atomic(scalar) => {
233                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234            }
235            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236            crate::TypeInner::Matrix {
237                columns,
238                rows,
239                scalar,
240            } => put_numeric_type(out, scalar, &[rows, columns]),
241            // Requires Metal-2.3
242            crate::TypeInner::CooperativeMatrix {
243                columns,
244                rows,
245                scalar,
246                role: _,
247            } => {
248                write!(
249                    out,
250                    "{NAMESPACE}::simdgroup_{}{}x{}",
251                    scalar.to_msl_name(),
252                    columns as u32,
253                    rows as u32,
254                )
255            }
256            crate::TypeInner::Pointer { base, space } => {
257                let sub = Self {
258                    handle: base,
259                    first_time: false,
260                    ..*self
261                };
262                let space_name = match space.to_msl_name() {
263                    Some(name) => name,
264                    None => return Ok(()),
265                };
266                write!(out, "{space_name} {sub}&")
267            }
268            crate::TypeInner::ValuePointer {
269                size,
270                scalar,
271                space,
272            } => {
273                match space.to_msl_name() {
274                    Some(name) => write!(out, "{name} ")?,
275                    None => return Ok(()),
276                };
277                match size {
278                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279                    None => put_numeric_type(out, scalar, &[])?,
280                };
281
282                write!(out, "&")
283            }
284            crate::TypeInner::Array { base, .. } => {
285                let sub = Self {
286                    handle: base,
287                    first_time: false,
288                    ..*self
289                };
290                // Array lengths go at the end of the type definition,
291                // so just print the element type here.
292                write!(out, "{sub}")
293            }
294            crate::TypeInner::Struct { .. } => unreachable!(),
295            crate::TypeInner::Image {
296                dim,
297                arrayed,
298                class,
299            } => {
300                let dim_str = match dim {
301                    crate::ImageDimension::D1 => "1d",
302                    crate::ImageDimension::D2 => "2d",
303                    crate::ImageDimension::D3 => "3d",
304                    crate::ImageDimension::Cube => "cube",
305                };
306                let (texture_str, msaa_str, scalar, access) = match class {
307                    crate::ImageClass::Sampled { kind, multi } => {
308                        let (msaa_str, access) = if multi {
309                            ("_ms", "read")
310                        } else {
311                            ("", "sample")
312                        };
313                        let scalar = crate::Scalar { kind, width: 4 };
314                        ("texture", msaa_str, scalar, access)
315                    }
316                    crate::ImageClass::Depth { multi } => {
317                        let (msaa_str, access) = if multi {
318                            ("_ms", "read")
319                        } else {
320                            ("", "sample")
321                        };
322                        let scalar = crate::Scalar {
323                            kind: crate::ScalarKind::Float,
324                            width: 4,
325                        };
326                        ("depth", msaa_str, scalar, access)
327                    }
328                    crate::ImageClass::Storage { format, .. } => {
329                        let access = if self
330                            .access
331                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332                        {
333                            "read_write"
334                        } else if self.access.contains(crate::StorageAccess::STORE) {
335                            "write"
336                        } else if self.access.contains(crate::StorageAccess::LOAD) {
337                            "read"
338                        } else {
339                            log::warn!(
340                                "Storage access for {:?} (name '{}'): {:?}",
341                                self.handle,
342                                ty.name.as_deref().unwrap_or_default(),
343                                self.access
344                            );
345                            unreachable!("module is not valid");
346                        };
347                        ("texture", "", format.into(), access)
348                    }
349                    crate::ImageClass::External => {
350                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351                    }
352                };
353                let base_name = scalar.to_msl_name();
354                let array_str = if arrayed { "_array" } else { "" };
355                write!(
356                    out,
357                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358                )
359            }
360            crate::TypeInner::Sampler { comparison: _ } => {
361                write!(out, "{NAMESPACE}::sampler")
362            }
363            crate::TypeInner::AccelerationStructure { vertex_return } => {
364                if vertex_return {
365                    unimplemented!("metal does not support vertex ray hit return")
366                }
367                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368            }
369            crate::TypeInner::RayQuery { vertex_return } => {
370                if vertex_return {
371                    unimplemented!("metal does not support vertex ray hit return")
372                }
373                write!(out, "{}", super::ray::metal_intersector_ty())
374            }
375            crate::TypeInner::BindingArray { base, .. } => {
376                let base_tyname = Self {
377                    handle: base,
378                    first_time: false,
379                    ..*self
380                };
381
382                write!(
383                    out,
384                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385                )
386            }
387        }
388    }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392    pub module: &'a crate::Module,
393    pub names: &'a FastHashMap<NameKey, String>,
394    pub handle: Handle<crate::GlobalVariable>,
395    pub usage: valid::GlobalUse,
396    pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400    ty_name: String,
401    var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405    fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406        let var = &self.module.global_variables[self.handle];
407        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409        let storage_access = match var.space {
410            crate::AddressSpace::Storage { access } => access,
411            _ => match self.module.types[var.ty].inner {
412                crate::TypeInner::Image {
413                    class: crate::ImageClass::Storage { access, .. },
414                    ..
415                } => access,
416                crate::TypeInner::BindingArray { base, .. } => {
417                    match self.module.types[base].inner {
418                        crate::TypeInner::Image {
419                            class: crate::ImageClass::Storage { access, .. },
420                            ..
421                        } => access,
422                        _ => crate::StorageAccess::default(),
423                    }
424                }
425                _ => crate::StorageAccess::default(),
426            },
427        };
428        let ty_name = TypeContext {
429            handle: var.ty,
430            gctx: self.module.to_ctx(),
431            names: self.names,
432            access: storage_access,
433            first_time: false,
434        };
435
436        let access = if var.space.needs_access_qualifier()
437            && !self.usage.intersects(valid::GlobalUse::WRITE)
438        {
439            "const"
440        } else {
441            ""
442        };
443        let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444            (Some(space), crate::AddressSpace::WorkGroup) => {
445                ("", space, access, if self.reference { "&" } else { "" })
446            }
447            (Some(space), _) if self.reference => {
448                let coherent = if var
449                    .memory_decorations
450                    .contains(crate::MemoryDecorations::COHERENT)
451                {
452                    "coherent "
453                } else {
454                    ""
455                };
456                (coherent, space, access, "&")
457            }
458            _ => ("", "", "", ""),
459        };
460
461        let ty = format!(
462            "{coherent}{space}{}{ty_name}{}{access}{reference}",
463            if space.is_empty() { "" } else { " " },
464            if access.is_empty() { "" } else { " " },
465        );
466
467        Ok(TypedGlobalVariableParts {
468            ty_name: ty,
469            var_name: name.clone(),
470        })
471    }
472    pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473        let parts = self.to_parts()?;
474
475        Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476    }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481    UnaryOp {
482        op: crate::UnaryOperator,
483        ty: (Option<crate::VectorSize>, crate::Scalar),
484    },
485    BinaryOp {
486        op: crate::BinaryOperator,
487        left_ty: (Option<crate::VectorSize>, crate::Scalar),
488        right_ty: (Option<crate::VectorSize>, crate::Scalar),
489    },
490    Math {
491        fun: crate::MathFunction,
492        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493    },
494    Cast {
495        src_scalar: crate::Scalar,
496        vector_size: Option<crate::VectorSize>,
497        dst_scalar: crate::Scalar,
498    },
499    ImageLoad {
500        class: crate::ImageClass,
501    },
502    ImageSample {
503        class: crate::ImageClass,
504        clamp_to_edge: bool,
505    },
506    ImageQuerySize {
507        class: crate::ImageClass,
508    },
509    CooperativeLoad {
510        space_name: &'static str,
511        columns: crate::CooperativeSize,
512        rows: crate::CooperativeSize,
513        scalar: crate::Scalar,
514    },
515    CooperativeMultiplyAdd {
516        space_name: &'static str,
517        columns: crate::CooperativeSize,
518        rows: crate::CooperativeSize,
519        intermediate: crate::CooperativeSize,
520        scalar: crate::Scalar,
521    },
522    RayQueryGetIntersection {
523        committed: bool,
524    },
525}
526
527pub struct Writer<W> {
528    pub(super) out: W,
529    pub(super) names: FastHashMap<NameKey, String>,
530    pub(super) named_expressions: crate::NamedExpressions,
531    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
532    need_bake_expressions: back::NeedBakeExpressions,
533    pub(super) namer: proc::Namer,
534    pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535    emit_int_div_checks: bool,
536    #[cfg(test)]
537    put_expression_stack_pointers: FastHashSet<*const ()>,
538    #[cfg(test)]
539    put_block_stack_pointers: FastHashSet<*const ()>,
540    /// Set of (struct type, struct field index) denoting which fields require
541    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
542    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
543    needs_object_memory_barriers: bool,
544}
545
546impl crate::Scalar {
547    pub(super) fn to_msl_name(self) -> &'static str {
548        use crate::ScalarKind as Sk;
549        match self {
550            Self {
551                kind: Sk::Float,
552                width: 4,
553            } => "float",
554            Self {
555                kind: Sk::Float,
556                width: 2,
557            } => "half",
558            Self {
559                kind: Sk::Sint,
560                width: 2,
561            } => "short",
562            Self {
563                kind: Sk::Uint,
564                width: 2,
565            } => "ushort",
566            Self {
567                kind: Sk::Sint,
568                width: 4,
569            } => "int",
570            Self {
571                kind: Sk::Uint,
572                width: 4,
573            } => "uint",
574            Self {
575                kind: Sk::Sint,
576                width: 8,
577            } => "long",
578            Self {
579                kind: Sk::Uint,
580                width: 8,
581            } => "ulong",
582            Self {
583                kind: Sk::Bool,
584                width: _,
585            } => "bool",
586            Self {
587                kind: Sk::AbstractInt | Sk::AbstractFloat,
588                width: _,
589            } => unreachable!("Found Abstract scalar kind"),
590            _ => unreachable!("Unsupported scalar kind: {:?}", self),
591        }
592    }
593}
594
595const fn separate(need_separator: bool) -> &'static str {
596    if need_separator {
597        ","
598    } else {
599        ""
600    }
601}
602
603fn should_pack_struct_member(
604    members: &[crate::StructMember],
605    span: u32,
606    index: usize,
607    module: &crate::Module,
608) -> Option<crate::Scalar> {
609    let member = &members[index];
610
611    let ty_inner = &module.types[member.ty].inner;
612    let last_offset = member.offset + ty_inner.size(module.to_ctx());
613    let next_offset = match members.get(index + 1) {
614        Some(next) => next.offset,
615        None => span,
616    };
617    let is_tight = next_offset == last_offset;
618
619    match *ty_inner {
620        crate::TypeInner::Vector {
621            size: crate::VectorSize::Tri,
622            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
623        } if is_tight => Some(scalar),
624        _ => None,
625    }
626}
627
628fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
629    match arena[ty].inner {
630        crate::TypeInner::Struct { ref members, .. } => {
631            if let Some(member) = members.last() {
632                if let crate::TypeInner::Array {
633                    size: crate::ArraySize::Dynamic,
634                    ..
635                } = arena[member.ty].inner
636                {
637                    return true;
638                }
639            }
640            false
641        }
642        crate::TypeInner::Array {
643            size: crate::ArraySize::Dynamic,
644            ..
645        } => true,
646        _ => false,
647    }
648}
649
650impl crate::AddressSpace {
651    /// Returns true if global variables in this address space are
652    /// passed in function arguments. These arguments need to be
653    /// passed through any functions called from the entry point.
654    const fn needs_pass_through(&self) -> bool {
655        match *self {
656            Self::Uniform
657            | Self::Storage { .. }
658            | Self::Private
659            | Self::WorkGroup
660            | Self::Immediate
661            | Self::Handle
662            | Self::TaskPayload => true,
663            Self::Function => false,
664            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
665        }
666    }
667
668    /// Returns true if the address space may need a "const" qualifier.
669    const fn needs_access_qualifier(&self) -> bool {
670        match *self {
671            //Note: we are ignoring the storage access here, and instead
672            // rely on the actual use of a global by functions. This means we
673            // may end up with "const" even if the binding is read-write,
674            // and that should be OK.
675            Self::Storage { .. } => true,
676            Self::TaskPayload => true,
677            Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
678            // These should always be read-write.
679            Self::Private | Self::WorkGroup => false,
680            // These translate to `constant` address space, no need for qualifiers.
681            Self::Uniform | Self::Immediate => false,
682            // Not applicable.
683            Self::Handle | Self::Function => false,
684        }
685    }
686
687    const fn to_msl_name(self) -> Option<&'static str> {
688        match self {
689            Self::Handle => None,
690            Self::Uniform | Self::Immediate => Some("constant"),
691            Self::Storage { .. } => Some("device"),
692            // note for `RayPayload`, this probably needs to be emulated as a
693            // private variable, as metal has essentially an inout input
694            // for where it is passed.
695            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
696            Self::WorkGroup => Some("threadgroup"),
697            Self::TaskPayload => Some("object_data"),
698            Self::IncomingRayPayload => Some("ray_data"),
699        }
700    }
701}
702
703impl crate::Type {
704    // Returns `true` if we need to emit an alias for this type.
705    const fn needs_alias(&self) -> bool {
706        use crate::TypeInner as Ti;
707
708        match self.inner {
709            // value types are concise enough, we only alias them if they are named
710            Ti::Scalar(_)
711            | Ti::Vector { .. }
712            | Ti::Matrix { .. }
713            | Ti::CooperativeMatrix { .. }
714            | Ti::Atomic(_)
715            | Ti::Pointer { .. }
716            | Ti::ValuePointer { .. } => self.name.is_some(),
717            // composite types are better to be aliased, regardless of the name
718            Ti::Struct { .. } | Ti::Array { .. } => true,
719            // handle types may be different, depending on the global var access, so we always inline them
720            Ti::Image { .. }
721            | Ti::Sampler { .. }
722            | Ti::AccelerationStructure { .. }
723            | Ti::RayQuery { .. }
724            | Ti::BindingArray { .. } => false,
725        }
726    }
727}
728
729#[derive(Clone, Copy)]
730enum FunctionOrigin {
731    Handle(Handle<crate::Function>),
732    EntryPoint(proc::EntryPointIndex),
733}
734
735trait NameKeyExt {
736    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
737        match origin {
738            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
739            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
740        }
741    }
742
743    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
744    /// policy when it needs to produce a pointer-typed result for an OOB access. These
745    /// are unique per accessed type, so the second argument is a type handle. See docs
746    /// for [`crate::back::msl`].
747    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
748        match origin {
749            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
750            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
751        }
752    }
753}
754
755impl NameKeyExt for NameKey {}
756
757/// A level of detail argument.
758///
759/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
760/// save the clamped level of detail in a temporary variable whose name is based
761/// on the handle of the `ImageLoad` expression. But for other policies, we just
762/// use the expression directly.
763///
764/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
765/// [`ImageLoad`]: crate::Expression::ImageLoad
766#[derive(Clone, Copy)]
767enum LevelOfDetail {
768    Direct(Handle<crate::Expression>),
769    Restricted(Handle<crate::Expression>),
770}
771
772/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
773///
774/// When this is used in code paths unconcerned with the `Restrict` bounds check
775/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
776/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
777/// be too awkward. If that changes, we can revisit.
778///
779/// [`ImageLoad`]: crate::Expression::ImageLoad
780/// [`ImageStore`]: crate::Statement::ImageStore
781struct TexelAddress {
782    coordinate: Handle<crate::Expression>,
783    array_index: Option<Handle<crate::Expression>>,
784    sample: Option<Handle<crate::Expression>>,
785    level: Option<LevelOfDetail>,
786}
787
788pub(super) struct ExpressionContext<'a> {
789    pub(super) function: &'a crate::Function,
790    origin: FunctionOrigin,
791    pub(super) info: &'a valid::FunctionInfo,
792    pub(super) module: &'a crate::Module,
793    pub(super) mod_info: &'a valid::ModuleInfo,
794    pub(super) pipeline_options: &'a PipelineOptions,
795    pub(super) lang_version: (u8, u8),
796    pub(super) policies: index::BoundsCheckPolicies,
797
798    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
799    /// accesses. These may need to be cached in temporary variables. See
800    /// `index::find_checked_indexes` for details.
801    pub(super) guarded_indices: HandleSet<crate::Expression>,
802    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
803    pub(super) force_loop_bounding: bool,
804    /// Whether to emit safety checks for integer division/modulo.
805    emit_int_div_checks: bool,
806}
807
808impl<'a> ExpressionContext<'a> {
809    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
810        self.info[handle].ty.inner_with(&self.module.types)
811    }
812
813    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
814    ///
815    /// Only mipmapped images need to specify a level of detail. Since 1D
816    /// textures cannot have mipmaps, MSL requires that the level argument to
817    /// texture1d queries and accesses must be a constexpr 0. It's easiest
818    /// just to omit the level entirely for 1D textures.
819    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
820        let image_ty = self.resolve_type(image);
821        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
822            class.is_mipmapped() && dim != crate::ImageDimension::D1
823        } else {
824            false
825        }
826    }
827
828    fn choose_bounds_check_policy(
829        &self,
830        pointer: Handle<crate::Expression>,
831    ) -> index::BoundsCheckPolicy {
832        self.policies
833            .choose_policy(pointer, &self.module.types, self.info)
834    }
835
836    /// See docs for [`proc::index::access_needs_check`].
837    fn access_needs_check(
838        &self,
839        base: Handle<crate::Expression>,
840        index: index::GuardedIndex,
841    ) -> Option<index::IndexableLength> {
842        index::access_needs_check(
843            base,
844            index,
845            self.module,
846            &self.function.expressions,
847            self.info,
848        )
849    }
850
851    /// See docs for [`proc::index::bounds_check_iter`].
852    fn bounds_check_iter(
853        &self,
854        chain: Handle<crate::Expression>,
855    ) -> impl Iterator<Item = BoundsCheck> + '_ {
856        index::bounds_check_iter(chain, self.module, self.function, self.info)
857    }
858
859    /// See docs for [`proc::index::oob_local_types`].
860    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
861        index::oob_local_types(self.module, self.function, self.info, self.policies)
862    }
863
864    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
865        match self.function.expressions[expr_handle] {
866            crate::Expression::AccessIndex { base, index } => {
867                let ty = match *self.resolve_type(base) {
868                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
869                    ref ty => ty,
870                };
871                match *ty {
872                    crate::TypeInner::Struct {
873                        ref members, span, ..
874                    } => should_pack_struct_member(members, span, index as usize, self.module),
875                    _ => None,
876                }
877            }
878            _ => None,
879        }
880    }
881}
882
883pub(super) struct StatementContext<'a> {
884    pub(super) expression: ExpressionContext<'a>,
885    pub(super) result_struct: Option<&'a str>,
886}
887
888impl<W: Write> Writer<W> {
889    /// Creates a new `Writer` instance.
890    pub fn new(out: W) -> Self {
891        Writer {
892            out,
893            names: FastHashMap::default(),
894            named_expressions: Default::default(),
895            need_bake_expressions: Default::default(),
896            namer: proc::Namer::default(),
897            wrapped_functions: FastHashSet::default(),
898            emit_int_div_checks: true,
899            #[cfg(test)]
900            put_expression_stack_pointers: Default::default(),
901            #[cfg(test)]
902            put_block_stack_pointers: Default::default(),
903            struct_member_pads: FastHashSet::default(),
904            needs_object_memory_barriers: false,
905        }
906    }
907
908    /// Finishes writing and returns the output.
909    // See https://github.com/rust-lang/rust-clippy/issues/4979.
910    pub fn finish(self) -> W {
911        self.out
912    }
913
914    /// Generates statements to be inserted immediately before and at the very
915    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
916    /// The 0th item of the returned tuple should be inserted immediately prior
917    /// to the loop and the 1st item should be inserted at the very start of
918    /// the loop body.
919    ///
920    /// # What is this trying to solve?
921    ///
922    /// In Metal Shading Language, an infinite loop has undefined behavior.
923    /// (This rule is inherited from C++14.) This means that, if the MSL
924    /// compiler determines that a given loop will never exit, it may assume
925    /// that it is never reached. It may thus assume that any conditions
926    /// sufficient to cause the loop to be reached must be false. Like many
927    /// optimizing compilers, MSL uses this kind of analysis to establish limits
928    /// on the range of values variables involved in those conditions might
929    /// hold.
930    ///
931    /// For example, suppose the MSL compiler sees the code:
932    ///
933    /// ```ignore
934    /// if (i >= 10) {
935    ///     while (true) { }
936    /// }
937    /// ```
938    ///
939    /// It will recognize that the `while` loop will never terminate, conclude
940    /// that it must be unreachable, and thus infer that, if this code is
941    /// reached, then `i < 10` at that point.
942    ///
943    /// Now suppose that, at some point where `i` has the same value as above,
944    /// the compiler sees the code:
945    ///
946    /// ```ignore
947    /// if (i < 10) {
948    ///     a[i] = 1;
949    /// }
950    /// ```
951    ///
952    /// Because the compiler is confident that `i < 10`, it will make the
953    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
954    ///
955    /// ```ignore
956    /// a[i] = 1;
957    /// ```
958    ///
959    /// If that `if` condition was injected by Naga to implement a bounds check,
960    /// the MSL compiler's optimizations could allow out-of-bounds array
961    /// accesses to occur.
962    ///
963    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
964    /// that a loop is infinite, so an attacker could craft a Naga module
965    /// containing an infinite loop protected by conditions that cause the Metal
966    /// compiler to remove bounds checks that Naga injected elsewhere in the
967    /// function.
968    ///
969    /// This rewrite could occur even if the conditional assignment appears
970    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
971    /// reached. This would allow the attacker to save the results of
972    /// unauthorized reads somewhere accessible before entering the infinite
973    /// loop. But even worse, the MSL compiler has been observed to simply
974    /// delete the infinite loop entirely, so that even code dominated by the
975    /// loop becomes reachable. This would make the attack even more flexible,
976    /// since shaders that would appear to never terminate would actually exit
977    /// nicely, after having stolen data from elsewhere in the GPU address
978    /// space.
979    ///
980    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
981    /// generates is infinite. One approach would be to add inline assembly to
982    /// each loop that is annotated as potentially branching out of the loop,
983    /// but which in fact generates no instructions. Unfortunately, inline
984    /// assembly is not handled correctly by some Metal device drivers.
985    ///
986    /// A previously used approach was to add the following code to the bottom
987    /// of every loop:
988    ///
989    /// ```ignore
990    /// if (volatile bool unpredictable = false; unpredictable)
991    ///     break;
992    /// ```
993    ///
994    /// Although the `if` condition will always be false in any real execution,
995    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
996    /// it must assume that the `break` might be reached, and hence that the
997    /// loop is not unbounded. This prevents the range analysis impact described
998    /// above. Unfortunately this prevented the compiler from making important,
999    /// and safe, optimizations such as loop unrolling and was observed to
1000    /// significantly hurt performance.
1001    ///
1002    /// Our current approach declares a counter before every loop and
1003    /// increments it every iteration, breaking after 2^64 iterations:
1004    ///
1005    /// ```ignore
1006    /// uint2 loop_bound = uint2(0);
1007    /// while (true) {
1008    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
1009    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
1010    /// }
1011    /// ```
1012    ///
1013    /// This convinces the compiler that the loop is finite and therefore may
1014    /// execute, whilst at the same time allowing optimizations such as loop
1015    /// unrolling. Furthermore the 64-bit counter is large enough it seems
1016    /// implausible that it would affect the execution of any shader.
1017    ///
1018    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
1019    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
1020    fn gen_force_bounded_loop_statements(
1021        &mut self,
1022        level: back::Level,
1023        context: &StatementContext,
1024    ) -> Option<(String, String)> {
1025        if !context.expression.force_loop_bounding {
1026            return None;
1027        }
1028
1029        let loop_bound_name = self.namer.call("loop_bound");
1030        // Count down from u32::MAX rather than up from 0 to avoid hang on
1031        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
1032        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1033        let level = level.next();
1034        let break_and_inc = format!(
1035            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1036{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1037        );
1038
1039        Some((decl, break_and_inc))
1040    }
1041
1042    fn put_call_parameters(
1043        &mut self,
1044        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1045        context: &ExpressionContext,
1046    ) -> BackendResult {
1047        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1048            writer.put_expression(expr, context, true)
1049        })
1050    }
1051
1052    fn put_call_parameters_impl<C, E>(
1053        &mut self,
1054        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1055        ctx: &C,
1056        put_expression: E,
1057    ) -> BackendResult
1058    where
1059        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1060    {
1061        write!(self.out, "(")?;
1062        for (i, handle) in parameters.enumerate() {
1063            if i != 0 {
1064                write!(self.out, ", ")?;
1065            }
1066            put_expression(self, ctx, handle)?;
1067        }
1068        write!(self.out, ")")?;
1069        Ok(())
1070    }
1071
1072    /// Writes the local variables of the given function, as well as any extra
1073    /// out-of-bounds locals that are needed.
1074    ///
1075    /// The names of the OOB locals are also added to `self.names` at the same
1076    /// time.
1077    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1078        let oob_local_types = context.oob_local_types();
1079        for &ty in oob_local_types.iter() {
1080            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1081            self.names.insert(name_key, self.namer.call("oob"));
1082        }
1083
1084        for (name_key, ty, init) in context
1085            .function
1086            .local_variables
1087            .iter()
1088            .map(|(local_handle, local)| {
1089                let name_key = NameKey::local(context.origin, local_handle);
1090                (name_key, local.ty, local.init)
1091            })
1092            .chain(oob_local_types.iter().map(|&ty| {
1093                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1094                (name_key, ty, None)
1095            }))
1096        {
1097            let ty_name = TypeContext {
1098                handle: ty,
1099                gctx: context.module.to_ctx(),
1100                names: &self.names,
1101                access: crate::StorageAccess::empty(),
1102                first_time: false,
1103            };
1104            write!(
1105                self.out,
1106                "{}{} {}",
1107                back::INDENT,
1108                ty_name,
1109                self.names[&name_key]
1110            )?;
1111            match init {
1112                Some(value) => {
1113                    write!(self.out, " = ")?;
1114                    self.put_expression(value, context, true)?;
1115                }
1116                None => {
1117                    write!(self.out, " = {{}}")?;
1118                }
1119            };
1120            writeln!(self.out, ";")?;
1121        }
1122        Ok(())
1123    }
1124
1125    fn put_level_of_detail(
1126        &mut self,
1127        level: LevelOfDetail,
1128        context: &ExpressionContext,
1129    ) -> BackendResult {
1130        match level {
1131            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1132            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1133        }
1134        Ok(())
1135    }
1136
1137    fn put_image_query(
1138        &mut self,
1139        image: Handle<crate::Expression>,
1140        query: &str,
1141        level: Option<LevelOfDetail>,
1142        context: &ExpressionContext,
1143    ) -> BackendResult {
1144        self.put_expression(image, context, false)?;
1145        write!(self.out, ".get_{query}(")?;
1146        if let Some(level) = level {
1147            self.put_level_of_detail(level, context)?;
1148        }
1149        write!(self.out, ")")?;
1150        Ok(())
1151    }
1152
1153    fn put_image_size_query(
1154        &mut self,
1155        image: Handle<crate::Expression>,
1156        level: Option<LevelOfDetail>,
1157        kind: crate::ScalarKind,
1158        context: &ExpressionContext,
1159    ) -> BackendResult {
1160        if let crate::TypeInner::Image {
1161            class: crate::ImageClass::External,
1162            ..
1163        } = *context.resolve_type(image)
1164        {
1165            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1166            self.put_expression(image, context, true)?;
1167            write!(self.out, ")")?;
1168            return Ok(());
1169        }
1170
1171        //Note: MSL only has separate width/height/depth queries,
1172        // so compose the result of them.
1173        let dim = match *context.resolve_type(image) {
1174            crate::TypeInner::Image { dim, .. } => dim,
1175            ref other => unreachable!("Unexpected type {:?}", other),
1176        };
1177        let scalar = crate::Scalar { kind, width: 4 };
1178        let coordinate_type = scalar.to_msl_name();
1179        match dim {
1180            crate::ImageDimension::D1 => {
1181                // Since 1D textures never have mipmaps, MSL requires that the
1182                // `level` argument be a constexpr 0. It's simplest for us just
1183                // to pass `None` and omit the level entirely.
1184                if kind == crate::ScalarKind::Uint {
1185                    // No need to construct a vector. No cast needed.
1186                    self.put_image_query(image, "width", None, context)?;
1187                } else {
1188                    // There's no definition for `int` in the `metal` namespace.
1189                    write!(self.out, "int(")?;
1190                    self.put_image_query(image, "width", None, context)?;
1191                    write!(self.out, ")")?;
1192                }
1193            }
1194            crate::ImageDimension::D2 => {
1195                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1196                self.put_image_query(image, "width", level, context)?;
1197                write!(self.out, ", ")?;
1198                self.put_image_query(image, "height", level, context)?;
1199                write!(self.out, ")")?;
1200            }
1201            crate::ImageDimension::D3 => {
1202                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1203                self.put_image_query(image, "width", level, context)?;
1204                write!(self.out, ", ")?;
1205                self.put_image_query(image, "height", level, context)?;
1206                write!(self.out, ", ")?;
1207                self.put_image_query(image, "depth", level, context)?;
1208                write!(self.out, ")")?;
1209            }
1210            crate::ImageDimension::Cube => {
1211                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1212                self.put_image_query(image, "width", level, context)?;
1213                write!(self.out, ")")?;
1214            }
1215        }
1216        Ok(())
1217    }
1218
1219    fn put_cast_to_uint_scalar_or_vector(
1220        &mut self,
1221        expr: Handle<crate::Expression>,
1222        context: &ExpressionContext,
1223    ) -> BackendResult {
1224        // coordinates in IR are int, but Metal expects uint
1225        match *context.resolve_type(expr) {
1226            crate::TypeInner::Scalar(_) => {
1227                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1228            }
1229            crate::TypeInner::Vector { size, .. } => {
1230                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1231            }
1232            _ => {
1233                return Err(Error::GenericValidation(
1234                    "Invalid type for image coordinate".into(),
1235                ))
1236            }
1237        };
1238
1239        write!(self.out, "(")?;
1240        self.put_expression(expr, context, true)?;
1241        write!(self.out, ")")?;
1242        Ok(())
1243    }
1244
1245    fn put_image_sample_level(
1246        &mut self,
1247        image: Handle<crate::Expression>,
1248        level: crate::SampleLevel,
1249        context: &ExpressionContext,
1250    ) -> BackendResult {
1251        let has_levels = context.image_needs_lod(image);
1252        match level {
1253            crate::SampleLevel::Auto => {}
1254            crate::SampleLevel::Zero => {
1255                //TODO: do we support Zero on `Sampled` image classes?
1256            }
1257            _ if !has_levels => {
1258                log::warn!("1D image can't be sampled with level {level:?}");
1259            }
1260            crate::SampleLevel::Exact(h) => {
1261                write!(self.out, ", {NAMESPACE}::level(")?;
1262                self.put_expression(h, context, true)?;
1263                write!(self.out, ")")?;
1264            }
1265            crate::SampleLevel::Bias(h) => {
1266                write!(self.out, ", {NAMESPACE}::bias(")?;
1267                self.put_expression(h, context, true)?;
1268                write!(self.out, ")")?;
1269            }
1270            crate::SampleLevel::Gradient { x, y } => {
1271                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1272                self.put_expression(x, context, true)?;
1273                write!(self.out, ", ")?;
1274                self.put_expression(y, context, true)?;
1275                write!(self.out, ")")?;
1276            }
1277        }
1278        Ok(())
1279    }
1280
1281    fn put_image_coordinate_limits(
1282        &mut self,
1283        image: Handle<crate::Expression>,
1284        level: Option<LevelOfDetail>,
1285        context: &ExpressionContext,
1286    ) -> BackendResult {
1287        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1288        write!(self.out, " - 1")?;
1289        Ok(())
1290    }
1291
1292    /// General function for writing restricted image indexes.
1293    ///
1294    /// This is used to produce restricted mip levels, array indices, and sample
1295    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1296    /// [`Restrict`] bounds check policy.
1297    ///
1298    /// This function writes an expression of the form:
1299    ///
1300    /// ```ignore
1301    ///
1302    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1303    ///
1304    /// ```
1305    ///
1306    /// [`ImageLoad`]: crate::Expression::ImageLoad
1307    /// [`ImageStore`]: crate::Statement::ImageStore
1308    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1309    fn put_restricted_scalar_image_index(
1310        &mut self,
1311        image: Handle<crate::Expression>,
1312        index: Handle<crate::Expression>,
1313        limit_method: &str,
1314        context: &ExpressionContext,
1315    ) -> BackendResult {
1316        write!(self.out, "{NAMESPACE}::min(uint(")?;
1317        self.put_expression(index, context, true)?;
1318        write!(self.out, "), ")?;
1319        self.put_expression(image, context, false)?;
1320        write!(self.out, ".{limit_method}() - 1)")?;
1321        Ok(())
1322    }
1323
1324    fn put_restricted_texel_address(
1325        &mut self,
1326        image: Handle<crate::Expression>,
1327        address: &TexelAddress,
1328        context: &ExpressionContext,
1329    ) -> BackendResult {
1330        // Write the coordinate.
1331        write!(self.out, "{NAMESPACE}::min(")?;
1332        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1333        write!(self.out, ", ")?;
1334        self.put_image_coordinate_limits(image, address.level, context)?;
1335        write!(self.out, ")")?;
1336
1337        // Write the array index, if present.
1338        if let Some(array_index) = address.array_index {
1339            write!(self.out, ", ")?;
1340            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1341        }
1342
1343        // Write the sample index, if present.
1344        if let Some(sample) = address.sample {
1345            write!(self.out, ", ")?;
1346            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1347        }
1348
1349        // The level of detail should be clamped and cached by
1350        // `put_cache_restricted_level`, so we don't need to clamp it here.
1351        if let Some(level) = address.level {
1352            write!(self.out, ", ")?;
1353            self.put_level_of_detail(level, context)?;
1354        }
1355
1356        Ok(())
1357    }
1358
1359    /// Write an expression that is true if the given image access is in bounds.
1360    fn put_image_access_bounds_check(
1361        &mut self,
1362        image: Handle<crate::Expression>,
1363        address: &TexelAddress,
1364        context: &ExpressionContext,
1365    ) -> BackendResult {
1366        let mut conjunction = "";
1367
1368        // First, check the level of detail. Only if that is in bounds can we
1369        // use it to find the appropriate bounds for the coordinates.
1370        let level = if let Some(level) = address.level {
1371            write!(self.out, "uint(")?;
1372            self.put_level_of_detail(level, context)?;
1373            write!(self.out, ") < ")?;
1374            self.put_expression(image, context, true)?;
1375            write!(self.out, ".get_num_mip_levels()")?;
1376            conjunction = " && ";
1377            Some(level)
1378        } else {
1379            None
1380        };
1381
1382        // Check sample index, if present.
1383        if let Some(sample) = address.sample {
1384            write!(self.out, "uint(")?;
1385            self.put_expression(sample, context, true)?;
1386            write!(self.out, ") < ")?;
1387            self.put_expression(image, context, true)?;
1388            write!(self.out, ".get_num_samples()")?;
1389            conjunction = " && ";
1390        }
1391
1392        // Check array index, if present.
1393        if let Some(array_index) = address.array_index {
1394            write!(self.out, "{conjunction}uint(")?;
1395            self.put_expression(array_index, context, true)?;
1396            write!(self.out, ") < ")?;
1397            self.put_expression(image, context, true)?;
1398            write!(self.out, ".get_array_size()")?;
1399            conjunction = " && ";
1400        }
1401
1402        // Finally, check if the coordinates are within bounds.
1403        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1404            crate::TypeInner::Vector { .. } => true,
1405            _ => false,
1406        };
1407        write!(self.out, "{conjunction}")?;
1408        if coord_is_vector {
1409            write!(self.out, "{NAMESPACE}::all(")?;
1410        }
1411        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1412        write!(self.out, " < ")?;
1413        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1414        if coord_is_vector {
1415            write!(self.out, ")")?;
1416        }
1417
1418        Ok(())
1419    }
1420
1421    fn put_image_load(
1422        &mut self,
1423        load: Handle<crate::Expression>,
1424        image: Handle<crate::Expression>,
1425        mut address: TexelAddress,
1426        context: &ExpressionContext,
1427    ) -> BackendResult {
1428        if let crate::TypeInner::Image {
1429            class: crate::ImageClass::External,
1430            ..
1431        } = *context.resolve_type(image)
1432        {
1433            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1434            self.put_expression(image, context, true)?;
1435            write!(self.out, ", ")?;
1436            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1437            write!(self.out, ")")?;
1438            return Ok(());
1439        }
1440
1441        match context.policies.image_load {
1442            proc::BoundsCheckPolicy::Restrict => {
1443                // Use the cached restricted level of detail, if any. Omit the
1444                // level altogether for 1D textures.
1445                if address.level.is_some() {
1446                    address.level = if context.image_needs_lod(image) {
1447                        Some(LevelOfDetail::Restricted(load))
1448                    } else {
1449                        None
1450                    }
1451                }
1452
1453                self.put_expression(image, context, false)?;
1454                write!(self.out, ".read(")?;
1455                self.put_restricted_texel_address(image, &address, context)?;
1456                write!(self.out, ")")?;
1457            }
1458            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1459                write!(self.out, "(")?;
1460                self.put_image_access_bounds_check(image, &address, context)?;
1461                write!(self.out, " ? ")?;
1462                self.put_unchecked_image_load(image, &address, context)?;
1463                write!(self.out, ": DefaultConstructible())")?;
1464            }
1465            proc::BoundsCheckPolicy::Unchecked => {
1466                self.put_unchecked_image_load(image, &address, context)?;
1467            }
1468        }
1469
1470        Ok(())
1471    }
1472
1473    fn put_unchecked_image_load(
1474        &mut self,
1475        image: Handle<crate::Expression>,
1476        address: &TexelAddress,
1477        context: &ExpressionContext,
1478    ) -> BackendResult {
1479        self.put_expression(image, context, false)?;
1480        write!(self.out, ".read(")?;
1481        // coordinates in IR are int, but Metal expects uint
1482        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1483        if let Some(expr) = address.array_index {
1484            write!(self.out, ", ")?;
1485            self.put_expression(expr, context, true)?;
1486        }
1487        if let Some(sample) = address.sample {
1488            write!(self.out, ", ")?;
1489            self.put_expression(sample, context, true)?;
1490        }
1491        if let Some(level) = address.level {
1492            if context.image_needs_lod(image) {
1493                write!(self.out, ", ")?;
1494                self.put_level_of_detail(level, context)?;
1495            }
1496        }
1497        write!(self.out, ")")?;
1498
1499        Ok(())
1500    }
1501
1502    fn put_image_atomic(
1503        &mut self,
1504        level: back::Level,
1505        image: Handle<crate::Expression>,
1506        address: &TexelAddress,
1507        fun: crate::AtomicFunction,
1508        value: Handle<crate::Expression>,
1509        context: &StatementContext,
1510    ) -> BackendResult {
1511        write!(self.out, "{level}")?;
1512        self.put_expression(image, &context.expression, false)?;
1513        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1514            fun.to_msl_64_bit()?
1515        } else {
1516            fun.to_msl()
1517        };
1518        write!(self.out, ".atomic_{op}(")?;
1519        // coordinates in IR are int, but Metal expects uint
1520        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1521        write!(self.out, ", ")?;
1522        self.put_expression(value, &context.expression, true)?;
1523        writeln!(self.out, ");")?;
1524
1525        // Workaround for Apple Metal TBDR driver bug: fragment shader atomic
1526        // texture writes randomly drop unless followed by a standard texture
1527        // write. Insert a dead-code write behind an unprovable condition so
1528        // the compiler emits proper memory safety barriers.
1529        // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425
1530        let value_ty = context.expression.resolve_type(value);
1531        let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1532            (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1533            (_, Some(8)) => "ulong4(0uL)",
1534            _ => "uint4(0u)",
1535        };
1536        let coord_ty = context.expression.resolve_type(address.coordinate);
1537        let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1538            ""
1539        } else {
1540            ".x"
1541        };
1542        write!(self.out, "{level}if (")?;
1543        self.put_expression(address.coordinate, &context.expression, true)?;
1544        write!(self.out, "{x} == -99999) {{ ")?;
1545        self.put_expression(image, &context.expression, false)?;
1546        write!(self.out, ".write({zero_value}, ")?;
1547        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1548        if let Some(array_index) = address.array_index {
1549            write!(self.out, ", ")?;
1550            self.put_expression(array_index, &context.expression, true)?;
1551        }
1552        writeln!(self.out, "); }}")?;
1553
1554        Ok(())
1555    }
1556
1557    fn put_image_store(
1558        &mut self,
1559        level: back::Level,
1560        image: Handle<crate::Expression>,
1561        address: &TexelAddress,
1562        value: Handle<crate::Expression>,
1563        context: &StatementContext,
1564    ) -> BackendResult {
1565        write!(self.out, "{level}")?;
1566        self.put_expression(image, &context.expression, false)?;
1567        write!(self.out, ".write(")?;
1568        self.put_expression(value, &context.expression, true)?;
1569        write!(self.out, ", ")?;
1570        // coordinates in IR are int, but Metal expects uint
1571        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1572        if let Some(expr) = address.array_index {
1573            write!(self.out, ", ")?;
1574            self.put_expression(expr, &context.expression, true)?;
1575        }
1576        writeln!(self.out, ");")?;
1577
1578        Ok(())
1579    }
1580
1581    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1582    ///
1583    /// The 'maximum valid index' is simply one less than the array's length.
1584    ///
1585    /// This emits an expression of the form `a / b`, so the caller must
1586    /// parenthesize its output if it will be applying operators of higher
1587    /// precedence.
1588    ///
1589    /// `handle` must be the handle of a global variable whose final member is a
1590    /// dynamically sized array.
1591    fn put_dynamic_array_max_index(
1592        &mut self,
1593        handle: Handle<crate::GlobalVariable>,
1594        context: &ExpressionContext,
1595    ) -> BackendResult {
1596        let global = &context.module.global_variables[handle];
1597        let (offset, array_ty) = match context.module.types[global.ty].inner {
1598            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1599                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1600                None => return Err(Error::GenericValidation("Struct has no members".into())),
1601            },
1602            crate::TypeInner::Array {
1603                size: crate::ArraySize::Dynamic,
1604                ..
1605            } => (0, global.ty),
1606            ref ty => {
1607                return Err(Error::GenericValidation(format!(
1608                    "Expected type with dynamic array, got {ty:?}"
1609                )))
1610            }
1611        };
1612
1613        let (size, stride) = match context.module.types[array_ty].inner {
1614            crate::TypeInner::Array { base, stride, .. } => (
1615                context.module.types[base]
1616                    .inner
1617                    .size(context.module.to_ctx()),
1618                stride,
1619            ),
1620            ref ty => {
1621                return Err(Error::GenericValidation(format!(
1622                    "Expected array type, got {ty:?}"
1623                )))
1624            }
1625        };
1626
1627        // When the stride length is larger than the size, the final element's stride of
1628        // bytes would have padding following the value. But the buffer size in
1629        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1630        // enough to hold the actual values' bytes.
1631        //
1632        // So subtract off the size to get a byte size that falls at the start or within
1633        // the final element. Then divide by the stride size, to get one less than the
1634        // length, and then add one. This works even if the buffer size does include the
1635        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1636        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1637        // rules, together with draw-time validation when `minBindingSize` is zero,
1638        // prevent that.
1639        write!(
1640            self.out,
1641            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1642            member = ArraySizeMember(handle),
1643            offset = offset,
1644            size = size,
1645            stride = stride,
1646        )?;
1647        Ok(())
1648    }
1649
1650    /// Emit code for the arithmetic expression of the dot product.
1651    ///
1652    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1653    /// an index. It writes out the expression for the vector component at that index.
1654    fn put_dot_product<T: Copy>(
1655        &mut self,
1656        arg: T,
1657        arg1: T,
1658        size: usize,
1659        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1660    ) -> BackendResult {
1661        // Write parentheses around the dot product expression to prevent operators
1662        // with different precedences from applying earlier.
1663        write!(self.out, "(")?;
1664
1665        // Cycle through all the components of the vector
1666        for index in 0..size {
1667            // Write the addition to the previous product
1668            // This will print an extra '+' at the beginning but that is fine in msl
1669            write!(self.out, " + ")?;
1670            extractor(self, arg, index)?;
1671            write!(self.out, " * ")?;
1672            extractor(self, arg1, index)?;
1673        }
1674
1675        write!(self.out, ")")?;
1676        Ok(())
1677    }
1678
1679    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1680    fn put_pack4x8(
1681        &mut self,
1682        arg: Handle<crate::Expression>,
1683        context: &ExpressionContext<'_>,
1684        was_signed: bool,
1685        clamp_bounds: Option<(&str, &str)>,
1686    ) -> Result<(), Error> {
1687        let write_arg = |this: &mut Self| -> BackendResult {
1688            if let Some((min, max)) = clamp_bounds {
1689                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1690                write!(this.out, "{NAMESPACE}::clamp(")?;
1691                this.put_expression(arg, context, true)?;
1692                write!(this.out, ", {min}, {max})")?;
1693            } else {
1694                this.put_expression(arg, context, true)?;
1695            }
1696            Ok(())
1697        };
1698
1699        if context.lang_version >= (2, 1) {
1700            let packed_type = if was_signed {
1701                "packed_char4"
1702            } else {
1703                "packed_uchar4"
1704            };
1705            // Metal uses little endian byte order, which matches what WGSL expects here.
1706            write!(self.out, "as_type<uint>({packed_type}(")?;
1707            write_arg(self)?;
1708            write!(self.out, "))")?;
1709        } else {
1710            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1711            if was_signed {
1712                write!(self.out, "uint(")?;
1713            }
1714            write!(self.out, "(")?;
1715            write_arg(self)?;
1716            write!(self.out, "[0] & 0xFF) | ((")?;
1717            write_arg(self)?;
1718            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1719            write_arg(self)?;
1720            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1721            write_arg(self)?;
1722            write!(self.out, "[3] & 0xFF) << 24)")?;
1723            if was_signed {
1724                write!(self.out, ")")?;
1725            }
1726        }
1727
1728        Ok(())
1729    }
1730
1731    /// Emit code for the isign expression.
1732    ///
1733    fn put_isign(
1734        &mut self,
1735        arg: Handle<crate::Expression>,
1736        context: &ExpressionContext,
1737    ) -> BackendResult {
1738        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1739        let scalar = context
1740            .resolve_type(arg)
1741            .scalar()
1742            .expect("put_isign should only be called for args which have an integer scalar type")
1743            .to_msl_name();
1744        match context.resolve_type(arg) {
1745            &crate::TypeInner::Vector { size, .. } => {
1746                let size = common::vector_size_str(size);
1747                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1748            }
1749            _ => {
1750                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1751            }
1752        }
1753        write!(self.out, ", (")?;
1754        self.put_expression(arg, context, true)?;
1755        write!(self.out, " > 0)), {scalar}(0), (")?;
1756        self.put_expression(arg, context, true)?;
1757        write!(self.out, " == 0))")?;
1758        Ok(())
1759    }
1760
1761    pub(super) fn put_const_expression(
1762        &mut self,
1763        expr_handle: Handle<crate::Expression>,
1764        module: &crate::Module,
1765        mod_info: &valid::ModuleInfo,
1766        arena: &crate::Arena<crate::Expression>,
1767    ) -> BackendResult {
1768        self.put_possibly_const_expression(
1769            expr_handle,
1770            arena,
1771            module,
1772            mod_info,
1773            &(module, mod_info),
1774            |&(_, mod_info), expr| &mod_info[expr],
1775            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1776        )
1777    }
1778
1779    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1780        match literal {
1781            crate::Literal::F64(_) => {
1782                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1783            }
1784            crate::Literal::F16(value) => {
1785                if value.is_infinite() {
1786                    let sign = if value.is_sign_negative() { "-" } else { "" };
1787                    write!(self.out, "{sign}INFINITY")?;
1788                } else if value.is_nan() {
1789                    write!(self.out, "NAN")?;
1790                } else {
1791                    let suffix = if value.fract() == f16::from_f32(0.0) {
1792                        ".0h"
1793                    } else {
1794                        "h"
1795                    };
1796                    write!(self.out, "{value}{suffix}")?;
1797                }
1798            }
1799            crate::Literal::F32(value) => {
1800                if value.is_infinite() {
1801                    let sign = if value.is_sign_negative() { "-" } else { "" };
1802                    write!(self.out, "{sign}INFINITY")?;
1803                } else if value.is_nan() {
1804                    write!(self.out, "NAN")?;
1805                } else {
1806                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1807                    write!(self.out, "{value}{suffix}")?;
1808                }
1809            }
1810            crate::Literal::U16(value) => {
1811                write!(self.out, "static_cast<ushort>({value})")?;
1812            }
1813            crate::Literal::I16(value) => {
1814                write!(self.out, "static_cast<short>({value})")?;
1815            }
1816            crate::Literal::U32(value) => {
1817                write!(self.out, "{value}u")?;
1818            }
1819            crate::Literal::I32(value) => {
1820                // `-2147483648` is parsed as unary negation of positive 2147483648.
1821                // 2147483648 is too large for int32_t meaning the expression gets
1822                // promoted to a int64_t which is not our intention. Avoid this by instead
1823                // using `-2147483647 - 1`.
1824                if value == i32::MIN {
1825                    write!(self.out, "({} - 1)", value + 1)?;
1826                } else {
1827                    write!(self.out, "{value}")?;
1828                }
1829            }
1830            crate::Literal::U64(value) => {
1831                write!(self.out, "{value}uL")?;
1832            }
1833            crate::Literal::I64(value) => {
1834                // `-9223372036854775808` is parsed as unary negation of positive
1835                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1836                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1837                // value to `-9223372036854775808`. Which would then be negated, possibly
1838                // causing undefined behaviour. Avoid this by instead using
1839                // `-9223372036854775808L - 1L`.
1840                if value == i64::MIN {
1841                    write!(self.out, "({}L - 1L)", value + 1)?;
1842                } else {
1843                    write!(self.out, "{value}L")?;
1844                }
1845            }
1846            crate::Literal::Bool(value) => {
1847                write!(self.out, "{value}")?;
1848            }
1849            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1850                return Err(Error::GenericValidation(
1851                    "Unsupported abstract literal".into(),
1852                ));
1853            }
1854        }
1855        Ok(())
1856    }
1857
1858    #[allow(clippy::too_many_arguments)]
1859    fn put_possibly_const_expression<C, I, E>(
1860        &mut self,
1861        expr_handle: Handle<crate::Expression>,
1862        expressions: &crate::Arena<crate::Expression>,
1863        module: &crate::Module,
1864        mod_info: &valid::ModuleInfo,
1865        ctx: &C,
1866        get_expr_ty: I,
1867        put_expression: E,
1868    ) -> BackendResult
1869    where
1870        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1871        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1872    {
1873        match expressions[expr_handle] {
1874            crate::Expression::Literal(literal) => {
1875                self.put_literal(literal)?;
1876            }
1877            crate::Expression::Constant(handle) => {
1878                let constant = &module.constants[handle];
1879                if constant.name.is_some() {
1880                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1881                } else {
1882                    self.put_const_expression(
1883                        constant.init,
1884                        module,
1885                        mod_info,
1886                        &module.global_expressions,
1887                    )?;
1888                }
1889            }
1890            crate::Expression::ZeroValue(ty) => {
1891                let ty_name = TypeContext {
1892                    handle: ty,
1893                    gctx: module.to_ctx(),
1894                    names: &self.names,
1895                    access: crate::StorageAccess::empty(),
1896                    first_time: false,
1897                };
1898                write!(self.out, "{ty_name} {{}}")?;
1899            }
1900            crate::Expression::Compose { ty, ref components } => {
1901                let ty_name = TypeContext {
1902                    handle: ty,
1903                    gctx: module.to_ctx(),
1904                    names: &self.names,
1905                    access: crate::StorageAccess::empty(),
1906                    first_time: false,
1907                };
1908                write!(self.out, "{ty_name}")?;
1909                match module.types[ty].inner {
1910                    crate::TypeInner::Scalar(_)
1911                    | crate::TypeInner::Vector { .. }
1912                    | crate::TypeInner::Matrix { .. } => {
1913                        self.put_call_parameters_impl(
1914                            components.iter().copied(),
1915                            ctx,
1916                            put_expression,
1917                        )?;
1918                    }
1919                    crate::TypeInner::Array { .. } => {
1920                        // Naga Arrays are Metal arrays wrapped in structs, so
1921                        // we need two levels of braces.
1922                        write!(self.out, " {{{{")?;
1923                        for (index, &component) in components.iter().enumerate() {
1924                            if index != 0 {
1925                                write!(self.out, ", ")?;
1926                            }
1927                            put_expression(self, ctx, component)?;
1928                        }
1929                        write!(self.out, "}}}}")?;
1930                    }
1931                    crate::TypeInner::Struct { .. } => {
1932                        write!(self.out, " {{")?;
1933                        for (index, &component) in components.iter().enumerate() {
1934                            if index != 0 {
1935                                write!(self.out, ", ")?;
1936                            }
1937                            // insert padding initialization, if needed
1938                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1939                                write!(self.out, "{{}}, ")?;
1940                            }
1941                            put_expression(self, ctx, component)?;
1942                        }
1943                        write!(self.out, "}}")?;
1944                    }
1945                    _ => return Err(Error::UnsupportedCompose(ty)),
1946                }
1947            }
1948            crate::Expression::Splat { size, value } => {
1949                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1950                    crate::TypeInner::Scalar(scalar) => scalar,
1951                    ref ty => {
1952                        return Err(Error::GenericValidation(format!(
1953                            "Expected splat value type must be a scalar, got {ty:?}",
1954                        )))
1955                    }
1956                };
1957                put_numeric_type(&mut self.out, scalar, &[size])?;
1958                write!(self.out, "(")?;
1959                put_expression(self, ctx, value)?;
1960                write!(self.out, ")")?;
1961            }
1962            _ => {
1963                return Err(Error::Override);
1964            }
1965        }
1966
1967        Ok(())
1968    }
1969
1970    /// Emit code for the expression `expr_handle`.
1971    ///
1972    /// The `is_scoped` argument is true if the surrounding operators have the
1973    /// precedence of the comma operator, or lower. So, for example:
1974    ///
1975    /// - Pass `true` for `is_scoped` when writing function arguments, an
1976    ///   expression statement, an initializer expression, or anything already
1977    ///   wrapped in parenthesis.
1978    ///
1979    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1980    ///   almost anything else.
1981    pub(super) fn put_expression(
1982        &mut self,
1983        expr_handle: Handle<crate::Expression>,
1984        context: &ExpressionContext,
1985        is_scoped: bool,
1986    ) -> BackendResult {
1987        // Add to the set in order to track the stack size.
1988        #[cfg(test)]
1989        self.put_expression_stack_pointers
1990            .insert(ptr::from_ref(&expr_handle).cast());
1991
1992        if let Some(name) = self.named_expressions.get(&expr_handle) {
1993            write!(self.out, "{name}")?;
1994            return Ok(());
1995        }
1996
1997        let expression = &context.function.expressions[expr_handle];
1998        match *expression {
1999            crate::Expression::Literal(_)
2000            | crate::Expression::Constant(_)
2001            | crate::Expression::ZeroValue(_)
2002            | crate::Expression::Compose { .. }
2003            | crate::Expression::Splat { .. } => {
2004                self.put_possibly_const_expression(
2005                    expr_handle,
2006                    &context.function.expressions,
2007                    context.module,
2008                    context.mod_info,
2009                    context,
2010                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2011                    |writer, context, expr| writer.put_expression(expr, context, true),
2012                )?;
2013            }
2014            crate::Expression::Override(_) => return Err(Error::Override),
2015            crate::Expression::Access { base, .. }
2016            | crate::Expression::AccessIndex { base, .. } => {
2017                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
2018                // Since `put_bounds_checks` and `put_access_chain` handle an entire
2019                // access chain at a time, recursing back through `put_expression` only
2020                // for index expressions and the base object, we will never see intermediate
2021                // `Access` or `AccessIndex` expressions here.
2022                let policy = context.choose_bounds_check_policy(base);
2023                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2024                    && self.put_bounds_checks(
2025                        expr_handle,
2026                        context,
2027                        back::Level(0),
2028                        if is_scoped { "" } else { "(" },
2029                    )?
2030                {
2031                    write!(self.out, " ? ")?;
2032                    self.put_access_chain(expr_handle, policy, context)?;
2033                    write!(self.out, " : ")?;
2034
2035                    if context.resolve_type(base).pointer_space().is_some() {
2036                        // We can't just use `DefaultConstructible` if this is a pointer.
2037                        // Instead, we create a dummy local variable to serve as pointer
2038                        // target if the access is out of bounds.
2039                        let result_ty = context.info[expr_handle]
2040                            .ty
2041                            .inner_with(&context.module.types)
2042                            .pointer_base_type();
2043                        let result_ty_handle = match result_ty {
2044                            Some(TypeResolution::Handle(handle)) => handle,
2045                            Some(TypeResolution::Value(_)) => {
2046                                // As long as the result of a pointer access expression is
2047                                // passed to a function or stored in a let binding, the
2048                                // type will be in the arena. If additional uses of
2049                                // pointers become valid, this assumption might no longer
2050                                // hold. Note that the LHS of a load or store doesn't
2051                                // take this path -- there is dedicated code in `put_load`
2052                                // and `put_store`.
2053                                unreachable!(
2054                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2055                                );
2056                            }
2057                            None => {
2058                                unreachable!(
2059                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2060                                )
2061                            }
2062                        };
2063                        let name_key =
2064                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
2065                        self.out.write_str(&self.names[&name_key])?;
2066                    } else {
2067                        write!(self.out, "DefaultConstructible()")?;
2068                    }
2069
2070                    if !is_scoped {
2071                        write!(self.out, ")")?;
2072                    }
2073                } else {
2074                    self.put_access_chain(expr_handle, policy, context)?;
2075                }
2076            }
2077            crate::Expression::Swizzle {
2078                size,
2079                vector,
2080                pattern,
2081            } => {
2082                self.put_wrapped_expression_for_packed_vec3_access(
2083                    vector,
2084                    context,
2085                    false,
2086                    &Self::put_expression,
2087                )?;
2088                write!(self.out, ".")?;
2089                for &sc in pattern[..size as usize].iter() {
2090                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2091                }
2092            }
2093            crate::Expression::FunctionArgument(index) => {
2094                let name_key = match context.origin {
2095                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2096                    FunctionOrigin::EntryPoint(ep_index) => {
2097                        NameKey::EntryPointArgument(ep_index, index)
2098                    }
2099                };
2100                let name = &self.names[&name_key];
2101                write!(self.out, "{name}")?;
2102            }
2103            crate::Expression::GlobalVariable(handle) => {
2104                let name = &self.names[&NameKey::GlobalVariable(handle)];
2105                write!(self.out, "{name}")?;
2106            }
2107            crate::Expression::LocalVariable(handle) => {
2108                let name_key = NameKey::local(context.origin, handle);
2109                let name = &self.names[&name_key];
2110                write!(self.out, "{name}")?;
2111            }
2112            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2113            crate::Expression::ImageSample {
2114                coordinate,
2115                image,
2116                sampler,
2117                clamp_to_edge: true,
2118                gather: None,
2119                array_index: None,
2120                offset: None,
2121                level: crate::SampleLevel::Zero,
2122                depth_ref: None,
2123            } => {
2124                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2125                self.put_expression(image, context, true)?;
2126                write!(self.out, ", ")?;
2127                self.put_expression(sampler, context, true)?;
2128                write!(self.out, ", ")?;
2129                self.put_expression(coordinate, context, true)?;
2130                write!(self.out, ")")?;
2131            }
2132            crate::Expression::ImageSample {
2133                image,
2134                sampler,
2135                gather,
2136                coordinate,
2137                array_index,
2138                offset,
2139                level,
2140                depth_ref,
2141                clamp_to_edge,
2142            } => {
2143                if clamp_to_edge {
2144                    return Err(Error::GenericValidation(
2145                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2146                    ));
2147                }
2148
2149                let main_op = match gather {
2150                    Some(_) => "gather",
2151                    None => "sample",
2152                };
2153                let comparison_op = match depth_ref {
2154                    Some(_) => "_compare",
2155                    None => "",
2156                };
2157                self.put_expression(image, context, false)?;
2158                write!(self.out, ".{main_op}{comparison_op}(")?;
2159                self.put_expression(sampler, context, true)?;
2160                write!(self.out, ", ")?;
2161                self.put_expression(coordinate, context, true)?;
2162                if let Some(expr) = array_index {
2163                    write!(self.out, ", ")?;
2164                    self.put_expression(expr, context, true)?;
2165                }
2166                if let Some(dref) = depth_ref {
2167                    write!(self.out, ", ")?;
2168                    self.put_expression(dref, context, true)?;
2169                }
2170
2171                self.put_image_sample_level(image, level, context)?;
2172
2173                if let Some(offset) = offset {
2174                    write!(self.out, ", ")?;
2175                    self.put_expression(offset, context, true)?;
2176                }
2177
2178                match gather {
2179                    None | Some(crate::SwizzleComponent::X) => {}
2180                    Some(component) => {
2181                        let is_cube_map = match *context.resolve_type(image) {
2182                            crate::TypeInner::Image {
2183                                dim: crate::ImageDimension::Cube,
2184                                ..
2185                            } => true,
2186                            _ => false,
2187                        };
2188                        // Offset always comes before the gather, except
2189                        // in cube maps where it's not applicable
2190                        if offset.is_none() && !is_cube_map {
2191                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2192                        }
2193                        let letter = back::COMPONENTS[component as usize];
2194                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2195                    }
2196                }
2197                write!(self.out, ")")?;
2198            }
2199            crate::Expression::ImageLoad {
2200                image,
2201                coordinate,
2202                array_index,
2203                sample,
2204                level,
2205            } => {
2206                let address = TexelAddress {
2207                    coordinate,
2208                    array_index,
2209                    sample,
2210                    level: level.map(LevelOfDetail::Direct),
2211                };
2212                self.put_image_load(expr_handle, image, address, context)?;
2213            }
2214            //Note: for all the queries, the signed integers are expected,
2215            // so a conversion is needed.
2216            crate::Expression::ImageQuery { image, query } => match query {
2217                crate::ImageQuery::Size { level } => {
2218                    self.put_image_size_query(
2219                        image,
2220                        level.map(LevelOfDetail::Direct),
2221                        crate::ScalarKind::Uint,
2222                        context,
2223                    )?;
2224                }
2225                crate::ImageQuery::NumLevels => {
2226                    self.put_expression(image, context, false)?;
2227                    write!(self.out, ".get_num_mip_levels()")?;
2228                }
2229                crate::ImageQuery::NumLayers => {
2230                    self.put_expression(image, context, false)?;
2231                    write!(self.out, ".get_array_size()")?;
2232                }
2233                crate::ImageQuery::NumSamples => {
2234                    self.put_expression(image, context, false)?;
2235                    write!(self.out, ".get_num_samples()")?;
2236                }
2237            },
2238            crate::Expression::Unary { op, expr } => {
2239                let op_str = match op {
2240                    crate::UnaryOperator::Negate => {
2241                        match context.resolve_type(expr).scalar_kind() {
2242                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2243                            _ => "-",
2244                        }
2245                    }
2246                    crate::UnaryOperator::LogicalNot => "!",
2247                    crate::UnaryOperator::BitwiseNot => "~",
2248                };
2249                write!(self.out, "{op_str}(")?;
2250                self.put_expression(expr, context, false)?;
2251                write!(self.out, ")")?;
2252            }
2253            crate::Expression::Binary { op, left, right } => {
2254                let kind = context
2255                    .resolve_type(left)
2256                    .scalar_kind()
2257                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2258
2259                if op == crate::BinaryOperator::Divide
2260                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2261                    && context.emit_int_div_checks
2262                {
2263                    write!(self.out, "{DIV_FUNCTION}(")?;
2264                    self.put_expression(left, context, true)?;
2265                    write!(self.out, ", ")?;
2266                    self.put_expression(right, context, true)?;
2267                    write!(self.out, ")")?;
2268                } else if op == crate::BinaryOperator::Modulo
2269                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2270                    && context.emit_int_div_checks
2271                {
2272                    write!(self.out, "{MOD_FUNCTION}(")?;
2273                    self.put_expression(left, context, true)?;
2274                    write!(self.out, ", ")?;
2275                    self.put_expression(right, context, true)?;
2276                    write!(self.out, ")")?;
2277                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2278                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2279                    //
2280                    // float:
2281                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2282                    write!(self.out, "{NAMESPACE}::fmod(")?;
2283                    self.put_expression(left, context, true)?;
2284                    write!(self.out, ", ")?;
2285                    self.put_expression(right, context, true)?;
2286                    write!(self.out, ")")?;
2287                } else if (op == crate::BinaryOperator::Add
2288                    || op == crate::BinaryOperator::Subtract
2289                    || op == crate::BinaryOperator::Multiply)
2290                    && kind == crate::ScalarKind::Sint
2291                {
2292                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2293                        crate::TypeInner::Scalar(scalar) => {
2294                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2295                                kind: crate::ScalarKind::Uint,
2296                                ..scalar
2297                            }))
2298                        }
2299                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2300                            size,
2301                            scalar: crate::Scalar {
2302                                kind: crate::ScalarKind::Uint,
2303                                ..scalar
2304                            },
2305                        }),
2306                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2307                    };
2308
2309                    // Avoid undefined behaviour due to overflowing signed
2310                    // integer arithmetic. Cast the operands to unsigned prior
2311                    // to performing the operation, then cast the result back
2312                    // to signed.
2313                    self.put_bitcasted_expression(
2314                        context.resolve_type(expr_handle),
2315                        expr_handle,
2316                        context,
2317                        &|writer, context, is_scoped| {
2318                            writer.put_binop(
2319                                op,
2320                                left,
2321                                right,
2322                                context,
2323                                is_scoped,
2324                                &|writer, expr, context, _is_scoped| {
2325                                    writer.put_bitcasted_expression(
2326                                        &to_unsigned(context.resolve_type(expr))?,
2327                                        expr,
2328                                        context,
2329                                        &|writer, context, is_scoped| {
2330                                            writer.put_expression(expr, context, is_scoped)
2331                                        },
2332                                    )
2333                                },
2334                            )
2335                        },
2336                    )?;
2337                } else {
2338                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2339                }
2340            }
2341            crate::Expression::Select {
2342                condition,
2343                accept,
2344                reject,
2345            } => match *context.resolve_type(condition) {
2346                crate::TypeInner::Scalar(crate::Scalar {
2347                    kind: crate::ScalarKind::Bool,
2348                    ..
2349                }) => {
2350                    if !is_scoped {
2351                        write!(self.out, "(")?;
2352                    }
2353                    self.put_expression(condition, context, false)?;
2354                    write!(self.out, " ? ")?;
2355                    self.put_expression(accept, context, false)?;
2356                    write!(self.out, " : ")?;
2357                    self.put_expression(reject, context, false)?;
2358                    if !is_scoped {
2359                        write!(self.out, ")")?;
2360                    }
2361                }
2362                crate::TypeInner::Vector {
2363                    scalar:
2364                        crate::Scalar {
2365                            kind: crate::ScalarKind::Bool,
2366                            ..
2367                        },
2368                    ..
2369                } => {
2370                    write!(self.out, "{NAMESPACE}::select(")?;
2371                    self.put_expression(reject, context, true)?;
2372                    write!(self.out, ", ")?;
2373                    self.put_expression(accept, context, true)?;
2374                    write!(self.out, ", ")?;
2375                    self.put_expression(condition, context, true)?;
2376                    write!(self.out, ")")?;
2377                }
2378                ref ty => {
2379                    return Err(Error::GenericValidation(format!(
2380                        "Expected select condition to be a non-bool type, got {ty:?}",
2381                    )))
2382                }
2383            },
2384            crate::Expression::Derivative { axis, expr, .. } => {
2385                use crate::DerivativeAxis as Axis;
2386                let op = match axis {
2387                    Axis::X => "dfdx",
2388                    Axis::Y => "dfdy",
2389                    Axis::Width => "fwidth",
2390                };
2391                write!(self.out, "{NAMESPACE}::{op}")?;
2392                self.put_call_parameters(iter::once(expr), context)?;
2393            }
2394            crate::Expression::Relational { fun, argument } => {
2395                let op = match fun {
2396                    crate::RelationalFunction::Any => "any",
2397                    crate::RelationalFunction::All => "all",
2398                    crate::RelationalFunction::IsNan => "isnan",
2399                    crate::RelationalFunction::IsInf => "isinf",
2400                };
2401                write!(self.out, "{NAMESPACE}::{op}")?;
2402                self.put_call_parameters(iter::once(argument), context)?;
2403            }
2404            crate::Expression::Math {
2405                fun,
2406                arg,
2407                arg1,
2408                arg2,
2409                arg3,
2410            } => {
2411                use crate::MathFunction as Mf;
2412
2413                let arg_type = context.resolve_type(arg);
2414                let scalar_argument = match arg_type {
2415                    &crate::TypeInner::Scalar(_) => true,
2416                    _ => false,
2417                };
2418
2419                let fun_name = match fun {
2420                    // comparison
2421                    Mf::Abs => "abs",
2422                    Mf::Min => "min",
2423                    Mf::Max => "max",
2424                    Mf::Clamp => "clamp",
2425                    Mf::Saturate => "saturate",
2426                    // trigonometry
2427                    Mf::Cos => "cos",
2428                    Mf::Cosh => "cosh",
2429                    Mf::Sin => "sin",
2430                    Mf::Sinh => "sinh",
2431                    Mf::Tan => "tan",
2432                    Mf::Tanh => "tanh",
2433                    Mf::Acos => "acos",
2434                    Mf::Asin => "asin",
2435                    Mf::Atan => "atan",
2436                    Mf::Atan2 => "atan2",
2437                    Mf::Asinh => "asinh",
2438                    Mf::Acosh => "acosh",
2439                    Mf::Atanh => "atanh",
2440                    Mf::Radians => "",
2441                    Mf::Degrees => "",
2442                    // decomposition
2443                    Mf::Ceil => "ceil",
2444                    Mf::Floor => "floor",
2445                    Mf::Round => "rint",
2446                    Mf::Fract => "fract",
2447                    Mf::Trunc => "trunc",
2448                    Mf::Modf => MODF_FUNCTION,
2449                    Mf::Frexp => FREXP_FUNCTION,
2450                    Mf::Ldexp => "ldexp",
2451                    // exponent
2452                    Mf::Exp => "exp",
2453                    Mf::Exp2 => "exp2",
2454                    Mf::Log => "log",
2455                    Mf::Log2 => "log2",
2456                    Mf::Pow => "pow",
2457                    // geometry
2458                    Mf::Dot => match *context.resolve_type(arg) {
2459                        crate::TypeInner::Vector {
2460                            scalar:
2461                                crate::Scalar {
2462                                    // Resolve float values to MSL's builtin dot function.
2463                                    kind: crate::ScalarKind::Float,
2464                                    ..
2465                                },
2466                            ..
2467                        } => "dot",
2468                        crate::TypeInner::Vector {
2469                            size,
2470                            scalar:
2471                                scalar @ crate::Scalar {
2472                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2473                                    ..
2474                                },
2475                        } => {
2476                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2477                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2478                            write!(self.out, "{fun_name}(")?;
2479                            self.put_expression(arg, context, true)?;
2480                            write!(self.out, ", ")?;
2481                            self.put_expression(arg1.unwrap(), context, true)?;
2482                            write!(self.out, ")")?;
2483                            return Ok(());
2484                        }
2485                        _ => unreachable!(
2486                            "Correct TypeInner for dot product should be already validated"
2487                        ),
2488                    },
2489                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2490                        if context.lang_version >= (2, 1) {
2491                            // Write potentially optimizable code using `packed_(u?)char4`.
2492                            // The two function arguments were already reinterpreted as packed (signed
2493                            // or unsigned) chars in `Self::put_block`.
2494                            let packed_type = match fun {
2495                                Mf::Dot4I8Packed => "packed_char4",
2496                                Mf::Dot4U8Packed => "packed_uchar4",
2497                                _ => unreachable!(),
2498                            };
2499
2500                            return self.put_dot_product(
2501                                Reinterpreted::new(packed_type, arg),
2502                                Reinterpreted::new(packed_type, arg1.unwrap()),
2503                                4,
2504                                |writer, arg, index| {
2505                                    // MSL implicitly promotes these (signed or unsigned) chars to
2506                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2507                                    write!(writer.out, "{arg}[{index}]")?;
2508                                    Ok(())
2509                                },
2510                            );
2511                        } else {
2512                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2513                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2514                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2515                            let conversion = match fun {
2516                                Mf::Dot4I8Packed => "int",
2517                                Mf::Dot4U8Packed => "",
2518                                _ => unreachable!(),
2519                            };
2520
2521                            return self.put_dot_product(
2522                                arg,
2523                                arg1.unwrap(),
2524                                4,
2525                                |writer, arg, index| {
2526                                    write!(writer.out, "({conversion}(")?;
2527                                    writer.put_expression(arg, context, true)?;
2528                                    if index == 3 {
2529                                        write!(writer.out, ") >> 24)")?;
2530                                    } else {
2531                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2532                                    }
2533                                    Ok(())
2534                                },
2535                            );
2536                        }
2537                    }
2538                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2539                    Mf::Cross => "cross",
2540                    Mf::Distance => "distance",
2541                    Mf::Length if scalar_argument => "abs",
2542                    Mf::Length => "length",
2543                    Mf::Normalize => "normalize",
2544                    Mf::FaceForward => "faceforward",
2545                    Mf::Reflect => "reflect",
2546                    Mf::Refract => "refract",
2547                    // computational
2548                    Mf::Sign => match arg_type.scalar_kind() {
2549                        Some(crate::ScalarKind::Sint) => {
2550                            return self.put_isign(arg, context);
2551                        }
2552                        _ => "sign",
2553                    },
2554                    Mf::Fma => "fma",
2555                    Mf::Mix => "mix",
2556                    Mf::Step => "step",
2557                    Mf::SmoothStep => "smoothstep",
2558                    Mf::Sqrt => "sqrt",
2559                    Mf::InverseSqrt => "rsqrt",
2560                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2561                    Mf::Transpose => "transpose",
2562                    Mf::Determinant => "determinant",
2563                    Mf::QuantizeToF16 => "",
2564                    // bits
2565                    Mf::CountTrailingZeros => "ctz",
2566                    Mf::CountLeadingZeros => "clz",
2567                    Mf::CountOneBits => "popcount",
2568                    Mf::ReverseBits => "reverse_bits",
2569                    Mf::ExtractBits => "",
2570                    Mf::InsertBits => "",
2571                    Mf::FirstTrailingBit => "",
2572                    Mf::FirstLeadingBit => "",
2573                    // data packing
2574                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2575                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2576                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2577                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2578                    Mf::Pack2x16float => "",
2579                    Mf::Pack4xI8 => "",
2580                    Mf::Pack4xU8 => "",
2581                    Mf::Pack4xI8Clamp => "",
2582                    Mf::Pack4xU8Clamp => "",
2583                    // data unpacking
2584                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2585                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2586                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2587                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2588                    Mf::Unpack2x16float => "",
2589                    Mf::Unpack4xI8 => "",
2590                    Mf::Unpack4xU8 => "",
2591                };
2592
2593                match fun {
2594                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2595                        // reverse_bits is listed as requiring MSL 2.1 but that
2596                        // is a copy/paste error. Looking at previous snapshots
2597                        // on web.archive.org it's present in MSL 1.2.
2598                        //
2599                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2600                        // also talks about MSL 1.2 adding "New integer
2601                        // functions to extract, insert, and reverse bits, as
2602                        // described in Integer Functions."
2603                        if context.lang_version < (1, 2) {
2604                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2605                        }
2606                    }
2607                    _ => {}
2608                }
2609
2610                match fun {
2611                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2612                        write!(self.out, "{ABS_FUNCTION}(")?;
2613                        self.put_expression(arg, context, true)?;
2614                        write!(self.out, ")")?;
2615                    }
2616                    Mf::Distance if scalar_argument => {
2617                        write!(self.out, "{NAMESPACE}::abs(")?;
2618                        self.put_expression(arg, context, false)?;
2619                        write!(self.out, " - ")?;
2620                        self.put_expression(arg1.unwrap(), context, false)?;
2621                        write!(self.out, ")")?;
2622                    }
2623                    Mf::FirstTrailingBit => {
2624                        let scalar = context.resolve_type(arg).scalar().unwrap();
2625                        let constant = scalar.width * 8 + 1;
2626
2627                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2628                        self.put_expression(arg, context, true)?;
2629                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2630                    }
2631                    Mf::FirstLeadingBit => {
2632                        let inner = context.resolve_type(arg);
2633                        let scalar = inner.scalar().unwrap();
2634                        let constant = scalar.width * 8 - 1;
2635
2636                        write!(
2637                            self.out,
2638                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2639                        )?;
2640
2641                        if scalar.kind == crate::ScalarKind::Sint {
2642                            write!(self.out, "{NAMESPACE}::select(")?;
2643                            self.put_expression(arg, context, true)?;
2644                            write!(self.out, ", ~")?;
2645                            self.put_expression(arg, context, true)?;
2646                            write!(self.out, ", ")?;
2647                            self.put_expression(arg, context, true)?;
2648                            write!(self.out, " < 0)")?;
2649                        } else {
2650                            self.put_expression(arg, context, true)?;
2651                        }
2652
2653                        write!(self.out, "), ")?;
2654
2655                        // or metal will complain that select is ambiguous
2656                        match *inner {
2657                            crate::TypeInner::Vector { size, scalar } => {
2658                                let size = common::vector_size_str(size);
2659                                let name = scalar.to_msl_name();
2660                                write!(self.out, "{name}{size}")?;
2661                            }
2662                            crate::TypeInner::Scalar(scalar) => {
2663                                let name = scalar.to_msl_name();
2664                                write!(self.out, "{name}")?;
2665                            }
2666                            _ => (),
2667                        }
2668
2669                        write!(self.out, "(-1), ")?;
2670                        self.put_expression(arg, context, true)?;
2671                        write!(self.out, " == 0")?;
2672                        if scalar.kind == crate::ScalarKind::Sint {
2673                            write!(self.out, " || ")?;
2674                            self.put_expression(arg, context, true)?;
2675                            write!(self.out, " == -1")?;
2676                        }
2677                        write!(self.out, ")")?;
2678                    }
2679                    Mf::Unpack2x16float => {
2680                        write!(self.out, "float2(as_type<half2>(")?;
2681                        self.put_expression(arg, context, false)?;
2682                        write!(self.out, "))")?;
2683                    }
2684                    Mf::Pack2x16float => {
2685                        write!(self.out, "as_type<uint>(half2(")?;
2686                        self.put_expression(arg, context, false)?;
2687                        write!(self.out, "))")?;
2688                    }
2689                    Mf::ExtractBits => {
2690                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2691                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2692                        // will return out-of-spec values if the extracted range is not within the bit width.
2693                        //
2694                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2695                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2696                        //
2697                        // w = sizeof(x) * 8
2698                        // o = min(offset, w)
2699                        // tmp = w - o
2700                        // c = min(count, tmp)
2701                        //
2702                        // bitfieldExtract(x, o, c)
2703                        //
2704                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2705
2706                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2707
2708                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2709                        self.put_expression(arg, context, true)?;
2710                        write!(self.out, ", {NAMESPACE}::min(")?;
2711                        self.put_expression(arg1.unwrap(), context, true)?;
2712                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2713                        self.put_expression(arg2.unwrap(), context, true)?;
2714                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2715                        self.put_expression(arg1.unwrap(), context, true)?;
2716                        write!(self.out, ", {scalar_bits}u)))")?;
2717                    }
2718                    Mf::InsertBits => {
2719                        // The behavior of InsertBits has the same issue as ExtractBits.
2720                        //
2721                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2722
2723                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2724
2725                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2726                        self.put_expression(arg, context, true)?;
2727                        write!(self.out, ", ")?;
2728                        self.put_expression(arg1.unwrap(), context, true)?;
2729                        write!(self.out, ", {NAMESPACE}::min(")?;
2730                        self.put_expression(arg2.unwrap(), context, true)?;
2731                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2732                        self.put_expression(arg3.unwrap(), context, true)?;
2733                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2734                        self.put_expression(arg2.unwrap(), context, true)?;
2735                        write!(self.out, ", {scalar_bits}u)))")?;
2736                    }
2737                    Mf::Radians => {
2738                        write!(self.out, "((")?;
2739                        self.put_expression(arg, context, false)?;
2740                        write!(self.out, ") * 0.017453292519943295474)")?;
2741                    }
2742                    Mf::Degrees => {
2743                        write!(self.out, "((")?;
2744                        self.put_expression(arg, context, false)?;
2745                        write!(self.out, ") * 57.295779513082322865)")?;
2746                    }
2747                    Mf::Modf | Mf::Frexp => {
2748                        write!(self.out, "{fun_name}")?;
2749                        self.put_call_parameters(iter::once(arg), context)?;
2750                    }
2751                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2752                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2753                    Mf::Pack4xI8Clamp => {
2754                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2755                    }
2756                    Mf::Pack4xU8Clamp => {
2757                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2758                    }
2759                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2760                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2761                            "u"
2762                        } else {
2763                            ""
2764                        };
2765
2766                        if context.lang_version >= (2, 1) {
2767                            // Metal uses little endian byte order, which matches what WGSL expects here.
2768                            write!(
2769                                self.out,
2770                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2771                            )?;
2772                            self.put_expression(arg, context, true)?;
2773                            write!(self.out, "))")?;
2774                        } else {
2775                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2776                            write!(self.out, "({sign_prefix}int4(")?;
2777                            self.put_expression(arg, context, true)?;
2778                            write!(self.out, ", ")?;
2779                            self.put_expression(arg, context, true)?;
2780                            write!(self.out, " >> 8, ")?;
2781                            self.put_expression(arg, context, true)?;
2782                            write!(self.out, " >> 16, ")?;
2783                            self.put_expression(arg, context, true)?;
2784                            write!(self.out, " >> 24) << 24 >> 24)")?;
2785                        }
2786                    }
2787                    Mf::QuantizeToF16 => {
2788                        match *context.resolve_type(arg) {
2789                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2790                            crate::TypeInner::Vector { size, .. } => write!(
2791                                self.out,
2792                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2793                                size = common::vector_size_str(size),
2794                            )?,
2795                            _ => unreachable!(
2796                                "Correct TypeInner for QuantizeToF16 should be already validated"
2797                            ),
2798                        };
2799
2800                        self.put_expression(arg, context, true)?;
2801                        write!(self.out, "))")?;
2802                    }
2803                    _ => {
2804                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2805                        self.put_call_parameters(
2806                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2807                            context,
2808                        )?;
2809                    }
2810                }
2811            }
2812            crate::Expression::As {
2813                expr,
2814                kind,
2815                convert,
2816            } => match *context.resolve_type(expr) {
2817                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2818                    if src.kind == crate::ScalarKind::Float
2819                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2820                        && convert.is_some()
2821                    {
2822                        // Use helper functions for float to int casts in order to avoid
2823                        // undefined behaviour when value is out of range for the target
2824                        // type.
2825                        let fun_name = match (kind, convert) {
2826                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2827                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2828                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2829                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2830                            _ => unreachable!(),
2831                        };
2832                        write!(self.out, "{fun_name}(")?;
2833                        self.put_expression(expr, context, true)?;
2834                        write!(self.out, ")")?;
2835                    } else {
2836                        let target_scalar = crate::Scalar {
2837                            kind,
2838                            width: convert.unwrap_or(src.width),
2839                        };
2840                        let op = match convert {
2841                            Some(_) => "static_cast",
2842                            None => "as_type",
2843                        };
2844                        write!(self.out, "{op}<")?;
2845                        match *context.resolve_type(expr) {
2846                            crate::TypeInner::Vector { size, .. } => {
2847                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2848                            }
2849                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2850                        };
2851                        write!(self.out, ">(")?;
2852                        self.put_expression(expr, context, true)?;
2853                        write!(self.out, ")")?;
2854                    }
2855                }
2856                crate::TypeInner::Matrix {
2857                    columns,
2858                    rows,
2859                    scalar,
2860                } => {
2861                    let target_scalar = crate::Scalar {
2862                        kind,
2863                        width: convert.unwrap_or(scalar.width),
2864                    };
2865                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2866                    write!(self.out, "(")?;
2867                    self.put_expression(expr, context, true)?;
2868                    write!(self.out, ")")?;
2869                }
2870                ref ty => {
2871                    return Err(Error::GenericValidation(format!(
2872                        "Unsupported type for As: {ty:?}"
2873                    )))
2874                }
2875            },
2876            // has to be a named expression
2877            crate::Expression::CallResult(_)
2878            | crate::Expression::AtomicResult { .. }
2879            | crate::Expression::WorkGroupUniformLoadResult { .. }
2880            | crate::Expression::SubgroupBallotResult
2881            | crate::Expression::SubgroupOperationResult { .. }
2882            | crate::Expression::RayQueryProceedResult => {
2883                unreachable!()
2884            }
2885            crate::Expression::ArrayLength(expr) => {
2886                // Find the global to which the array belongs.
2887                let global = match context.function.expressions[expr] {
2888                    crate::Expression::AccessIndex { base, .. } => {
2889                        match context.function.expressions[base] {
2890                            crate::Expression::GlobalVariable(handle) => handle,
2891                            ref ex => {
2892                                return Err(Error::GenericValidation(format!(
2893                                    "Expected global variable in AccessIndex, got {ex:?}"
2894                                )))
2895                            }
2896                        }
2897                    }
2898                    crate::Expression::GlobalVariable(handle) => handle,
2899                    ref ex => {
2900                        return Err(Error::GenericValidation(format!(
2901                            "Unexpected expression in ArrayLength, got {ex:?}"
2902                        )))
2903                    }
2904                };
2905
2906                if !is_scoped {
2907                    write!(self.out, "(")?;
2908                }
2909                write!(self.out, "1 + ")?;
2910                self.put_dynamic_array_max_index(global, context)?;
2911                if !is_scoped {
2912                    write!(self.out, ")")?;
2913                }
2914            }
2915            crate::Expression::RayQueryVertexPositions { .. } => {
2916                unimplemented!()
2917            }
2918            crate::Expression::RayQueryGetIntersection { query, committed } => {
2919                if context.lang_version < (2, 4) {
2920                    return Err(Error::UnsupportedRayTracing);
2921                }
2922
2923                write!(
2924                    self.out,
2925                    "{}_{committed}(",
2926                    super::ray::INTERSECTION_FUNCTION_NAME
2927                )?;
2928                self.put_expression(query, context, true)?;
2929                write!(self.out, ")")?;
2930            }
2931            crate::Expression::CooperativeLoad { ref data, .. } => {
2932                if context.lang_version < (2, 3) {
2933                    return Err(Error::UnsupportedCooperativeMatrix);
2934                }
2935                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2936                write!(self.out, "&")?;
2937                self.put_access_chain(data.pointer, context.policies.index, context)?;
2938                write!(self.out, ", ")?;
2939                self.put_expression(data.stride, context, true)?;
2940                // Metal's `simdgroup_load` treats its `transpose` flag as
2941                // "memory is transposed from the simdgroup_matrix's canonical
2942                // layout". On Apple GPUs that canonical layout is row-major,
2943                // so `transpose=false` loads from row-major memory. WGSL's
2944                // `coopLoadT` (row_major=true) = row-major memory, so it must
2945                // map to `transpose=false`. Hence the negation.
2946                write!(self.out, ", {})", !data.row_major)?;
2947            }
2948            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2949                if context.lang_version < (2, 3) {
2950                    return Err(Error::UnsupportedCooperativeMatrix);
2951                }
2952                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2953                self.put_expression(a, context, true)?;
2954                write!(self.out, ", ")?;
2955                self.put_expression(b, context, true)?;
2956                write!(self.out, ", ")?;
2957                self.put_expression(c, context, true)?;
2958                write!(self.out, ")")?;
2959            }
2960        }
2961        Ok(())
2962    }
2963
2964    /// Emits code for a binary operation, using the provided callback to emit
2965    /// the left and right operands.
2966    fn put_binop<F>(
2967        &mut self,
2968        op: crate::BinaryOperator,
2969        left: Handle<crate::Expression>,
2970        right: Handle<crate::Expression>,
2971        context: &ExpressionContext,
2972        is_scoped: bool,
2973        put_expression: &F,
2974    ) -> BackendResult
2975    where
2976        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2977    {
2978        let op_str = back::binary_operation_str(op);
2979
2980        if !is_scoped {
2981            write!(self.out, "(")?;
2982        }
2983
2984        // Cast packed vector if necessary
2985        // Packed vector - matrix multiplications are not supported in MSL
2986        if op == crate::BinaryOperator::Multiply
2987            && matches!(
2988                context.resolve_type(right),
2989                &crate::TypeInner::Matrix { .. }
2990            )
2991        {
2992            self.put_wrapped_expression_for_packed_vec3_access(
2993                left,
2994                context,
2995                false,
2996                put_expression,
2997            )?;
2998        } else {
2999            put_expression(self, left, context, false)?;
3000        }
3001
3002        write!(self.out, " {op_str} ")?;
3003
3004        // See comment above
3005        if op == crate::BinaryOperator::Multiply
3006            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3007        {
3008            self.put_wrapped_expression_for_packed_vec3_access(
3009                right,
3010                context,
3011                false,
3012                put_expression,
3013            )?;
3014        } else {
3015            put_expression(self, right, context, false)?;
3016        }
3017
3018        if !is_scoped {
3019            write!(self.out, ")")?;
3020        }
3021
3022        Ok(())
3023    }
3024
3025    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
3026    fn put_wrapped_expression_for_packed_vec3_access<F>(
3027        &mut self,
3028        expr_handle: Handle<crate::Expression>,
3029        context: &ExpressionContext,
3030        is_scoped: bool,
3031        put_expression: &F,
3032    ) -> BackendResult
3033    where
3034        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3035    {
3036        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3037            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3038            put_expression(self, expr_handle, context, is_scoped)?;
3039            write!(self.out, ")")?;
3040        } else {
3041            put_expression(self, expr_handle, context, is_scoped)?;
3042        }
3043        Ok(())
3044    }
3045
3046    /// Emits code for an expression using the provided callback, wrapping the
3047    /// result in a bitcast to the type `cast_to`.
3048    fn put_bitcasted_expression<F>(
3049        &mut self,
3050        cast_to: &crate::TypeInner,
3051        inner_expr: Handle<crate::Expression>,
3052        context: &ExpressionContext,
3053        put_expression: &F,
3054    ) -> BackendResult
3055    where
3056        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3057    {
3058        // For sub-32-bit types, C++ integer promotion can widen the inner
3059        // expression (e.g. `ushort + ushort` promotes to `int`), making a
3060        // direct `as_type<short>(int_expr)` invalid due to size mismatch.
3061        // We wrap with `static_cast` to truncate back before the bitcast.
3062        let needs_truncation = match *cast_to {
3063            crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3064            crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3065            _ => false,
3066        };
3067
3068        write!(self.out, "as_type<")?;
3069        match *cast_to {
3070            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3071            crate::TypeInner::Vector { size, scalar } => {
3072                put_numeric_type(&mut self.out, scalar, &[size])?
3073            }
3074            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3075        };
3076        write!(self.out, ">(")?;
3077
3078        if needs_truncation {
3079            write!(self.out, "static_cast<")?;
3080            // Cast to the unsigned version of the target type to truncate
3081            let unsigned_scalar = match *cast_to {
3082                crate::TypeInner::Scalar(scalar) => crate::Scalar {
3083                    kind: crate::ScalarKind::Uint,
3084                    ..scalar
3085                },
3086                crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3087                    kind: crate::ScalarKind::Uint,
3088                    ..scalar
3089                },
3090                _ => unreachable!(),
3091            };
3092            match *cast_to {
3093                crate::TypeInner::Scalar(_) => {
3094                    put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3095                }
3096                crate::TypeInner::Vector { size, .. } => {
3097                    put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3098                }
3099                _ => unreachable!(),
3100            };
3101            write!(self.out, ">(")?;
3102        }
3103
3104        // if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
3105        if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3106            put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3107            write!(self.out, "(")?;
3108            put_expression(self, context, true)?;
3109            write!(self.out, ")")?;
3110        } else {
3111            put_expression(self, context, true)?;
3112        }
3113
3114        if needs_truncation {
3115            write!(self.out, ")")?;
3116        }
3117
3118        write!(self.out, ")")?;
3119        Ok(())
3120    }
3121
3122    /// Write a `GuardedIndex` as a Metal expression.
3123    fn put_index(
3124        &mut self,
3125        index: index::GuardedIndex,
3126        context: &ExpressionContext,
3127        is_scoped: bool,
3128    ) -> BackendResult {
3129        match index {
3130            index::GuardedIndex::Expression(expr) => {
3131                self.put_expression(expr, context, is_scoped)?
3132            }
3133            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3134        }
3135        Ok(())
3136    }
3137
3138    /// Emit an index bounds check condition for `chain`, if required.
3139    ///
3140    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3141    /// operating either on a pointer to a value, or on a value directly. If we cannot
3142    /// statically determine that all indexing operations in `chain` are within
3143    /// bounds, then write a conditional expression to check them dynamically,
3144    /// and return true. All accesses in the chain are checked by the generated
3145    /// expression.
3146    ///
3147    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3148    ///
3149    /// The text written is of the form:
3150    ///
3151    /// ```ignore
3152    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3153    /// ```
3154    ///
3155    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3156    /// statements, presumably these arguments start an indented `if` statement; for
3157    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3158    /// expression. In either case, what is written is not a complete syntactic structure
3159    /// in its own right, and the caller will have to finish it off if we return `true`.
3160    ///
3161    /// If no expression is written, return false.
3162    ///
3163    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3164    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3165    /// [`Store`]: crate::Statement::Store
3166    /// [`Load`]: crate::Expression::Load
3167    fn put_bounds_checks(
3168        &mut self,
3169        chain: Handle<crate::Expression>,
3170        context: &ExpressionContext,
3171        level: back::Level,
3172        prefix: &'static str,
3173    ) -> Result<bool, Error> {
3174        let mut check_written = false;
3175
3176        // Iterate over the access chain, handling each required bounds check.
3177        for item in context.bounds_check_iter(chain) {
3178            let BoundsCheck {
3179                base,
3180                index,
3181                length,
3182            } = item;
3183
3184            if check_written {
3185                write!(self.out, " && ")?;
3186            } else {
3187                write!(self.out, "{level}{prefix}")?;
3188                check_written = true;
3189            }
3190
3191            // Check that the index falls within bounds. Do this with a single
3192            // comparison, by casting the index to `uint` first, so that negative
3193            // indices become large positive values.
3194            write!(self.out, "uint(")?;
3195            self.put_index(index, context, true)?;
3196            self.out.write_str(") < ")?;
3197            match length {
3198                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3199                index::IndexableLength::Dynamic => {
3200                    let global = context.function.originating_global(base).ok_or_else(|| {
3201                        Error::GenericValidation("Could not find originating global".into())
3202                    })?;
3203                    write!(self.out, "1 + ")?;
3204                    self.put_dynamic_array_max_index(global, context)?
3205                }
3206            }
3207        }
3208
3209        Ok(check_written)
3210    }
3211
3212    /// Write the access chain `chain`.
3213    ///
3214    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3215    /// operating either on a pointer to a value, or on a value directly.
3216    ///
3217    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3218    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3219    /// that must be handled in the caller.
3220    ///
3221    /// Handle the entire chain, recursing back into `put_expression` only for index
3222    /// expressions and the base expression that originates the pointer or composite value
3223    /// being accessed. This allows `put_expression` to assume that any `Access` or
3224    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3225    /// `ReadZeroSkipWrite` checks.
3226    ///
3227    /// [`Access`]: crate::Expression::Access
3228    /// [`AccessIndex`]: crate::Expression::AccessIndex
3229    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3230    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3231    fn put_access_chain(
3232        &mut self,
3233        chain: Handle<crate::Expression>,
3234        policy: index::BoundsCheckPolicy,
3235        context: &ExpressionContext,
3236    ) -> BackendResult {
3237        match context.function.expressions[chain] {
3238            crate::Expression::Access { base, index } => {
3239                let mut base_ty = context.resolve_type(base);
3240
3241                // Look through any pointers to see what we're really indexing.
3242                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3243                    base_ty = &context.module.types[base].inner;
3244                }
3245
3246                self.put_subscripted_access_chain(
3247                    base,
3248                    base_ty,
3249                    index::GuardedIndex::Expression(index),
3250                    policy,
3251                    context,
3252                )?;
3253            }
3254            crate::Expression::AccessIndex { base, index } => {
3255                let base_resolution = &context.info[base].ty;
3256                let mut base_ty = base_resolution.inner_with(&context.module.types);
3257                let mut base_ty_handle = base_resolution.handle();
3258
3259                // Look through any pointers to see what we're really indexing.
3260                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3261                    base_ty = &context.module.types[base].inner;
3262                    base_ty_handle = Some(base);
3263                }
3264
3265                // Handle structs and anything else that can use `.x` syntax here, so
3266                // `put_subscripted_access_chain` won't have to handle the absurd case of
3267                // indexing a struct with an expression.
3268                match *base_ty {
3269                    crate::TypeInner::Struct { .. } => {
3270                        let base_ty = base_ty_handle.unwrap();
3271                        self.put_access_chain(base, policy, context)?;
3272                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3273                        write!(self.out, ".{name}")?;
3274                    }
3275                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3276                        self.put_access_chain(base, policy, context)?;
3277                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3278                        // however array indexing is
3279                        if context.get_packed_vec_kind(base).is_some() {
3280                            write!(self.out, "[{index}]")?;
3281                        } else {
3282                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3283                        }
3284                    }
3285                    _ => {
3286                        self.put_subscripted_access_chain(
3287                            base,
3288                            base_ty,
3289                            index::GuardedIndex::Known(index),
3290                            policy,
3291                            context,
3292                        )?;
3293                    }
3294                }
3295            }
3296            _ => self.put_expression(chain, context, false)?,
3297        }
3298
3299        Ok(())
3300    }
3301
3302    /// Write a `[]`-style access of `base` by `index`.
3303    ///
3304    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3305    /// values within bounds.
3306    ///
3307    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3308    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3309    /// removed. Our callers often already have this handy.
3310    ///
3311    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3312    /// referencing vector components by name.
3313    ///
3314    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3315    /// [`Array`]: crate::TypeInner::Array
3316    /// [`Vector`]: crate::TypeInner::Vector
3317    /// [`Pointer`]: crate::TypeInner::Pointer
3318    fn put_subscripted_access_chain(
3319        &mut self,
3320        base: Handle<crate::Expression>,
3321        base_ty: &crate::TypeInner,
3322        index: index::GuardedIndex,
3323        policy: index::BoundsCheckPolicy,
3324        context: &ExpressionContext,
3325    ) -> BackendResult {
3326        let accessing_wrapped_array = match *base_ty {
3327            crate::TypeInner::Array {
3328                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3329                ..
3330            } => true,
3331            _ => false,
3332        };
3333        let accessing_wrapped_binding_array =
3334            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3335
3336        self.put_access_chain(base, policy, context)?;
3337        if accessing_wrapped_array {
3338            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3339        }
3340        write!(self.out, "[")?;
3341
3342        // Decide whether this index needs to be clamped to fall within range.
3343        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3344            context.access_needs_check(base, index)
3345        } else {
3346            None
3347        };
3348        if let Some(limit) = restriction_needed {
3349            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3350            self.put_index(index, context, true)?;
3351            write!(self.out, "), ")?;
3352            match limit {
3353                index::IndexableLength::Known(limit) => {
3354                    write!(self.out, "{}u", limit - 1)?;
3355                }
3356                index::IndexableLength::Dynamic => {
3357                    let global = context.function.originating_global(base).ok_or_else(|| {
3358                        Error::GenericValidation("Could not find originating global".into())
3359                    })?;
3360                    self.put_dynamic_array_max_index(global, context)?;
3361                }
3362            }
3363            write!(self.out, ")")?;
3364        } else {
3365            self.put_index(index, context, true)?;
3366        }
3367
3368        write!(self.out, "]")?;
3369
3370        if accessing_wrapped_binding_array {
3371            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3372        }
3373
3374        Ok(())
3375    }
3376
3377    fn put_load(
3378        &mut self,
3379        pointer: Handle<crate::Expression>,
3380        context: &ExpressionContext,
3381        is_scoped: bool,
3382    ) -> BackendResult {
3383        // Since access chains never cross between address spaces, we can just
3384        // check the index bounds check policy once at the top.
3385        let policy = context.choose_bounds_check_policy(pointer);
3386        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3387            && self.put_bounds_checks(
3388                pointer,
3389                context,
3390                back::Level(0),
3391                if is_scoped { "" } else { "(" },
3392            )?
3393        {
3394            write!(self.out, " ? ")?;
3395            self.put_unchecked_load(pointer, policy, context)?;
3396            write!(self.out, " : DefaultConstructible()")?;
3397
3398            if !is_scoped {
3399                write!(self.out, ")")?;
3400            }
3401        } else {
3402            self.put_unchecked_load(pointer, policy, context)?;
3403        }
3404
3405        Ok(())
3406    }
3407
3408    fn put_unchecked_load(
3409        &mut self,
3410        pointer: Handle<crate::Expression>,
3411        policy: index::BoundsCheckPolicy,
3412        context: &ExpressionContext,
3413    ) -> BackendResult {
3414        let is_atomic_pointer = context
3415            .resolve_type(pointer)
3416            .is_atomic_pointer(&context.module.types);
3417
3418        if is_atomic_pointer {
3419            write!(
3420                self.out,
3421                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3422            )?;
3423            self.put_access_chain(pointer, policy, context)?;
3424            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3425        } else {
3426            // We don't do any dereferencing with `*` here as pointer arguments to functions
3427            // are done by `&` references and not `*` pointers. These do not need to be
3428            // dereferenced.
3429            self.put_access_chain(pointer, policy, context)?;
3430        }
3431
3432        Ok(())
3433    }
3434
3435    fn put_return_value(
3436        &mut self,
3437        level: back::Level,
3438        expr_handle: Handle<crate::Expression>,
3439        result_struct: Option<&str>,
3440        context: &ExpressionContext,
3441    ) -> BackendResult {
3442        match result_struct {
3443            Some(struct_name) => {
3444                let mut has_point_size = false;
3445                let result_ty = context.function.result.as_ref().unwrap().ty;
3446                match context.module.types[result_ty].inner {
3447                    crate::TypeInner::Struct { ref members, .. } => {
3448                        let tmp = "_tmp";
3449                        write!(self.out, "{level}const auto {tmp} = ")?;
3450                        self.put_expression(expr_handle, context, true)?;
3451                        writeln!(self.out, ";")?;
3452                        write!(self.out, "{level}return {struct_name} {{")?;
3453
3454                        let mut is_first = true;
3455
3456                        for (index, member) in members.iter().enumerate() {
3457                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3458                                member.binding
3459                            {
3460                                has_point_size = true;
3461                                if !context.pipeline_options.allow_and_force_point_size {
3462                                    continue;
3463                                }
3464                            }
3465
3466                            let comma = if is_first { "" } else { "," };
3467                            is_first = false;
3468                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3469                            // HACK: we are forcefully deduplicating the expression here
3470                            // to convert from a wrapped struct to a raw array, e.g.
3471                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3472                            if let crate::TypeInner::Array {
3473                                size: crate::ArraySize::Constant(size),
3474                                ..
3475                            } = context.module.types[member.ty].inner
3476                            {
3477                                write!(self.out, "{comma} {{")?;
3478                                for j in 0..size.get() {
3479                                    if j != 0 {
3480                                        write!(self.out, ",")?;
3481                                    }
3482                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3483                                }
3484                                write!(self.out, "}}")?;
3485                            } else {
3486                                write!(self.out, "{comma} {tmp}.{name}")?;
3487                            }
3488                        }
3489                    }
3490                    _ => {
3491                        write!(self.out, "{level}return {struct_name} {{ ")?;
3492                        self.put_expression(expr_handle, context, true)?;
3493                    }
3494                }
3495
3496                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3497                    let stage = context.module.entry_points[ep_index as usize].stage;
3498                    if context.pipeline_options.allow_and_force_point_size
3499                        && stage == crate::ShaderStage::Vertex
3500                        && !has_point_size
3501                    {
3502                        // point size was injected and comes last
3503                        write!(self.out, ", 1.0")?;
3504                    }
3505                }
3506                write!(self.out, " }}")?;
3507            }
3508            None => {
3509                write!(self.out, "{level}return ")?;
3510                self.put_expression(expr_handle, context, true)?;
3511            }
3512        }
3513        writeln!(self.out, ";")?;
3514        Ok(())
3515    }
3516
3517    /// Helper method used to find which expressions of a given function require baking
3518    ///
3519    /// # Notes
3520    /// This function overwrites the contents of `self.need_bake_expressions`
3521    fn update_expressions_to_bake(
3522        &mut self,
3523        func: &crate::Function,
3524        info: &valid::FunctionInfo,
3525        context: &ExpressionContext,
3526    ) {
3527        use crate::Expression;
3528        self.need_bake_expressions.clear();
3529
3530        for (expr_handle, expr) in func.expressions.iter() {
3531            // Expressions whose reference count is above the
3532            // threshold should always be stored in temporaries.
3533            let expr_info = &info[expr_handle];
3534            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3535            if min_ref_count <= expr_info.ref_count {
3536                self.need_bake_expressions.insert(expr_handle);
3537            } else {
3538                match expr_info.ty {
3539                    // force ray desc to be baked: it's used multiple times internally
3540                    TypeResolution::Handle(h)
3541                        if Some(h) == context.module.special_types.ray_desc =>
3542                    {
3543                        self.need_bake_expressions.insert(expr_handle);
3544                    }
3545                    _ => {}
3546                }
3547            }
3548
3549            if let Expression::Math {
3550                fun,
3551                arg,
3552                arg1,
3553                arg2,
3554                ..
3555            } = *expr
3556            {
3557                match fun {
3558                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3559                    // works on floating-point vectors, so we emit inline code for
3560                    // integer vector `dot` calls. But that code uses each argument `N`
3561                    // times, once for each component (see `put_dot_product`), so to
3562                    // avoid duplicated evaluation, we must bake integer operands.
3563                    // This applies both when using the polyfill (because of the duplicate
3564                    // evaluation issue) and when we don't use the polyfill (because we
3565                    // need them to be emitted before casting to packed chars -- see the
3566                    // comment at the call to `put_casting_to_packed_chars`).
3567                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3568                        self.need_bake_expressions.insert(arg);
3569                        self.need_bake_expressions.insert(arg1.unwrap());
3570                    }
3571                    crate::MathFunction::FirstLeadingBit => {
3572                        self.need_bake_expressions.insert(arg);
3573                    }
3574                    crate::MathFunction::Pack4xI8
3575                    | crate::MathFunction::Pack4xU8
3576                    | crate::MathFunction::Pack4xI8Clamp
3577                    | crate::MathFunction::Pack4xU8Clamp
3578                    | crate::MathFunction::Unpack4xI8
3579                    | crate::MathFunction::Unpack4xU8 => {
3580                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3581                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3582                        if context.lang_version < (2, 1) {
3583                            self.need_bake_expressions.insert(arg);
3584                        }
3585                    }
3586                    crate::MathFunction::ExtractBits => {
3587                        // Only argument 1 is re-used.
3588                        self.need_bake_expressions.insert(arg1.unwrap());
3589                    }
3590                    crate::MathFunction::InsertBits => {
3591                        // Only argument 2 is re-used.
3592                        self.need_bake_expressions.insert(arg2.unwrap());
3593                    }
3594                    crate::MathFunction::Sign => {
3595                        // WGSL's `sign` function works also on signed ints, but Metal's only
3596                        // works on floating points, so we emit inline code for integer `sign`
3597                        // calls. But that code uses each argument 2 times (see `put_isign`),
3598                        // so to avoid duplicated evaluation, we must bake the argument.
3599                        let inner = context.resolve_type(expr_handle);
3600                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3601                            self.need_bake_expressions.insert(arg);
3602                        }
3603                    }
3604                    _ => {}
3605                }
3606            }
3607        }
3608    }
3609
3610    pub(super) fn start_baking_expression(
3611        &mut self,
3612        handle: Handle<crate::Expression>,
3613        context: &ExpressionContext,
3614        name: &str,
3615    ) -> BackendResult {
3616        match context.info[handle].ty {
3617            TypeResolution::Handle(ty_handle) => {
3618                let ty_name = TypeContext {
3619                    handle: ty_handle,
3620                    gctx: context.module.to_ctx(),
3621                    names: &self.names,
3622                    access: crate::StorageAccess::empty(),
3623                    first_time: false,
3624                };
3625                write!(self.out, "{ty_name}")?;
3626            }
3627            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3628                put_numeric_type(&mut self.out, scalar, &[])?;
3629            }
3630            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3631                put_numeric_type(&mut self.out, scalar, &[size])?;
3632            }
3633            TypeResolution::Value(crate::TypeInner::Matrix {
3634                columns,
3635                rows,
3636                scalar,
3637            }) => {
3638                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3639            }
3640            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3641                columns,
3642                rows,
3643                scalar,
3644                role: _,
3645            }) => {
3646                write!(
3647                    self.out,
3648                    "{}::simdgroup_{}{}x{}",
3649                    NAMESPACE,
3650                    scalar.to_msl_name(),
3651                    columns as u32,
3652                    rows as u32,
3653                )?;
3654            }
3655            TypeResolution::Value(ref other) => {
3656                log::warn!("Type {other:?} isn't a known local");
3657                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3658            }
3659        }
3660
3661        //TODO: figure out the naming scheme that wouldn't collide with user names.
3662        write!(self.out, " {name} = ")?;
3663
3664        Ok(())
3665    }
3666
3667    /// Cache a clamped level of detail value, if necessary.
3668    ///
3669    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3670    /// properly clamped level of detail value both in the access itself, and
3671    /// for fetching the size of the requested MIP level, needed to clamp the
3672    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3673    /// it in a temporary variable, as part of the [`Emit`] statement covering
3674    /// the [`ImageLoad`] expression.
3675    ///
3676    /// [`ImageLoad`]: crate::Expression::ImageLoad
3677    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3678    /// [`Emit`]: crate::Statement::Emit
3679    fn put_cache_restricted_level(
3680        &mut self,
3681        load: Handle<crate::Expression>,
3682        image: Handle<crate::Expression>,
3683        mip_level: Option<Handle<crate::Expression>>,
3684        indent: back::Level,
3685        context: &StatementContext,
3686    ) -> BackendResult {
3687        // Does this image access actually require (or even permit) a
3688        // level-of-detail, and does the policy require us to restrict it?
3689        let level_of_detail = match mip_level {
3690            Some(level) => level,
3691            None => return Ok(()),
3692        };
3693
3694        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3695            || !context.expression.image_needs_lod(image)
3696        {
3697            return Ok(());
3698        }
3699
3700        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3701        self.put_restricted_scalar_image_index(
3702            image,
3703            level_of_detail,
3704            "get_num_mip_levels",
3705            &context.expression,
3706        )?;
3707        writeln!(self.out, ";")?;
3708
3709        Ok(())
3710    }
3711
3712    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3713    ///
3714    /// Caches the results in temporary variables (whose names are derived from
3715    /// the original variable names). This caching avoids the need to redo the
3716    /// casting for each vector component when emitting the dot product.
3717    fn put_casting_to_packed_chars(
3718        &mut self,
3719        fun: crate::MathFunction,
3720        arg0: Handle<crate::Expression>,
3721        arg1: Handle<crate::Expression>,
3722        indent: back::Level,
3723        context: &StatementContext<'_>,
3724    ) -> Result<(), Error> {
3725        let packed_type = match fun {
3726            crate::MathFunction::Dot4I8Packed => "packed_char4",
3727            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3728            _ => unreachable!(),
3729        };
3730
3731        for arg in [arg0, arg1] {
3732            write!(
3733                self.out,
3734                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3735                Reinterpreted::new(packed_type, arg)
3736            )?;
3737            self.put_expression(arg, &context.expression, true)?;
3738            writeln!(self.out, ");")?;
3739        }
3740
3741        Ok(())
3742    }
3743
3744    fn put_block(
3745        &mut self,
3746        level: back::Level,
3747        statements: &[crate::Statement],
3748        context: &StatementContext,
3749    ) -> BackendResult {
3750        // Add to the set in order to track the stack size.
3751        #[cfg(test)]
3752        self.put_block_stack_pointers
3753            .insert(ptr::from_ref(&level).cast());
3754
3755        for statement in statements {
3756            log::trace!("statement[{}] {:?}", level.0, statement);
3757            match *statement {
3758                crate::Statement::Emit(ref range) => {
3759                    for handle in range.clone() {
3760                        use crate::MathFunction as Mf;
3761
3762                        match context.expression.function.expressions[handle] {
3763                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3764                            // may need to cache a clamped version of their level-of-detail argument.
3765                            crate::Expression::ImageLoad {
3766                                image,
3767                                level: mip_level,
3768                                ..
3769                            } => {
3770                                self.put_cache_restricted_level(
3771                                    handle, image, mip_level, level, context,
3772                                )?;
3773                            }
3774
3775                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3776                            // 2.1+ then we introduce two intermediate variables that recast the two
3777                            // arguments as packed (signed or unsigned) chars. The actual dot product
3778                            // is implemented in `Self::put_expression`, and it uses both of these
3779                            // intermediate variables multiple times. There's no danger that the
3780                            // original arguments get modified between the definition of these
3781                            // intermediate variables and the implementation of the actual dot
3782                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3783                            crate::Expression::Math {
3784                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3785                                arg,
3786                                arg1,
3787                                ..
3788                            } if context.expression.lang_version >= (2, 1) => {
3789                                self.put_casting_to_packed_chars(
3790                                    fun,
3791                                    arg,
3792                                    arg1.unwrap(),
3793                                    level,
3794                                    context,
3795                                )?;
3796                            }
3797
3798                            _ => (),
3799                        }
3800
3801                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3802                        let expr_name = if ptr_class.is_some() {
3803                            None // don't bake pointer expressions (just yet)
3804                        } else if let Some(name) =
3805                            context.expression.function.named_expressions.get(&handle)
3806                        {
3807                            // The `crate::Function::named_expressions` table holds
3808                            // expressions that should be saved in temporaries once they
3809                            // are `Emit`ted. We only add them to `self.named_expressions`
3810                            // when we reach the `Emit` that covers them, so that we don't
3811                            // try to use their names before we've actually initialized
3812                            // the temporary that holds them.
3813                            //
3814                            // Don't assume the names in `named_expressions` are unique,
3815                            // or even valid. Use the `Namer`.
3816                            Some(self.namer.call(name))
3817                        } else {
3818                            // If this expression is an index that we're going to first compare
3819                            // against a limit, and then actually use as an index, then we may
3820                            // want to cache it in a temporary, to avoid evaluating it twice.
3821                            let bake = if context.expression.guarded_indices.contains(handle) {
3822                                true
3823                            } else {
3824                                self.need_bake_expressions.contains(&handle)
3825                            };
3826
3827                            if bake {
3828                                Some(Baked(handle).to_string())
3829                            } else {
3830                                None
3831                            }
3832                        };
3833
3834                        if let Some(name) = expr_name {
3835                            write!(self.out, "{level}")?;
3836                            self.start_baking_expression(handle, &context.expression, &name)?;
3837                            self.put_expression(handle, &context.expression, true)?;
3838                            self.named_expressions.insert(handle, name);
3839                            writeln!(self.out, ";")?;
3840                        }
3841                    }
3842                }
3843                crate::Statement::Block(ref block) => {
3844                    if !block.is_empty() {
3845                        writeln!(self.out, "{level}{{")?;
3846                        self.put_block(level.next(), block, context)?;
3847                        writeln!(self.out, "{level}}}")?;
3848                    }
3849                }
3850                crate::Statement::If {
3851                    condition,
3852                    ref accept,
3853                    ref reject,
3854                } => {
3855                    write!(self.out, "{level}if (")?;
3856                    self.put_expression(condition, &context.expression, true)?;
3857                    writeln!(self.out, ") {{")?;
3858                    self.put_block(level.next(), accept, context)?;
3859                    if !reject.is_empty() {
3860                        writeln!(self.out, "{level}}} else {{")?;
3861                        self.put_block(level.next(), reject, context)?;
3862                    }
3863                    writeln!(self.out, "{level}}}")?;
3864                }
3865                crate::Statement::Switch {
3866                    selector,
3867                    ref cases,
3868                } => {
3869                    write!(self.out, "{level}switch(")?;
3870                    self.put_expression(selector, &context.expression, true)?;
3871                    writeln!(self.out, ") {{")?;
3872                    let lcase = level.next();
3873                    for case in cases.iter() {
3874                        match case.value {
3875                            crate::SwitchValue::I32(value) => {
3876                                write!(self.out, "{lcase}case {value}:")?;
3877                            }
3878                            crate::SwitchValue::U32(value) => {
3879                                write!(self.out, "{lcase}case {value}u:")?;
3880                            }
3881                            crate::SwitchValue::Default => {
3882                                write!(self.out, "{lcase}default:")?;
3883                            }
3884                        }
3885
3886                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3887                        if write_block_braces {
3888                            writeln!(self.out, " {{")?;
3889                        } else {
3890                            writeln!(self.out)?;
3891                        }
3892
3893                        self.put_block(lcase.next(), &case.body, context)?;
3894                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3895                        {
3896                            writeln!(self.out, "{}break;", lcase.next())?;
3897                        }
3898
3899                        if write_block_braces {
3900                            writeln!(self.out, "{lcase}}}")?;
3901                        }
3902                    }
3903                    writeln!(self.out, "{level}}}")?;
3904                }
3905                crate::Statement::Loop {
3906                    ref body,
3907                    ref continuing,
3908                    break_if,
3909                } => {
3910                    let force_loop_bound_statements =
3911                        self.gen_force_bounded_loop_statements(level, context);
3912                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3913                        .then(|| self.namer.call("loop_init"));
3914
3915                    if let Some((ref decl, _)) = force_loop_bound_statements {
3916                        writeln!(self.out, "{decl}")?;
3917                    }
3918                    if let Some(ref gate_name) = gate_name {
3919                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3920                    }
3921
3922                    writeln!(self.out, "{level}while(true) {{",)?;
3923                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3924                        writeln!(self.out, "{break_and_inc}")?;
3925                    }
3926                    if let Some(ref gate_name) = gate_name {
3927                        let lif = level.next();
3928                        let lcontinuing = lif.next();
3929                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3930                        self.put_block(lcontinuing, continuing, context)?;
3931                        if let Some(condition) = break_if {
3932                            write!(self.out, "{lcontinuing}if (")?;
3933                            self.put_expression(condition, &context.expression, true)?;
3934                            writeln!(self.out, ") {{")?;
3935                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3936                            writeln!(self.out, "{lcontinuing}}}")?;
3937                        }
3938                        writeln!(self.out, "{lif}}}")?;
3939                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3940                    }
3941                    self.put_block(level.next(), body, context)?;
3942
3943                    writeln!(self.out, "{level}}}")?;
3944                }
3945                crate::Statement::Break => {
3946                    writeln!(self.out, "{level}break;")?;
3947                }
3948                crate::Statement::Continue => {
3949                    writeln!(self.out, "{level}continue;")?;
3950                }
3951                crate::Statement::Return {
3952                    value: Some(expr_handle),
3953                } => {
3954                    self.put_return_value(
3955                        level,
3956                        expr_handle,
3957                        context.result_struct,
3958                        &context.expression,
3959                    )?;
3960                }
3961                crate::Statement::Return { value: None } => {
3962                    writeln!(self.out, "{level}return;")?;
3963                }
3964                crate::Statement::Kill => {
3965                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3966                }
3967                crate::Statement::ControlBarrier(flags)
3968                | crate::Statement::MemoryBarrier(flags) => {
3969                    self.write_barrier(flags, level)?;
3970                }
3971                crate::Statement::Store { pointer, value } => {
3972                    self.put_store(pointer, value, level, context)?
3973                }
3974                crate::Statement::ImageStore {
3975                    image,
3976                    coordinate,
3977                    array_index,
3978                    value,
3979                } => {
3980                    let address = TexelAddress {
3981                        coordinate,
3982                        array_index,
3983                        sample: None,
3984                        level: None,
3985                    };
3986                    self.put_image_store(level, image, &address, value, context)?
3987                }
3988                crate::Statement::Call {
3989                    function,
3990                    ref arguments,
3991                    result,
3992                } => {
3993                    write!(self.out, "{level}")?;
3994                    if let Some(expr) = result {
3995                        let name = Baked(expr).to_string();
3996                        self.start_baking_expression(expr, &context.expression, &name)?;
3997                        self.named_expressions.insert(expr, name);
3998                    }
3999                    let fun_name = &self.names[&NameKey::Function(function)];
4000                    write!(self.out, "{fun_name}(")?;
4001                    // first, write down the actual arguments
4002                    for (i, &handle) in arguments.iter().enumerate() {
4003                        if i != 0 {
4004                            write!(self.out, ", ")?;
4005                        }
4006                        self.put_expression(handle, &context.expression, true)?;
4007                    }
4008                    // follow-up with any global resources used
4009                    let mut separate = !arguments.is_empty();
4010                    let fun_info = &context.expression.mod_info[function];
4011                    let mut needs_buffer_sizes = false;
4012                    for (handle, var) in context.expression.module.global_variables.iter() {
4013                        if fun_info[handle].is_empty() {
4014                            continue;
4015                        }
4016                        if var.space.needs_pass_through() {
4017                            let name = &self.names[&NameKey::GlobalVariable(handle)];
4018                            if separate {
4019                                write!(self.out, ", ")?;
4020                            } else {
4021                                separate = true;
4022                            }
4023                            write!(self.out, "{name}")?;
4024                        }
4025                        needs_buffer_sizes |=
4026                            needs_array_length(var.ty, &context.expression.module.types);
4027                    }
4028                    if needs_buffer_sizes {
4029                        if separate {
4030                            write!(self.out, ", ")?;
4031                        }
4032                        write!(self.out, "_buffer_sizes")?;
4033                    }
4034
4035                    // done
4036                    writeln!(self.out, ");")?;
4037                }
4038                crate::Statement::Atomic {
4039                    pointer,
4040                    ref fun,
4041                    value,
4042                    result,
4043                } => {
4044                    let context = &context.expression;
4045
4046                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
4047                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
4048                    // `Some`, we are not operating on a 64-bit value, and that if we are
4049                    // operating on a 64-bit value, `result` is `None`.
4050                    write!(self.out, "{level}")?;
4051                    let fun_key = if let Some(result) = result {
4052                        let res_name = Baked(result).to_string();
4053                        self.start_baking_expression(result, context, &res_name)?;
4054                        self.named_expressions.insert(result, res_name);
4055                        fun.to_msl()
4056                    } else if context.resolve_type(value).scalar_width() == Some(8) {
4057                        fun.to_msl_64_bit()?
4058                    } else {
4059                        fun.to_msl()
4060                    };
4061
4062                    // If the pointer we're passing to the atomic operation needs to be conditional
4063                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
4064                    // the pointer operand should be unchecked.
4065                    let policy = context.choose_bounds_check_policy(pointer);
4066                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4067                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4068
4069                    // If requested and successfully put bounds checks, continue the ternary expression.
4070                    if checked {
4071                        write!(self.out, " ? ")?;
4072                    }
4073
4074                    // Put the atomic function invocation.
4075                    match *fun {
4076                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4077                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4078                            self.put_access_chain(pointer, policy, context)?;
4079                            write!(self.out, ", ")?;
4080                            self.put_expression(cmp, context, true)?;
4081                            write!(self.out, ", ")?;
4082                            self.put_expression(value, context, true)?;
4083                            write!(self.out, ")")?;
4084                        }
4085                        _ => {
4086                            write!(
4087                                self.out,
4088                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4089                            )?;
4090                            self.put_access_chain(pointer, policy, context)?;
4091                            write!(self.out, ", ")?;
4092                            self.put_expression(value, context, true)?;
4093                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4094                        }
4095                    }
4096
4097                    // Finish the ternary expression.
4098                    if checked {
4099                        write!(self.out, " : DefaultConstructible()")?;
4100                    }
4101
4102                    // Done
4103                    writeln!(self.out, ";")?;
4104                }
4105                crate::Statement::ImageAtomic {
4106                    image,
4107                    coordinate,
4108                    array_index,
4109                    fun,
4110                    value,
4111                } => {
4112                    let address = TexelAddress {
4113                        coordinate,
4114                        array_index,
4115                        sample: None,
4116                        level: None,
4117                    };
4118                    self.put_image_atomic(level, image, &address, fun, value, context)?
4119                }
4120                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4121                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4122
4123                    write!(self.out, "{level}")?;
4124                    let name = self.namer.call("");
4125                    self.start_baking_expression(result, &context.expression, &name)?;
4126                    self.put_load(pointer, &context.expression, true)?;
4127                    self.named_expressions.insert(result, name);
4128
4129                    writeln!(self.out, ";")?;
4130                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4131                }
4132                crate::Statement::RayQuery { query, ref fun } => {
4133                    self.write_ray_query_stmt(level, context, query, fun)?;
4134                }
4135                crate::Statement::SubgroupBallot { result, predicate } => {
4136                    write!(self.out, "{level}")?;
4137                    let name = self.namer.call("");
4138                    self.start_baking_expression(result, &context.expression, &name)?;
4139                    self.named_expressions.insert(result, name);
4140                    write!(
4141                        self.out,
4142                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4143                    )?;
4144                    if let Some(predicate) = predicate {
4145                        self.put_expression(predicate, &context.expression, true)?;
4146                    } else {
4147                        write!(self.out, "true")?;
4148                    }
4149                    writeln!(self.out, "), 0, 0, 0);")?;
4150                }
4151                crate::Statement::SubgroupCollectiveOperation {
4152                    op,
4153                    collective_op,
4154                    argument,
4155                    result,
4156                } => {
4157                    write!(self.out, "{level}")?;
4158                    let name = self.namer.call("");
4159                    self.start_baking_expression(result, &context.expression, &name)?;
4160                    self.named_expressions.insert(result, name);
4161                    match (collective_op, op) {
4162                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4163                            write!(self.out, "{NAMESPACE}::simd_all(")?
4164                        }
4165                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4166                            write!(self.out, "{NAMESPACE}::simd_any(")?
4167                        }
4168                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4169                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4170                        }
4171                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4172                            write!(self.out, "{NAMESPACE}::simd_product(")?
4173                        }
4174                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4175                            write!(self.out, "{NAMESPACE}::simd_max(")?
4176                        }
4177                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4178                            write!(self.out, "{NAMESPACE}::simd_min(")?
4179                        }
4180                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4181                            write!(self.out, "{NAMESPACE}::simd_and(")?
4182                        }
4183                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4184                            write!(self.out, "{NAMESPACE}::simd_or(")?
4185                        }
4186                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4187                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4188                        }
4189                        (
4190                            crate::CollectiveOperation::ExclusiveScan,
4191                            crate::SubgroupOperation::Add,
4192                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4193                        (
4194                            crate::CollectiveOperation::ExclusiveScan,
4195                            crate::SubgroupOperation::Mul,
4196                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4197                        (
4198                            crate::CollectiveOperation::InclusiveScan,
4199                            crate::SubgroupOperation::Add,
4200                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4201                        (
4202                            crate::CollectiveOperation::InclusiveScan,
4203                            crate::SubgroupOperation::Mul,
4204                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4205                        _ => unimplemented!(),
4206                    }
4207                    self.put_expression(argument, &context.expression, true)?;
4208                    writeln!(self.out, ");")?;
4209                }
4210                crate::Statement::SubgroupGather {
4211                    mode,
4212                    argument,
4213                    result,
4214                } => {
4215                    write!(self.out, "{level}")?;
4216                    let name = self.namer.call("");
4217                    self.start_baking_expression(result, &context.expression, &name)?;
4218                    self.named_expressions.insert(result, name);
4219                    match mode {
4220                        crate::GatherMode::BroadcastFirst => {
4221                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4222                        }
4223                        crate::GatherMode::Broadcast(_) => {
4224                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4225                        }
4226                        crate::GatherMode::Shuffle(_) => {
4227                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4228                        }
4229                        crate::GatherMode::ShuffleDown(_) => {
4230                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4231                        }
4232                        crate::GatherMode::ShuffleUp(_) => {
4233                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4234                        }
4235                        crate::GatherMode::ShuffleXor(_) => {
4236                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4237                        }
4238                        crate::GatherMode::QuadBroadcast(_) => {
4239                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4240                        }
4241                        crate::GatherMode::QuadSwap(_) => {
4242                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4243                        }
4244                    }
4245                    self.put_expression(argument, &context.expression, true)?;
4246                    match mode {
4247                        crate::GatherMode::BroadcastFirst => {}
4248                        crate::GatherMode::Broadcast(index)
4249                        | crate::GatherMode::Shuffle(index)
4250                        | crate::GatherMode::ShuffleDown(index)
4251                        | crate::GatherMode::ShuffleUp(index)
4252                        | crate::GatherMode::ShuffleXor(index)
4253                        | crate::GatherMode::QuadBroadcast(index) => {
4254                            write!(self.out, ", ")?;
4255                            self.put_expression(index, &context.expression, true)?;
4256                        }
4257                        crate::GatherMode::QuadSwap(direction) => {
4258                            write!(self.out, ", ")?;
4259                            match direction {
4260                                crate::Direction::X => {
4261                                    write!(self.out, "1u")?;
4262                                }
4263                                crate::Direction::Y => {
4264                                    write!(self.out, "2u")?;
4265                                }
4266                                crate::Direction::Diagonal => {
4267                                    write!(self.out, "3u")?;
4268                                }
4269                            }
4270                        }
4271                    }
4272                    writeln!(self.out, ");")?;
4273                }
4274                crate::Statement::CooperativeStore { target, ref data } => {
4275                    write!(self.out, "{level}simdgroup_store(")?;
4276                    self.put_expression(target, &context.expression, true)?;
4277                    write!(self.out, ", &")?;
4278                    self.put_access_chain(
4279                        data.pointer,
4280                        context.expression.policies.index,
4281                        &context.expression,
4282                    )?;
4283                    write!(self.out, ", ")?;
4284                    self.put_expression(data.stride, &context.expression, true)?;
4285                    // See the comment in `CooperativeLoad` above: WGSL's
4286                    // row_major flag is negated when emitting Metal's
4287                    // `transpose` flag, so a col-major store (row_major=false)
4288                    // must use `transpose=true`.
4289                    if !data.row_major {
4290                        let matrix_origin = "0";
4291                        let transpose = true;
4292                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4293                    }
4294                    writeln!(self.out, ");")?;
4295                }
4296                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4297            }
4298        }
4299
4300        // un-emit expressions
4301        //TODO: take care of loop/continuing?
4302        for statement in statements {
4303            if let crate::Statement::Emit(ref range) = *statement {
4304                for handle in range.clone() {
4305                    self.named_expressions.shift_remove(&handle);
4306                }
4307            }
4308        }
4309        Ok(())
4310    }
4311
4312    fn put_store(
4313        &mut self,
4314        pointer: Handle<crate::Expression>,
4315        value: Handle<crate::Expression>,
4316        level: back::Level,
4317        context: &StatementContext,
4318    ) -> BackendResult {
4319        let policy = context.expression.choose_bounds_check_policy(pointer);
4320        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4321            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4322        {
4323            writeln!(self.out, ") {{")?;
4324            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4325            writeln!(self.out, "{level}}}")?;
4326        } else {
4327            self.put_unchecked_store(pointer, value, policy, level, context)?;
4328        }
4329
4330        Ok(())
4331    }
4332
4333    fn put_unchecked_store(
4334        &mut self,
4335        pointer: Handle<crate::Expression>,
4336        value: Handle<crate::Expression>,
4337        policy: index::BoundsCheckPolicy,
4338        level: back::Level,
4339        context: &StatementContext,
4340    ) -> BackendResult {
4341        let is_atomic_pointer = context
4342            .expression
4343            .resolve_type(pointer)
4344            .is_atomic_pointer(&context.expression.module.types);
4345
4346        if is_atomic_pointer {
4347            write!(
4348                self.out,
4349                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4350            )?;
4351            self.put_access_chain(pointer, policy, &context.expression)?;
4352            write!(self.out, ", ")?;
4353            self.put_expression(value, &context.expression, true)?;
4354            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4355        } else {
4356            write!(self.out, "{level}")?;
4357            self.put_access_chain(pointer, policy, &context.expression)?;
4358            write!(self.out, " = ")?;
4359            self.put_expression(value, &context.expression, true)?;
4360            writeln!(self.out, ";")?;
4361        }
4362
4363        Ok(())
4364    }
4365
4366    pub fn write(
4367        &mut self,
4368        module: &crate::Module,
4369        info: &valid::ModuleInfo,
4370        options: &Options,
4371        pipeline_options: &PipelineOptions,
4372    ) -> Result<TranslationInfo, Error> {
4373        self.emit_int_div_checks = options.emit_int_div_checks;
4374        self.names.clear();
4375        self.namer.reset(
4376            module,
4377            &super::keywords::RESERVED_SET,
4378            proc::KeywordSet::empty(),
4379            proc::CaseInsensitiveKeywordSet::empty(),
4380            &[
4381                CLAMPED_LOD_LOAD_PREFIX,
4382                super::ray::INTERSECTION_FUNCTION_NAME,
4383            ],
4384            &mut self.names,
4385        );
4386        self.wrapped_functions.clear();
4387        self.struct_member_pads.clear();
4388
4389        writeln!(
4390            self.out,
4391            "// language: metal{}.{}",
4392            options.lang_version.0, options.lang_version.1
4393        )?;
4394        writeln!(self.out, "#include <metal_stdlib>")?;
4395        writeln!(self.out, "#include <simd/simd.h>")?;
4396        writeln!(self.out)?;
4397        // Work around Metal bug where `uint` is not available by default
4398        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4399
4400        if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4401            return Err(Error::UnsupportedMeshShader);
4402        }
4403        self.needs_object_memory_barriers = module
4404            .entry_points
4405            .iter()
4406            .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4407
4408        if module.special_types.ray_desc.is_some()
4409            || module.special_types.ray_intersection.is_some()
4410        {
4411            if options.lang_version < (2, 4) {
4412                return Err(Error::UnsupportedRayTracing);
4413            }
4414        }
4415
4416        if options
4417            .bounds_check_policies
4418            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4419        {
4420            self.put_default_constructible()?;
4421        }
4422        writeln!(self.out)?;
4423
4424        {
4425            // Make a `Vec` of all the `GlobalVariable`s that contain
4426            // runtime-sized arrays.
4427            let globals: Vec<Handle<crate::GlobalVariable>> = module
4428                .global_variables
4429                .iter()
4430                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4431                .map(|(handle, _)| handle)
4432                .collect();
4433
4434            let mut buffer_indices = vec![];
4435            for vbm in &pipeline_options.vertex_buffer_mappings {
4436                buffer_indices.push(vbm.id);
4437            }
4438
4439            if !globals.is_empty() || !buffer_indices.is_empty() {
4440                writeln!(self.out, "struct _mslBufferSizes {{")?;
4441
4442                for global in globals {
4443                    writeln!(
4444                        self.out,
4445                        "{}uint {};",
4446                        back::INDENT,
4447                        ArraySizeMember(global)
4448                    )?;
4449                }
4450
4451                for idx in buffer_indices {
4452                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4453                }
4454
4455                writeln!(self.out, "}};")?;
4456                writeln!(self.out)?;
4457            }
4458        };
4459
4460        self.write_type_defs(module)?;
4461        self.write_global_constants(module, info)?;
4462        self.write_functions(module, info, options, pipeline_options)
4463    }
4464
4465    /// Write the definition for the `DefaultConstructible` class.
4466    ///
4467    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4468    /// produce 'zero' values for any type, including structs, arrays, and so
4469    /// on. We could do this by emitting default constructor applications, but
4470    /// that would entail printing the name of the type, which is more trouble
4471    /// than you'd think. Instead, we just construct this magic C++14 class that
4472    /// can be converted to any type that can be default constructed, using
4473    /// template parameter inference to detect which type is needed, so we don't
4474    /// have to figure out the name.
4475    ///
4476    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4477    fn put_default_constructible(&mut self) -> BackendResult {
4478        let tab = back::INDENT;
4479        writeln!(self.out, "struct DefaultConstructible {{")?;
4480        writeln!(self.out, "{tab}template<typename T>")?;
4481        writeln!(self.out, "{tab}operator T() && {{")?;
4482        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4483        writeln!(self.out, "{tab}}}")?;
4484        writeln!(self.out, "}};")?;
4485        Ok(())
4486    }
4487
4488    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4489        let mut generated_argument_buffer_wrapper = false;
4490        let mut generated_external_texture_wrapper = false;
4491        for (handle, ty) in module.types.iter() {
4492            match ty.inner {
4493                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4494                    writeln!(self.out, "template <typename T>")?;
4495                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4496                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4497                    writeln!(self.out, "}};")?;
4498                    generated_argument_buffer_wrapper = true;
4499                }
4500                crate::TypeInner::Image {
4501                    class: crate::ImageClass::External,
4502                    ..
4503                } if !generated_external_texture_wrapper => {
4504                    let params_ty_name = &self.names
4505                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4506                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4507                    writeln!(
4508                        self.out,
4509                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4510                        back::INDENT
4511                    )?;
4512                    writeln!(
4513                        self.out,
4514                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4515                        back::INDENT
4516                    )?;
4517                    writeln!(
4518                        self.out,
4519                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4520                        back::INDENT
4521                    )?;
4522                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4523                    writeln!(self.out, "}};")?;
4524                    generated_external_texture_wrapper = true;
4525                }
4526                _ => {}
4527            }
4528
4529            if !ty.needs_alias() {
4530                continue;
4531            }
4532            let name = &self.names[&NameKey::Type(handle)];
4533            match ty.inner {
4534                // Naga IR can pass around arrays by value, but Metal, following
4535                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4536                // on expressions of array type, so assigning the array by value
4537                // isn't possible. However, Metal *does* assign structs by
4538                // value. So in our Metal output, we wrap all array types in
4539                // synthetic struct types:
4540                //
4541                //     struct type1 {
4542                //         float inner[10]
4543                //     };
4544                //
4545                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4546                // any expression that actually wants access to the array.
4547                crate::TypeInner::Array {
4548                    base,
4549                    size,
4550                    stride: _,
4551                } => {
4552                    let base_name = TypeContext {
4553                        handle: base,
4554                        gctx: module.to_ctx(),
4555                        names: &self.names,
4556                        access: crate::StorageAccess::empty(),
4557                        first_time: false,
4558                    };
4559
4560                    match size.resolve(module.to_ctx())? {
4561                        proc::IndexableLength::Known(size) => {
4562                            writeln!(self.out, "struct {name} {{")?;
4563                            writeln!(
4564                                self.out,
4565                                "{}{} {}[{}];",
4566                                back::INDENT,
4567                                base_name,
4568                                WRAPPED_ARRAY_FIELD,
4569                                size
4570                            )?;
4571                            writeln!(self.out, "}};")?;
4572                        }
4573                        proc::IndexableLength::Dynamic => {
4574                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4575                        }
4576                    }
4577                }
4578                crate::TypeInner::Struct {
4579                    ref members, span, ..
4580                } => {
4581                    writeln!(self.out, "struct {name} {{")?;
4582                    let mut last_offset = 0;
4583                    for (index, member) in members.iter().enumerate() {
4584                        if member.offset > last_offset {
4585                            self.struct_member_pads.insert((handle, index as u32));
4586                            let pad = member.offset - last_offset;
4587                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4588                        }
4589                        let ty_inner = &module.types[member.ty].inner;
4590                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4591
4592                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4593
4594                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4595                        match should_pack_struct_member(members, span, index, module) {
4596                            Some(scalar) => {
4597                                writeln!(
4598                                    self.out,
4599                                    "{}{}::packed_{}3 {};",
4600                                    back::INDENT,
4601                                    NAMESPACE,
4602                                    scalar.to_msl_name(),
4603                                    member_name
4604                                )?;
4605                            }
4606                            None => {
4607                                let base_name = TypeContext {
4608                                    handle: member.ty,
4609                                    gctx: module.to_ctx(),
4610                                    names: &self.names,
4611                                    access: crate::StorageAccess::empty(),
4612                                    first_time: false,
4613                                };
4614                                writeln!(
4615                                    self.out,
4616                                    "{}{} {};",
4617                                    back::INDENT,
4618                                    base_name,
4619                                    member_name
4620                                )?;
4621
4622                                // for 3-component vectors, add one component
4623                                if let crate::TypeInner::Vector {
4624                                    size: crate::VectorSize::Tri,
4625                                    scalar,
4626                                } = *ty_inner
4627                                {
4628                                    last_offset += scalar.width as u32;
4629                                }
4630                            }
4631                        }
4632                    }
4633                    if last_offset < span {
4634                        let pad = span - last_offset;
4635                        writeln!(
4636                            self.out,
4637                            "{}char _pad{}[{}];",
4638                            back::INDENT,
4639                            members.len(),
4640                            pad
4641                        )?;
4642                    }
4643                    writeln!(self.out, "}};")?;
4644                }
4645                _ => {
4646                    let ty_name = TypeContext {
4647                        handle,
4648                        gctx: module.to_ctx(),
4649                        names: &self.names,
4650                        access: crate::StorageAccess::empty(),
4651                        first_time: true,
4652                    };
4653                    writeln!(self.out, "typedef {ty_name} {name};")?;
4654                }
4655            }
4656        }
4657
4658        // Write functions to create special types.
4659        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4660            match type_key {
4661                &crate::PredeclaredType::ModfResult { size, scalar }
4662                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4663                    let arg_type_name_owner;
4664                    let arg_type_name = if let Some(size) = size {
4665                        arg_type_name_owner = format!(
4666                            "{NAMESPACE}::{}{}",
4667                            if scalar.width == 8 { "double" } else { "float" },
4668                            size as u8
4669                        );
4670                        &arg_type_name_owner
4671                    } else if scalar.width == 8 {
4672                        "double"
4673                    } else {
4674                        "float"
4675                    };
4676
4677                    let other_type_name_owner;
4678                    let (defined_func_name, called_func_name, other_type_name) =
4679                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4680                            (MODF_FUNCTION, "modf", arg_type_name)
4681                        } else {
4682                            let other_type_name = if let Some(size) = size {
4683                                other_type_name_owner = format!("int{}", size as u8);
4684                                &other_type_name_owner
4685                            } else {
4686                                "int"
4687                            };
4688                            (FREXP_FUNCTION, "frexp", other_type_name)
4689                        };
4690
4691                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4692
4693                    writeln!(self.out)?;
4694                    writeln!(
4695                        self.out,
4696                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4697    {other_type_name} other;
4698    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4699    return {struct_name}{{ fract, other }};
4700}}"
4701                    )?;
4702                }
4703                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4704                    let arg_type_name = scalar.to_msl_name();
4705                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4706                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4707                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4708
4709                    writeln!(self.out)?;
4710
4711                    for address_space_name in ["device", "threadgroup"] {
4712                        writeln!(
4713                            self.out,
4714                            "\
4715template <typename A>
4716{struct_name} {defined_func_name}(
4717    {address_space_name} A *atomic_ptr,
4718    {arg_type_name} cmp,
4719    {arg_type_name} v
4720) {{
4721    bool swapped = {NAMESPACE}::{called_func_name}(
4722        atomic_ptr, &cmp, v,
4723        metal::memory_order_relaxed, metal::memory_order_relaxed
4724    );
4725    return {struct_name}{{cmp, swapped}};
4726}}"
4727                        )?;
4728                    }
4729                }
4730            }
4731        }
4732
4733        Ok(())
4734    }
4735
4736    /// Writes all named constants
4737    fn write_global_constants(
4738        &mut self,
4739        module: &crate::Module,
4740        mod_info: &valid::ModuleInfo,
4741    ) -> BackendResult {
4742        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4743
4744        for (handle, constant) in constants {
4745            let ty_name = TypeContext {
4746                handle: constant.ty,
4747                gctx: module.to_ctx(),
4748                names: &self.names,
4749                access: crate::StorageAccess::empty(),
4750                first_time: false,
4751            };
4752            let name = &self.names[&NameKey::Constant(handle)];
4753            write!(self.out, "constant {ty_name} {name} = ")?;
4754            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4755            writeln!(self.out, ";")?;
4756        }
4757
4758        Ok(())
4759    }
4760
4761    fn put_inline_sampler_properties(
4762        &mut self,
4763        level: back::Level,
4764        sampler: &sm::InlineSampler,
4765    ) -> BackendResult {
4766        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4767            writeln!(
4768                self.out,
4769                "{}{}::{}_address::{},",
4770                level,
4771                NAMESPACE,
4772                letter,
4773                address.as_str(),
4774            )?;
4775        }
4776        writeln!(
4777            self.out,
4778            "{}{}::mag_filter::{},",
4779            level,
4780            NAMESPACE,
4781            sampler.mag_filter.as_str(),
4782        )?;
4783        writeln!(
4784            self.out,
4785            "{}{}::min_filter::{},",
4786            level,
4787            NAMESPACE,
4788            sampler.min_filter.as_str(),
4789        )?;
4790        if let Some(filter) = sampler.mip_filter {
4791            writeln!(
4792                self.out,
4793                "{}{}::mip_filter::{},",
4794                level,
4795                NAMESPACE,
4796                filter.as_str(),
4797            )?;
4798        }
4799        // avoid setting it on platforms that don't support it
4800        if sampler.border_color != sm::BorderColor::TransparentBlack {
4801            writeln!(
4802                self.out,
4803                "{}{}::border_color::{},",
4804                level,
4805                NAMESPACE,
4806                sampler.border_color.as_str(),
4807            )?;
4808        }
4809        //TODO: I'm not able to feed this in a way that MSL likes:
4810        //>error: use of undeclared identifier 'lod_clamp'
4811        //>error: no member named 'max_anisotropy' in namespace 'metal'
4812        if false {
4813            if let Some(ref lod) = sampler.lod_clamp {
4814                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4815            }
4816            if let Some(aniso) = sampler.max_anisotropy {
4817                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4818            }
4819        }
4820        if sampler.compare_func != sm::CompareFunc::Never {
4821            writeln!(
4822                self.out,
4823                "{}{}::compare_func::{},",
4824                level,
4825                NAMESPACE,
4826                sampler.compare_func.as_str(),
4827            )?;
4828        }
4829        writeln!(
4830            self.out,
4831            "{}{}::coord::{}",
4832            level,
4833            NAMESPACE,
4834            sampler.coord.as_str()
4835        )?;
4836        Ok(())
4837    }
4838
4839    fn write_unpacking_function(
4840        &mut self,
4841        format: back::msl::VertexFormat,
4842    ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4843        use crate::{Scalar, VectorSize};
4844        use back::msl::VertexFormat::*;
4845        match format {
4846            Uint8 => {
4847                let name = self.namer.call("unpackUint8");
4848                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4849                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4850                writeln!(self.out, "}}")?;
4851                Ok((name, 1, None, Scalar::U32))
4852            }
4853            Uint8x2 => {
4854                let name = self.namer.call("unpackUint8x2");
4855                writeln!(
4856                    self.out,
4857                    "metal::uint2 {name}(metal::uchar b0, \
4858                                         metal::uchar b1) {{"
4859                )?;
4860                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4861                writeln!(self.out, "}}")?;
4862                Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4863            }
4864            Uint8x4 => {
4865                let name = self.namer.call("unpackUint8x4");
4866                writeln!(
4867                    self.out,
4868                    "metal::uint4 {name}(metal::uchar b0, \
4869                                         metal::uchar b1, \
4870                                         metal::uchar b2, \
4871                                         metal::uchar b3) {{"
4872                )?;
4873                writeln!(
4874                    self.out,
4875                    "{}return metal::uint4(b0, b1, b2, b3);",
4876                    back::INDENT
4877                )?;
4878                writeln!(self.out, "}}")?;
4879                Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4880            }
4881            Sint8 => {
4882                let name = self.namer.call("unpackSint8");
4883                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4884                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4885                writeln!(self.out, "}}")?;
4886                Ok((name, 1, None, Scalar::I32))
4887            }
4888            Sint8x2 => {
4889                let name = self.namer.call("unpackSint8x2");
4890                writeln!(
4891                    self.out,
4892                    "metal::int2 {name}(metal::uchar b0, \
4893                                        metal::uchar b1) {{"
4894                )?;
4895                writeln!(
4896                    self.out,
4897                    "{}return metal::int2(as_type<char>(b0), \
4898                                          as_type<char>(b1));",
4899                    back::INDENT
4900                )?;
4901                writeln!(self.out, "}}")?;
4902                Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4903            }
4904            Sint8x4 => {
4905                let name = self.namer.call("unpackSint8x4");
4906                writeln!(
4907                    self.out,
4908                    "metal::int4 {name}(metal::uchar b0, \
4909                                        metal::uchar b1, \
4910                                        metal::uchar b2, \
4911                                        metal::uchar b3) {{"
4912                )?;
4913                writeln!(
4914                    self.out,
4915                    "{}return metal::int4(as_type<char>(b0), \
4916                                          as_type<char>(b1), \
4917                                          as_type<char>(b2), \
4918                                          as_type<char>(b3));",
4919                    back::INDENT
4920                )?;
4921                writeln!(self.out, "}}")?;
4922                Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4923            }
4924            Unorm8 => {
4925                let name = self.namer.call("unpackUnorm8");
4926                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4927                writeln!(
4928                    self.out,
4929                    "{}return float(float(b0) / 255.0f);",
4930                    back::INDENT
4931                )?;
4932                writeln!(self.out, "}}")?;
4933                Ok((name, 1, None, Scalar::F32))
4934            }
4935            Unorm8x2 => {
4936                let name = self.namer.call("unpackUnorm8x2");
4937                writeln!(
4938                    self.out,
4939                    "metal::float2 {name}(metal::uchar b0, \
4940                                          metal::uchar b1) {{"
4941                )?;
4942                writeln!(
4943                    self.out,
4944                    "{}return metal::float2(float(b0) / 255.0f, \
4945                                            float(b1) / 255.0f);",
4946                    back::INDENT
4947                )?;
4948                writeln!(self.out, "}}")?;
4949                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4950            }
4951            Unorm8x4 => {
4952                let name = self.namer.call("unpackUnorm8x4");
4953                writeln!(
4954                    self.out,
4955                    "metal::float4 {name}(metal::uchar b0, \
4956                                          metal::uchar b1, \
4957                                          metal::uchar b2, \
4958                                          metal::uchar b3) {{"
4959                )?;
4960                writeln!(
4961                    self.out,
4962                    "{}return metal::float4(float(b0) / 255.0f, \
4963                                            float(b1) / 255.0f, \
4964                                            float(b2) / 255.0f, \
4965                                            float(b3) / 255.0f);",
4966                    back::INDENT
4967                )?;
4968                writeln!(self.out, "}}")?;
4969                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4970            }
4971            Snorm8 => {
4972                let name = self.namer.call("unpackSnorm8");
4973                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4974                writeln!(
4975                    self.out,
4976                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4977                    back::INDENT
4978                )?;
4979                writeln!(self.out, "}}")?;
4980                Ok((name, 1, None, Scalar::F32))
4981            }
4982            Snorm8x2 => {
4983                let name = self.namer.call("unpackSnorm8x2");
4984                writeln!(
4985                    self.out,
4986                    "metal::float2 {name}(metal::uchar b0, \
4987                                          metal::uchar b1) {{"
4988                )?;
4989                writeln!(
4990                    self.out,
4991                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4992                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4993                    back::INDENT
4994                )?;
4995                writeln!(self.out, "}}")?;
4996                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4997            }
4998            Snorm8x4 => {
4999                let name = self.namer.call("unpackSnorm8x4");
5000                writeln!(
5001                    self.out,
5002                    "metal::float4 {name}(metal::uchar b0, \
5003                                          metal::uchar b1, \
5004                                          metal::uchar b2, \
5005                                          metal::uchar b3) {{"
5006                )?;
5007                writeln!(
5008                    self.out,
5009                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5010                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5011                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5012                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5013                    back::INDENT
5014                )?;
5015                writeln!(self.out, "}}")?;
5016                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5017            }
5018            Uint16 => {
5019                let name = self.namer.call("unpackUint16");
5020                writeln!(
5021                    self.out,
5022                    "metal::uint {name}(metal::uint b0, \
5023                                        metal::uint b1) {{"
5024                )?;
5025                writeln!(
5026                    self.out,
5027                    "{}return metal::uint(b1 << 8 | b0);",
5028                    back::INDENT
5029                )?;
5030                writeln!(self.out, "}}")?;
5031                Ok((name, 2, None, Scalar::U32))
5032            }
5033            Uint16x2 => {
5034                let name = self.namer.call("unpackUint16x2");
5035                writeln!(
5036                    self.out,
5037                    "metal::uint2 {name}(metal::uint b0, \
5038                                         metal::uint b1, \
5039                                         metal::uint b2, \
5040                                         metal::uint b3) {{"
5041                )?;
5042                writeln!(
5043                    self.out,
5044                    "{}return metal::uint2(b1 << 8 | b0, \
5045                                           b3 << 8 | b2);",
5046                    back::INDENT
5047                )?;
5048                writeln!(self.out, "}}")?;
5049                Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5050            }
5051            Uint16x4 => {
5052                let name = self.namer.call("unpackUint16x4");
5053                writeln!(
5054                    self.out,
5055                    "metal::uint4 {name}(metal::uint b0, \
5056                                         metal::uint b1, \
5057                                         metal::uint b2, \
5058                                         metal::uint b3, \
5059                                         metal::uint b4, \
5060                                         metal::uint b5, \
5061                                         metal::uint b6, \
5062                                         metal::uint b7) {{"
5063                )?;
5064                writeln!(
5065                    self.out,
5066                    "{}return metal::uint4(b1 << 8 | b0, \
5067                                           b3 << 8 | b2, \
5068                                           b5 << 8 | b4, \
5069                                           b7 << 8 | b6);",
5070                    back::INDENT
5071                )?;
5072                writeln!(self.out, "}}")?;
5073                Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5074            }
5075            Sint16 => {
5076                let name = self.namer.call("unpackSint16");
5077                writeln!(
5078                    self.out,
5079                    "int {name}(metal::ushort b0, \
5080                                metal::ushort b1) {{"
5081                )?;
5082                writeln!(
5083                    self.out,
5084                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5085                    back::INDENT
5086                )?;
5087                writeln!(self.out, "}}")?;
5088                Ok((name, 2, None, Scalar::I32))
5089            }
5090            Sint16x2 => {
5091                let name = self.namer.call("unpackSint16x2");
5092                writeln!(
5093                    self.out,
5094                    "metal::int2 {name}(metal::ushort b0, \
5095                                        metal::ushort b1, \
5096                                        metal::ushort b2, \
5097                                        metal::ushort b3) {{"
5098                )?;
5099                writeln!(
5100                    self.out,
5101                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5102                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5103                    back::INDENT
5104                )?;
5105                writeln!(self.out, "}}")?;
5106                Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5107            }
5108            Sint16x4 => {
5109                let name = self.namer.call("unpackSint16x4");
5110                writeln!(
5111                    self.out,
5112                    "metal::int4 {name}(metal::ushort b0, \
5113                                        metal::ushort b1, \
5114                                        metal::ushort b2, \
5115                                        metal::ushort b3, \
5116                                        metal::ushort b4, \
5117                                        metal::ushort b5, \
5118                                        metal::ushort b6, \
5119                                        metal::ushort b7) {{"
5120                )?;
5121                writeln!(
5122                    self.out,
5123                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5124                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5125                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5126                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5127                    back::INDENT
5128                )?;
5129                writeln!(self.out, "}}")?;
5130                Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5131            }
5132            Unorm16 => {
5133                let name = self.namer.call("unpackUnorm16");
5134                writeln!(
5135                    self.out,
5136                    "float {name}(metal::ushort b0, \
5137                                  metal::ushort b1) {{"
5138                )?;
5139                writeln!(
5140                    self.out,
5141                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5142                    back::INDENT
5143                )?;
5144                writeln!(self.out, "}}")?;
5145                Ok((name, 2, None, Scalar::F32))
5146            }
5147            Unorm16x2 => {
5148                let name = self.namer.call("unpackUnorm16x2");
5149                writeln!(
5150                    self.out,
5151                    "metal::float2 {name}(metal::ushort b0, \
5152                                          metal::ushort b1, \
5153                                          metal::ushort b2, \
5154                                          metal::ushort b3) {{"
5155                )?;
5156                writeln!(
5157                    self.out,
5158                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5159                                            float(b3 << 8 | b2) / 65535.0f);",
5160                    back::INDENT
5161                )?;
5162                writeln!(self.out, "}}")?;
5163                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5164            }
5165            Unorm16x4 => {
5166                let name = self.namer.call("unpackUnorm16x4");
5167                writeln!(
5168                    self.out,
5169                    "metal::float4 {name}(metal::ushort b0, \
5170                                          metal::ushort b1, \
5171                                          metal::ushort b2, \
5172                                          metal::ushort b3, \
5173                                          metal::ushort b4, \
5174                                          metal::ushort b5, \
5175                                          metal::ushort b6, \
5176                                          metal::ushort b7) {{"
5177                )?;
5178                writeln!(
5179                    self.out,
5180                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5181                                            float(b3 << 8 | b2) / 65535.0f, \
5182                                            float(b5 << 8 | b4) / 65535.0f, \
5183                                            float(b7 << 8 | b6) / 65535.0f);",
5184                    back::INDENT
5185                )?;
5186                writeln!(self.out, "}}")?;
5187                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5188            }
5189            Snorm16 => {
5190                let name = self.namer.call("unpackSnorm16");
5191                writeln!(
5192                    self.out,
5193                    "float {name}(metal::ushort b0, \
5194                                  metal::ushort b1) {{"
5195                )?;
5196                writeln!(
5197                    self.out,
5198                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5199                    back::INDENT
5200                )?;
5201                writeln!(self.out, "}}")?;
5202                Ok((name, 2, None, Scalar::F32))
5203            }
5204            Snorm16x2 => {
5205                let name = self.namer.call("unpackSnorm16x2");
5206                writeln!(
5207                    self.out,
5208                    "metal::float2 {name}(uint b0, \
5209                                          uint b1, \
5210                                          uint b2, \
5211                                          uint b3) {{"
5212                )?;
5213                writeln!(
5214                    self.out,
5215                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5216                    back::INDENT
5217                )?;
5218                writeln!(self.out, "}}")?;
5219                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5220            }
5221            Snorm16x4 => {
5222                let name = self.namer.call("unpackSnorm16x4");
5223                writeln!(
5224                    self.out,
5225                    "metal::float4 {name}(uint b0, \
5226                                          uint b1, \
5227                                          uint b2, \
5228                                          uint b3, \
5229                                          uint b4, \
5230                                          uint b5, \
5231                                          uint b6, \
5232                                          uint b7) {{"
5233                )?;
5234                writeln!(
5235                    self.out,
5236                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5237                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5238                    back::INDENT
5239                )?;
5240                writeln!(self.out, "}}")?;
5241                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5242            }
5243            Float16 => {
5244                let name = self.namer.call("unpackFloat16");
5245                writeln!(
5246                    self.out,
5247                    "float {name}(metal::ushort b0, \
5248                                  metal::ushort b1) {{"
5249                )?;
5250                writeln!(
5251                    self.out,
5252                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5253                    back::INDENT
5254                )?;
5255                writeln!(self.out, "}}")?;
5256                Ok((name, 2, None, Scalar::F32))
5257            }
5258            Float16x2 => {
5259                let name = self.namer.call("unpackFloat16x2");
5260                writeln!(
5261                    self.out,
5262                    "metal::float2 {name}(metal::ushort b0, \
5263                                          metal::ushort b1, \
5264                                          metal::ushort b2, \
5265                                          metal::ushort b3) {{"
5266                )?;
5267                writeln!(
5268                    self.out,
5269                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5270                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5271                    back::INDENT
5272                )?;
5273                writeln!(self.out, "}}")?;
5274                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5275            }
5276            Float16x4 => {
5277                let name = self.namer.call("unpackFloat16x4");
5278                writeln!(
5279                    self.out,
5280                    "metal::float4 {name}(metal::ushort b0, \
5281                                        metal::ushort b1, \
5282                                        metal::ushort b2, \
5283                                        metal::ushort b3, \
5284                                        metal::ushort b4, \
5285                                        metal::ushort b5, \
5286                                        metal::ushort b6, \
5287                                        metal::ushort b7) {{"
5288                )?;
5289                writeln!(
5290                    self.out,
5291                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5292                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5293                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5294                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5295                    back::INDENT
5296                )?;
5297                writeln!(self.out, "}}")?;
5298                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5299            }
5300            Float32 => {
5301                let name = self.namer.call("unpackFloat32");
5302                writeln!(
5303                    self.out,
5304                    "float {name}(uint b0, \
5305                                  uint b1, \
5306                                  uint b2, \
5307                                  uint b3) {{"
5308                )?;
5309                writeln!(
5310                    self.out,
5311                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5312                    back::INDENT
5313                )?;
5314                writeln!(self.out, "}}")?;
5315                Ok((name, 4, None, Scalar::F32))
5316            }
5317            Float32x2 => {
5318                let name = self.namer.call("unpackFloat32x2");
5319                writeln!(
5320                    self.out,
5321                    "metal::float2 {name}(uint b0, \
5322                                          uint b1, \
5323                                          uint b2, \
5324                                          uint b3, \
5325                                          uint b4, \
5326                                          uint b5, \
5327                                          uint b6, \
5328                                          uint b7) {{"
5329                )?;
5330                writeln!(
5331                    self.out,
5332                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5333                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5334                    back::INDENT
5335                )?;
5336                writeln!(self.out, "}}")?;
5337                Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5338            }
5339            Float32x3 => {
5340                let name = self.namer.call("unpackFloat32x3");
5341                writeln!(
5342                    self.out,
5343                    "metal::float3 {name}(uint b0, \
5344                                          uint b1, \
5345                                          uint b2, \
5346                                          uint b3, \
5347                                          uint b4, \
5348                                          uint b5, \
5349                                          uint b6, \
5350                                          uint b7, \
5351                                          uint b8, \
5352                                          uint b9, \
5353                                          uint b10, \
5354                                          uint b11) {{"
5355                )?;
5356                writeln!(
5357                    self.out,
5358                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5359                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5360                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5361                    back::INDENT
5362                )?;
5363                writeln!(self.out, "}}")?;
5364                Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5365            }
5366            Float32x4 => {
5367                let name = self.namer.call("unpackFloat32x4");
5368                writeln!(
5369                    self.out,
5370                    "metal::float4 {name}(uint b0, \
5371                                          uint b1, \
5372                                          uint b2, \
5373                                          uint b3, \
5374                                          uint b4, \
5375                                          uint b5, \
5376                                          uint b6, \
5377                                          uint b7, \
5378                                          uint b8, \
5379                                          uint b9, \
5380                                          uint b10, \
5381                                          uint b11, \
5382                                          uint b12, \
5383                                          uint b13, \
5384                                          uint b14, \
5385                                          uint b15) {{"
5386                )?;
5387                writeln!(
5388                    self.out,
5389                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5390                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5391                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5392                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5393                    back::INDENT
5394                )?;
5395                writeln!(self.out, "}}")?;
5396                Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5397            }
5398            Uint32 => {
5399                let name = self.namer.call("unpackUint32");
5400                writeln!(
5401                    self.out,
5402                    "uint {name}(uint b0, \
5403                                 uint b1, \
5404                                 uint b2, \
5405                                 uint b3) {{"
5406                )?;
5407                writeln!(
5408                    self.out,
5409                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5410                    back::INDENT
5411                )?;
5412                writeln!(self.out, "}}")?;
5413                Ok((name, 4, None, Scalar::U32))
5414            }
5415            Uint32x2 => {
5416                let name = self.namer.call("unpackUint32x2");
5417                writeln!(
5418                    self.out,
5419                    "uint2 {name}(uint b0, \
5420                                  uint b1, \
5421                                  uint b2, \
5422                                  uint b3, \
5423                                  uint b4, \
5424                                  uint b5, \
5425                                  uint b6, \
5426                                  uint b7) {{"
5427                )?;
5428                writeln!(
5429                    self.out,
5430                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5431                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5432                    back::INDENT
5433                )?;
5434                writeln!(self.out, "}}")?;
5435                Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5436            }
5437            Uint32x3 => {
5438                let name = self.namer.call("unpackUint32x3");
5439                writeln!(
5440                    self.out,
5441                    "uint3 {name}(uint b0, \
5442                                  uint b1, \
5443                                  uint b2, \
5444                                  uint b3, \
5445                                  uint b4, \
5446                                  uint b5, \
5447                                  uint b6, \
5448                                  uint b7, \
5449                                  uint b8, \
5450                                  uint b9, \
5451                                  uint b10, \
5452                                  uint b11) {{"
5453                )?;
5454                writeln!(
5455                    self.out,
5456                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5457                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5458                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5459                    back::INDENT
5460                )?;
5461                writeln!(self.out, "}}")?;
5462                Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5463            }
5464            Uint32x4 => {
5465                let name = self.namer.call("unpackUint32x4");
5466                writeln!(
5467                    self.out,
5468                    "{NAMESPACE}::uint4 {name}(uint b0, \
5469                                  uint b1, \
5470                                  uint b2, \
5471                                  uint b3, \
5472                                  uint b4, \
5473                                  uint b5, \
5474                                  uint b6, \
5475                                  uint b7, \
5476                                  uint b8, \
5477                                  uint b9, \
5478                                  uint b10, \
5479                                  uint b11, \
5480                                  uint b12, \
5481                                  uint b13, \
5482                                  uint b14, \
5483                                  uint b15) {{"
5484                )?;
5485                writeln!(
5486                    self.out,
5487                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5488                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5489                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5490                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5491                    back::INDENT
5492                )?;
5493                writeln!(self.out, "}}")?;
5494                Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5495            }
5496            Sint32 => {
5497                let name = self.namer.call("unpackSint32");
5498                writeln!(
5499                    self.out,
5500                    "int {name}(uint b0, \
5501                                uint b1, \
5502                                uint b2, \
5503                                uint b3) {{"
5504                )?;
5505                writeln!(
5506                    self.out,
5507                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5508                    back::INDENT
5509                )?;
5510                writeln!(self.out, "}}")?;
5511                Ok((name, 4, None, Scalar::I32))
5512            }
5513            Sint32x2 => {
5514                let name = self.namer.call("unpackSint32x2");
5515                writeln!(
5516                    self.out,
5517                    "metal::int2 {name}(uint b0, \
5518                                        uint b1, \
5519                                        uint b2, \
5520                                        uint b3, \
5521                                        uint b4, \
5522                                        uint b5, \
5523                                        uint b6, \
5524                                        uint b7) {{"
5525                )?;
5526                writeln!(
5527                    self.out,
5528                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5529                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5530                    back::INDENT
5531                )?;
5532                writeln!(self.out, "}}")?;
5533                Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5534            }
5535            Sint32x3 => {
5536                let name = self.namer.call("unpackSint32x3");
5537                writeln!(
5538                    self.out,
5539                    "metal::int3 {name}(uint b0, \
5540                                        uint b1, \
5541                                        uint b2, \
5542                                        uint b3, \
5543                                        uint b4, \
5544                                        uint b5, \
5545                                        uint b6, \
5546                                        uint b7, \
5547                                        uint b8, \
5548                                        uint b9, \
5549                                        uint b10, \
5550                                        uint b11) {{"
5551                )?;
5552                writeln!(
5553                    self.out,
5554                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5555                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5556                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5557                    back::INDENT
5558                )?;
5559                writeln!(self.out, "}}")?;
5560                Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5561            }
5562            Sint32x4 => {
5563                let name = self.namer.call("unpackSint32x4");
5564                writeln!(
5565                    self.out,
5566                    "metal::int4 {name}(uint b0, \
5567                                        uint b1, \
5568                                        uint b2, \
5569                                        uint b3, \
5570                                        uint b4, \
5571                                        uint b5, \
5572                                        uint b6, \
5573                                        uint b7, \
5574                                        uint b8, \
5575                                        uint b9, \
5576                                        uint b10, \
5577                                        uint b11, \
5578                                        uint b12, \
5579                                        uint b13, \
5580                                        uint b14, \
5581                                        uint b15) {{"
5582                )?;
5583                writeln!(
5584                    self.out,
5585                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5586                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5587                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5588                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5589                    back::INDENT
5590                )?;
5591                writeln!(self.out, "}}")?;
5592                Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5593            }
5594            Unorm10_10_10_2 => {
5595                let name = self.namer.call("unpackUnorm10_10_10_2");
5596                writeln!(
5597                    self.out,
5598                    "metal::float4 {name}(uint b0, \
5599                                          uint b1, \
5600                                          uint b2, \
5601                                          uint b3) {{"
5602                )?;
5603                writeln!(
5604                    self.out,
5605                    // The following is correct for RGBA packing, but our format seems to
5606                    // match ABGR, which can be fed into the Metal builtin function
5607                    // unpack_unorm10a2_to_float.
5608                    /*
5609                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5610                       uint r = (v & 0xFFC00000) >> 22; \
5611                       uint g = (v & 0x003FF000) >> 12; \
5612                       uint b = (v & 0x00000FFC) >> 2; \
5613                       uint a = (v & 0x00000003); \
5614                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5615                    */
5616                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5617                    back::INDENT
5618                )?;
5619                writeln!(self.out, "}}")?;
5620                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5621            }
5622            Unorm8x4Bgra => {
5623                let name = self.namer.call("unpackUnorm8x4Bgra");
5624                writeln!(
5625                    self.out,
5626                    "metal::float4 {name}(metal::uchar b0, \
5627                                          metal::uchar b1, \
5628                                          metal::uchar b2, \
5629                                          metal::uchar b3) {{"
5630                )?;
5631                writeln!(
5632                    self.out,
5633                    "{}return metal::float4(float(b2) / 255.0f, \
5634                                            float(b1) / 255.0f, \
5635                                            float(b0) / 255.0f, \
5636                                            float(b3) / 255.0f);",
5637                    back::INDENT
5638                )?;
5639                writeln!(self.out, "}}")?;
5640                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5641            }
5642        }
5643    }
5644
5645    fn write_wrapped_unary_op(
5646        &mut self,
5647        module: &crate::Module,
5648        func_ctx: &back::FunctionCtx,
5649        op: crate::UnaryOperator,
5650        operand: Handle<crate::Expression>,
5651    ) -> BackendResult {
5652        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5653        match op {
5654            // Negating the TYPE_MIN of a two's complement signed integer
5655            // type causes overflow, which is undefined behaviour in MSL. To
5656            // avoid this we bitcast the value to unsigned and negate it,
5657            // then bitcast back to signed.
5658            // This adheres to the WGSL spec in that the negative of the
5659            // type's minimum value should equal to the minimum value.
5660            crate::UnaryOperator::Negate
5661                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5662            {
5663                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5664                    return Ok(());
5665                };
5666                let wrapped = WrappedFunction::UnaryOp {
5667                    op,
5668                    ty: (vector_size, scalar),
5669                };
5670                if !self.wrapped_functions.insert(wrapped) {
5671                    return Ok(());
5672                }
5673
5674                let unsigned_scalar = crate::Scalar {
5675                    kind: crate::ScalarKind::Uint,
5676                    ..scalar
5677                };
5678                let mut type_name = String::new();
5679                let mut unsigned_type_name = String::new();
5680                match vector_size {
5681                    None => {
5682                        put_numeric_type(&mut type_name, scalar, &[])?;
5683                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5684                    }
5685                    Some(size) => {
5686                        put_numeric_type(&mut type_name, scalar, &[size])?;
5687                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5688                    }
5689                };
5690
5691                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5692                let level = back::Level(1);
5693                // For sub-32-bit types, C++ integer promotion widens
5694                // `-as_type<ushort>(val)` to `int`, so we need static_cast
5695                // to truncate back before the outer as_type bitcast.
5696                if scalar.width < 4 {
5697                    writeln!(
5698                        self.out,
5699                        "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5700                    )?;
5701                } else {
5702                    writeln!(
5703                        self.out,
5704                        "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5705                    )?;
5706                }
5707                writeln!(self.out, "}}")?;
5708                writeln!(self.out)?;
5709            }
5710            _ => {}
5711        }
5712        Ok(())
5713    }
5714
5715    fn write_wrapped_binary_op(
5716        &mut self,
5717        module: &crate::Module,
5718        func_ctx: &back::FunctionCtx,
5719        expr: Handle<crate::Expression>,
5720        op: crate::BinaryOperator,
5721        left: Handle<crate::Expression>,
5722        right: Handle<crate::Expression>,
5723    ) -> BackendResult {
5724        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5725        let left_ty = func_ctx.resolve_type(left, &module.types);
5726        let right_ty = func_ctx.resolve_type(right, &module.types);
5727        match (op, expr_ty.scalar_kind()) {
5728            // Signed integer division of TYPE_MIN / -1, or signed or
5729            // unsigned division by zero, gives an unspecified value in MSL.
5730            // We override the divisor to 1 in these cases.
5731            // This adheres to the WGSL spec in that:
5732            // * TYPE_MIN / -1 == TYPE_MIN
5733            // * x / 0 == x
5734            (
5735                crate::BinaryOperator::Divide,
5736                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5737            ) if self.emit_int_div_checks => {
5738                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5739                    return Ok(());
5740                };
5741                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5742                    return Ok(());
5743                };
5744                let wrapped = WrappedFunction::BinaryOp {
5745                    op,
5746                    left_ty: left_wrapped_ty,
5747                    right_ty: right_wrapped_ty,
5748                };
5749                if !self.wrapped_functions.insert(wrapped) {
5750                    return Ok(());
5751                }
5752
5753                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5754                    return Ok(());
5755                };
5756                let mut type_name = String::new();
5757                match vector_size {
5758                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5759                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5760                };
5761                writeln!(
5762                    self.out,
5763                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5764                )?;
5765                let level = back::Level(1);
5766                // Sub-32-bit types need typed literal wrappers (e.g. `short(1)`)
5767                // to avoid ambiguous metal::select overloads. For >= 32-bit,
5768                // bare literals like `1`, `-1`, `0` are unambiguous.
5769                let (lp, rp) = if scalar.width < 4 {
5770                    (format!("{type_name}("), ")".to_string())
5771                } else {
5772                    (String::new(), String::new())
5773                };
5774                match scalar.kind {
5775                    crate::ScalarKind::Sint => {
5776                        let min_val = match scalar.width {
5777                            2 => crate::Literal::I16(i16::MIN),
5778                            4 => crate::Literal::I32(i32::MIN),
5779                            8 => crate::Literal::I64(i64::MIN),
5780                            _ => {
5781                                return Err(Error::GenericValidation(format!(
5782                                    "Unexpected width for scalar {scalar:?}"
5783                                )));
5784                            }
5785                        };
5786                        write!(
5787                            self.out,
5788                            "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5789                        )?;
5790                        self.put_literal(min_val)?;
5791                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5792                    }
5793                    crate::ScalarKind::Uint => {
5794                        let suffix = if scalar.width < 4 { "" } else { "u" };
5795                        writeln!(
5796                            self.out,
5797                            "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5798                        )?
5799                    }
5800                    _ => unreachable!(),
5801                }
5802                writeln!(self.out, "}}")?;
5803                writeln!(self.out)?;
5804            }
5805            // Integer modulo where one or both operands are negative, or the
5806            // divisor is zero, is undefined behaviour in MSL. To avoid this
5807            // we use the following equation:
5808            //
5809            // dividend - (dividend / divisor) * divisor
5810            //
5811            // overriding the divisor to 1 if either it is 0, or it is -1
5812            // and the dividend is TYPE_MIN.
5813            //
5814            // This adheres to the WGSL spec in that:
5815            // * TYPE_MIN % -1 == 0
5816            // * x % 0 == 0
5817            (
5818                crate::BinaryOperator::Modulo,
5819                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5820            ) if self.emit_int_div_checks => {
5821                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5822                    return Ok(());
5823                };
5824                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5825                else {
5826                    return Ok(());
5827                };
5828                let wrapped = WrappedFunction::BinaryOp {
5829                    op,
5830                    left_ty: left_wrapped_ty,
5831                    right_ty: (right_vector_size, right_scalar),
5832                };
5833                if !self.wrapped_functions.insert(wrapped) {
5834                    return Ok(());
5835                }
5836
5837                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5838                    return Ok(());
5839                };
5840                let mut type_name = String::new();
5841                match vector_size {
5842                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5843                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5844                };
5845                let mut rhs_type_name = String::new();
5846                match right_vector_size {
5847                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5848                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5849                };
5850
5851                writeln!(
5852                    self.out,
5853                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5854                )?;
5855                let level = back::Level(1);
5856                let (lp, rp) = if scalar.width < 4 {
5857                    (format!("{type_name}("), ")".to_string())
5858                } else {
5859                    (String::new(), String::new())
5860                };
5861                match scalar.kind {
5862                    crate::ScalarKind::Sint => {
5863                        let min_val = match scalar.width {
5864                            2 => crate::Literal::I16(i16::MIN),
5865                            4 => crate::Literal::I32(i32::MIN),
5866                            8 => crate::Literal::I64(i64::MIN),
5867                            _ => {
5868                                return Err(Error::GenericValidation(format!(
5869                                    "Unexpected width for scalar {scalar:?}"
5870                                )));
5871                            }
5872                        };
5873                        write!(
5874                            self.out,
5875                            "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5876                        )?;
5877                        self.put_literal(min_val)?;
5878                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5879                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5880                    }
5881                    crate::ScalarKind::Uint => {
5882                        let suffix = if scalar.width < 4 { "" } else { "u" };
5883                        writeln!(
5884                            self.out,
5885                            "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5886                        )?
5887                    }
5888                    _ => unreachable!(),
5889                }
5890                writeln!(self.out, "}}")?;
5891                writeln!(self.out)?;
5892            }
5893            _ => {}
5894        }
5895        Ok(())
5896    }
5897
5898    /// Build the mangled helper name for integer vector dot products.
5899    ///
5900    /// `scalar` must be a concrete integer scalar type.
5901    ///
5902    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5903    fn get_dot_wrapper_function_helper_name(
5904        &self,
5905        scalar: crate::Scalar,
5906        size: crate::VectorSize,
5907    ) -> String {
5908        // Check for consistency with [`super::keywords::RESERVED_SET`]
5909        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5910
5911        let type_name = scalar.to_msl_name();
5912        let size_suffix = common::vector_size_str(size);
5913        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5914    }
5915
5916    #[allow(clippy::too_many_arguments)]
5917    fn write_wrapped_math_function(
5918        &mut self,
5919        module: &crate::Module,
5920        func_ctx: &back::FunctionCtx,
5921        fun: crate::MathFunction,
5922        arg: Handle<crate::Expression>,
5923        _arg1: Option<Handle<crate::Expression>>,
5924        _arg2: Option<Handle<crate::Expression>>,
5925        _arg3: Option<Handle<crate::Expression>>,
5926    ) -> BackendResult {
5927        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5928        match fun {
5929            // Taking the absolute value of the TYPE_MIN of a two's
5930            // complement signed integer type causes overflow, which is
5931            // undefined behaviour in MSL. To avoid this, when the value is
5932            // negative we bitcast the value to unsigned and negate it, then
5933            // bitcast back to signed.
5934            // This adheres to the WGSL spec in that the absolute of the
5935            // type's minimum value should equal to the minimum value.
5936            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5937                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5938                    return Ok(());
5939                };
5940                let wrapped = WrappedFunction::Math {
5941                    fun,
5942                    arg_ty: (vector_size, scalar),
5943                };
5944                if !self.wrapped_functions.insert(wrapped) {
5945                    return Ok(());
5946                }
5947
5948                let unsigned_scalar = crate::Scalar {
5949                    kind: crate::ScalarKind::Uint,
5950                    ..scalar
5951                };
5952                let mut type_name = String::new();
5953                let mut unsigned_type_name = String::new();
5954                match vector_size {
5955                    None => {
5956                        put_numeric_type(&mut type_name, scalar, &[])?;
5957                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5958                    }
5959                    Some(size) => {
5960                        put_numeric_type(&mut type_name, scalar, &[size])?;
5961                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5962                    }
5963                };
5964
5965                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5966                let level = back::Level(1);
5967                let zero = if scalar.width < 4 {
5968                    format!("{type_name}(0)")
5969                } else {
5970                    "0".to_string()
5971                };
5972                let neg_expr = if scalar.width < 4 {
5973                    format!(
5974                        "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5975                    )
5976                } else {
5977                    format!("-as_type<{unsigned_type_name}>(val)")
5978                };
5979                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5980                writeln!(self.out, "}}")?;
5981                writeln!(self.out)?;
5982            }
5983
5984            crate::MathFunction::Dot => match *arg_ty {
5985                crate::TypeInner::Vector { size, scalar }
5986                    if matches!(
5987                        scalar.kind,
5988                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5989                    ) =>
5990                {
5991                    // De-duplicate per (fun, arg type) like other wrapped math functions
5992                    let wrapped = WrappedFunction::Math {
5993                        fun,
5994                        arg_ty: (Some(size), scalar),
5995                    };
5996                    if !self.wrapped_functions.insert(wrapped) {
5997                        return Ok(());
5998                    }
5999
6000                    let mut vec_ty = String::new();
6001                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
6002                    let mut ret_ty = String::new();
6003                    put_numeric_type(&mut ret_ty, scalar, &[])?;
6004
6005                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6006
6007                    // Emit function signature and body using put_dot_product for the expression
6008                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6009                    let level = back::Level(1);
6010                    write!(self.out, "{level}return ")?;
6011                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6012                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6013                        Ok(())
6014                    })?;
6015                    writeln!(self.out, ";")?;
6016                    writeln!(self.out, "}}")?;
6017                    writeln!(self.out)?;
6018                }
6019                _ => {}
6020            },
6021
6022            _ => {}
6023        }
6024        Ok(())
6025    }
6026
6027    fn write_wrapped_cast(
6028        &mut self,
6029        module: &crate::Module,
6030        func_ctx: &back::FunctionCtx,
6031        expr: Handle<crate::Expression>,
6032        kind: crate::ScalarKind,
6033        convert: Option<crate::Bytes>,
6034    ) -> BackendResult {
6035        // Avoid undefined behaviour when casting from a float to integer
6036        // when the value is out of range for the target type. Additionally
6037        // ensure we clamp to the correct value as per the WGSL spec.
6038        //
6039        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6040        // * If X is exactly representable in the target type T, then the
6041        //   result is that value.
6042        // * Otherwise, the result is the value in T closest to
6043        //   truncate(X) and also exactly representable in the original
6044        //   floating point type.
6045        let src_ty = func_ctx.resolve_type(expr, &module.types);
6046        let Some(width) = convert else {
6047            return Ok(());
6048        };
6049        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6050            return Ok(());
6051        };
6052        let dst_scalar = crate::Scalar { kind, width };
6053        if src_scalar.kind != crate::ScalarKind::Float
6054            || (dst_scalar.kind != crate::ScalarKind::Sint
6055                && dst_scalar.kind != crate::ScalarKind::Uint)
6056        {
6057            return Ok(());
6058        }
6059        let wrapped = WrappedFunction::Cast {
6060            src_scalar,
6061            vector_size,
6062            dst_scalar,
6063        };
6064        if !self.wrapped_functions.insert(wrapped) {
6065            return Ok(());
6066        }
6067        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6068
6069        let mut src_type_name = String::new();
6070        match vector_size {
6071            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6072            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6073        };
6074        let mut dst_type_name = String::new();
6075        match vector_size {
6076            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6077            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6078        };
6079        let fun_name = match dst_scalar {
6080            crate::Scalar::I32 => F2I32_FUNCTION,
6081            crate::Scalar::U32 => F2U32_FUNCTION,
6082            crate::Scalar::I64 => F2I64_FUNCTION,
6083            crate::Scalar::U64 => F2U64_FUNCTION,
6084            _ => unreachable!(),
6085        };
6086
6087        writeln!(
6088            self.out,
6089            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6090        )?;
6091        let level = back::Level(1);
6092        write!(
6093            self.out,
6094            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6095        )?;
6096        self.put_literal(min)?;
6097        write!(self.out, ", ")?;
6098        self.put_literal(max)?;
6099        writeln!(self.out, "));")?;
6100        writeln!(self.out, "}}")?;
6101        writeln!(self.out)?;
6102        Ok(())
6103    }
6104
6105    /// Helper function used by [`Self::write_wrapped_image_load`] and
6106    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6107    /// conversion code for external textures. Expects the preceding code to
6108    /// declare the Y component as a `float` variable of name `y`, the UV
6109    /// components as a `float2` variable of name `uv`, and the external
6110    /// texture params as a variable of name `params`. The emitted code will
6111    /// return the result.
6112    fn write_convert_yuv_to_rgb_and_return(
6113        &mut self,
6114        level: back::Level,
6115        y: &str,
6116        uv: &str,
6117        params: &str,
6118    ) -> BackendResult {
6119        let l1 = level;
6120        let l2 = l1.next();
6121
6122        // Convert from YUV to non-linear RGB in the source color space.
6123        writeln!(
6124            self.out,
6125            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6126        )?;
6127
6128        // Apply the inverse of the source transfer function to convert to
6129        // linear RGB in the source color space.
6130        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6131        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6132        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6133        writeln!(
6134            self.out,
6135            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6136        )?;
6137
6138        // Multiply by the gamut conversion matrix to convert to linear RGB in
6139        // the destination color space.
6140        writeln!(
6141            self.out,
6142            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6143        )?;
6144
6145        // Finally, apply the dest transfer function to convert to non-linear
6146        // RGB in the destination color space, and return the result.
6147        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6148        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6149        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6150        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6151
6152        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6153        Ok(())
6154    }
6155
6156    #[allow(clippy::too_many_arguments)]
6157    fn write_wrapped_image_load(
6158        &mut self,
6159        module: &crate::Module,
6160        func_ctx: &back::FunctionCtx,
6161        image: Handle<crate::Expression>,
6162        _coordinate: Handle<crate::Expression>,
6163        _array_index: Option<Handle<crate::Expression>>,
6164        _sample: Option<Handle<crate::Expression>>,
6165        _level: Option<Handle<crate::Expression>>,
6166    ) -> BackendResult {
6167        // We currently only need to wrap image loads for external textures
6168        let class = match *func_ctx.resolve_type(image, &module.types) {
6169            crate::TypeInner::Image { class, .. } => class,
6170            _ => unreachable!(),
6171        };
6172        if class != crate::ImageClass::External {
6173            return Ok(());
6174        }
6175        let wrapped = WrappedFunction::ImageLoad { class };
6176        if !self.wrapped_functions.insert(wrapped) {
6177            return Ok(());
6178        }
6179
6180        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6181        let l1 = back::Level(1);
6182        let l2 = l1.next();
6183        let l3 = l2.next();
6184        writeln!(
6185            self.out,
6186            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6187        )?;
6188        // Clamp coords to provided size of external texture to prevent OOB
6189        // read. If params.size is zero then clamp to the actual size of the
6190        // texture.
6191        writeln!(
6192            self.out,
6193            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6194        )?;
6195        writeln!(
6196            self.out,
6197            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6198        )?;
6199
6200        // Apply load transformation
6201        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6202        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6203        // For single plane, simply read from plane0
6204        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6205        writeln!(self.out, "{l1}}} else {{")?;
6206
6207        // Chroma planes may be subsampled so we must scale the coords accordingly.
6208        writeln!(
6209            self.out,
6210            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6211        )?;
6212        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6213
6214        // For multi-plane, read the Y value from plane 0
6215        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6216
6217        writeln!(self.out, "{l2}float2 uv;")?;
6218        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6219        // For 2 planes, read UV from interleaved plane 1
6220        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6221        writeln!(self.out, "{l2}}} else {{")?;
6222        // For 3 planes, read U and V from planes 1 and 2 respectively
6223        writeln!(
6224            self.out,
6225            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6226        )?;
6227        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6228        writeln!(
6229            self.out,
6230            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6231        )?;
6232        writeln!(self.out, "{l2}}}")?;
6233
6234        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6235
6236        writeln!(self.out, "{l1}}}")?;
6237        writeln!(self.out, "}}")?;
6238        writeln!(self.out)?;
6239        Ok(())
6240    }
6241
6242    #[allow(clippy::too_many_arguments)]
6243    fn write_wrapped_image_sample(
6244        &mut self,
6245        module: &crate::Module,
6246        func_ctx: &back::FunctionCtx,
6247        image: Handle<crate::Expression>,
6248        _sampler: Handle<crate::Expression>,
6249        _gather: Option<crate::SwizzleComponent>,
6250        _coordinate: Handle<crate::Expression>,
6251        _array_index: Option<Handle<crate::Expression>>,
6252        _offset: Option<Handle<crate::Expression>>,
6253        _level: crate::SampleLevel,
6254        _depth_ref: Option<Handle<crate::Expression>>,
6255        clamp_to_edge: bool,
6256    ) -> BackendResult {
6257        // We currently only need to wrap textureSampleBaseClampToEdge, for
6258        // both sampled and external textures.
6259        if !clamp_to_edge {
6260            return Ok(());
6261        }
6262        let class = match *func_ctx.resolve_type(image, &module.types) {
6263            crate::TypeInner::Image { class, .. } => class,
6264            _ => unreachable!(),
6265        };
6266        let wrapped = WrappedFunction::ImageSample {
6267            class,
6268            clamp_to_edge: true,
6269        };
6270        if !self.wrapped_functions.insert(wrapped) {
6271            return Ok(());
6272        }
6273        match class {
6274            crate::ImageClass::External => {
6275                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6276                let l1 = back::Level(1);
6277                let l2 = l1.next();
6278                let l3 = l2.next();
6279                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6280                writeln!(
6281                    self.out,
6282                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6283                )?;
6284
6285                // Calculate the sample bounds. The purported size of the texture
6286                // (params.size) is irrelevant here as we are dealing with normalized
6287                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6288                // apply the sample transformation to that, also bearing in mind that it
6289                // may contain a flip on either axis. We calculate and adjust for the
6290                // half-texel separately for each plane as it depends on the actual
6291                // texture size which may vary between planes.
6292                writeln!(
6293                    self.out,
6294                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6295                )?;
6296                writeln!(
6297                    self.out,
6298                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6299                )?;
6300                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6301                writeln!(
6302                    self.out,
6303                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6304                )?;
6305                writeln!(
6306                    self.out,
6307                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6308                )?;
6309                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6310                // For single plane, simply sample from plane0
6311                writeln!(
6312                    self.out,
6313                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6314                )?;
6315                writeln!(self.out, "{l1}}} else {{")?;
6316                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6317                writeln!(
6318                    self.out,
6319                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6320                )?;
6321                writeln!(
6322                    self.out,
6323                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6324                )?;
6325
6326                // For multi-plane, sample the Y value from plane 0
6327                writeln!(
6328                    self.out,
6329                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6330                )?;
6331                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6332                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6333                // For 2 planes, sample UV from interleaved plane 1
6334                writeln!(
6335                    self.out,
6336                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6337                )?;
6338                writeln!(self.out, "{l2}}} else {{")?;
6339                // For 3 planes, sample U and V from planes 1 and 2 respectively
6340                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6341                writeln!(
6342                    self.out,
6343                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6344                )?;
6345                writeln!(
6346                    self.out,
6347                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6348                )?;
6349                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6350                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6351                writeln!(self.out, "{l2}}}")?;
6352
6353                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6354
6355                writeln!(self.out, "{l1}}}")?;
6356                writeln!(self.out, "}}")?;
6357                writeln!(self.out)?;
6358            }
6359            _ => {
6360                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6361                let l1 = back::Level(1);
6362                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6363                writeln!(
6364                    self.out,
6365                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6366                )?;
6367                writeln!(self.out, "}}")?;
6368                writeln!(self.out)?;
6369            }
6370        }
6371        Ok(())
6372    }
6373
6374    fn write_wrapped_image_query(
6375        &mut self,
6376        module: &crate::Module,
6377        func_ctx: &back::FunctionCtx,
6378        image: Handle<crate::Expression>,
6379        query: crate::ImageQuery,
6380    ) -> BackendResult {
6381        // We currently only need to wrap size image queries for external textures
6382        if !matches!(query, crate::ImageQuery::Size { .. }) {
6383            return Ok(());
6384        }
6385        let class = match *func_ctx.resolve_type(image, &module.types) {
6386            crate::TypeInner::Image { class, .. } => class,
6387            _ => unreachable!(),
6388        };
6389        if class != crate::ImageClass::External {
6390            return Ok(());
6391        }
6392        let wrapped = WrappedFunction::ImageQuerySize { class };
6393        if !self.wrapped_functions.insert(wrapped) {
6394            return Ok(());
6395        }
6396        writeln!(
6397            self.out,
6398            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6399        )?;
6400        let l1 = back::Level(1);
6401        let l2 = l1.next();
6402        writeln!(
6403            self.out,
6404            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6405        )?;
6406        writeln!(self.out, "{l2}return tex.params.size;")?;
6407        writeln!(self.out, "{l1}}} else {{")?;
6408        // params.size == (0, 0) indicates to query and return plane 0's actual size
6409        writeln!(
6410            self.out,
6411            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6412        )?;
6413        writeln!(self.out, "{l1}}}")?;
6414        writeln!(self.out, "}}")?;
6415        writeln!(self.out)?;
6416        Ok(())
6417    }
6418
6419    fn write_wrapped_cooperative_load(
6420        &mut self,
6421        module: &crate::Module,
6422        func_ctx: &back::FunctionCtx,
6423        columns: crate::CooperativeSize,
6424        rows: crate::CooperativeSize,
6425        pointer: Handle<crate::Expression>,
6426    ) -> BackendResult {
6427        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6428        let space = ptr_ty.pointer_space().unwrap();
6429        let space_name = space.to_msl_name().unwrap_or_default();
6430        let scalar = ptr_ty
6431            .pointer_base_type()
6432            .unwrap()
6433            .inner_with(&module.types)
6434            .scalar()
6435            .unwrap();
6436        let wrapped = WrappedFunction::CooperativeLoad {
6437            space_name,
6438            columns,
6439            rows,
6440            scalar,
6441        };
6442        if !self.wrapped_functions.insert(wrapped) {
6443            return Ok(());
6444        }
6445        let scalar_name = scalar.to_msl_name();
6446        writeln!(
6447            self.out,
6448            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6449            columns as u32, rows as u32,
6450        )?;
6451        let l1 = back::Level(1);
6452        writeln!(
6453            self.out,
6454            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6455            columns as u32, rows as u32
6456        )?;
6457        let matrix_origin = "0";
6458        writeln!(
6459            self.out,
6460            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6461        )?;
6462        writeln!(self.out, "{l1}return m;")?;
6463        writeln!(self.out, "}}")?;
6464        writeln!(self.out)?;
6465        Ok(())
6466    }
6467
6468    fn write_wrapped_cooperative_multiply_add(
6469        &mut self,
6470        module: &crate::Module,
6471        func_ctx: &back::FunctionCtx,
6472        space: crate::AddressSpace,
6473        a: Handle<crate::Expression>,
6474        b: Handle<crate::Expression>,
6475    ) -> BackendResult {
6476        let space_name = space.to_msl_name().unwrap_or_default();
6477        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6478            crate::TypeInner::CooperativeMatrix {
6479                columns,
6480                rows,
6481                scalar,
6482                ..
6483            } => (columns, rows, scalar),
6484            _ => unreachable!(),
6485        };
6486        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6487            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6488            _ => unreachable!(),
6489        };
6490        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6491            space_name,
6492            columns: b_c,
6493            rows: a_r,
6494            intermediate: a_c,
6495            scalar,
6496        };
6497        if !self.wrapped_functions.insert(wrapped) {
6498            return Ok(());
6499        }
6500        let scalar_name = scalar.to_msl_name();
6501        writeln!(
6502            self.out,
6503            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6504            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6505        )?;
6506        let l1 = back::Level(1);
6507        writeln!(
6508            self.out,
6509            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6510            b_c as u32, a_r as u32
6511        )?;
6512        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6513        writeln!(self.out, "{l1}return d;")?;
6514        writeln!(self.out, "}}")?;
6515        writeln!(self.out)?;
6516        Ok(())
6517    }
6518
6519    pub(super) fn write_wrapped_functions(
6520        &mut self,
6521        module: &crate::Module,
6522        func_ctx: &back::FunctionCtx,
6523    ) -> BackendResult {
6524        for (expr_handle, expr) in func_ctx.expressions.iter() {
6525            match *expr {
6526                crate::Expression::Unary { op, expr: operand } => {
6527                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6528                }
6529                crate::Expression::Binary { op, left, right } => {
6530                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6531                }
6532                crate::Expression::Math {
6533                    fun,
6534                    arg,
6535                    arg1,
6536                    arg2,
6537                    arg3,
6538                } => {
6539                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6540                }
6541                crate::Expression::As {
6542                    expr,
6543                    kind,
6544                    convert,
6545                } => {
6546                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6547                }
6548                crate::Expression::ImageLoad {
6549                    image,
6550                    coordinate,
6551                    array_index,
6552                    sample,
6553                    level,
6554                } => {
6555                    self.write_wrapped_image_load(
6556                        module,
6557                        func_ctx,
6558                        image,
6559                        coordinate,
6560                        array_index,
6561                        sample,
6562                        level,
6563                    )?;
6564                }
6565                crate::Expression::ImageSample {
6566                    image,
6567                    sampler,
6568                    gather,
6569                    coordinate,
6570                    array_index,
6571                    offset,
6572                    level,
6573                    depth_ref,
6574                    clamp_to_edge,
6575                } => {
6576                    self.write_wrapped_image_sample(
6577                        module,
6578                        func_ctx,
6579                        image,
6580                        sampler,
6581                        gather,
6582                        coordinate,
6583                        array_index,
6584                        offset,
6585                        level,
6586                        depth_ref,
6587                        clamp_to_edge,
6588                    )?;
6589                }
6590                crate::Expression::ImageQuery { image, query } => {
6591                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6592                }
6593                crate::Expression::CooperativeLoad {
6594                    columns,
6595                    rows,
6596                    role: _,
6597                    ref data,
6598                } => {
6599                    self.write_wrapped_cooperative_load(
6600                        module,
6601                        func_ctx,
6602                        columns,
6603                        rows,
6604                        data.pointer,
6605                    )?;
6606                }
6607                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6608                    let space = crate::AddressSpace::Private;
6609                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6610                }
6611                crate::Expression::RayQueryGetIntersection { committed, .. } => {
6612                    self.write_rq_get_intersection_function(module, committed)?;
6613                }
6614                _ => {}
6615            }
6616        }
6617
6618        Ok(())
6619    }
6620
6621    // Returns the array of mapped entry point names.
6622    fn write_functions(
6623        &mut self,
6624        module: &crate::Module,
6625        mod_info: &valid::ModuleInfo,
6626        options: &Options,
6627        pipeline_options: &PipelineOptions,
6628    ) -> Result<TranslationInfo, Error> {
6629        use back::msl::VertexFormat;
6630
6631        // Define structs to hold resolved/generated data for vertex buffers and
6632        // their attributes.
6633        struct AttributeMappingResolved {
6634            ty_name: String,
6635            dimension: Option<crate::VectorSize>,
6636            scalar: crate::Scalar,
6637            name: String,
6638        }
6639        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6640
6641        struct VertexBufferMappingResolved<'a> {
6642            id: u32,
6643            stride: u32,
6644            step_mode: back::msl::VertexBufferStepMode,
6645            ty_name: String,
6646            param_name: String,
6647            elem_name: String,
6648            attributes: &'a Vec<back::msl::AttributeMapping>,
6649        }
6650        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6651
6652        // Define a struct to hold a named reference to a byte-unpacking function.
6653        struct UnpackingFunction {
6654            name: String,
6655            byte_count: u32,
6656            dimension: Option<crate::VectorSize>,
6657            scalar: crate::Scalar,
6658        }
6659        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6660
6661        // Check if we are attempting vertex pulling. If we are, generate some
6662        // names we'll need, and iterate the vertex buffer mappings to output
6663        // all the conversion functions we'll need to unpack the attribute data.
6664        // We can re-use these names for all entry points that need them, since
6665        // those entry points also use self.namer.
6666        let mut needs_vertex_id = false;
6667        let v_id = self.namer.call("v_id");
6668
6669        let mut needs_instance_id = false;
6670        let i_id = self.namer.call("i_id");
6671        if pipeline_options.vertex_pulling_transform {
6672            for vbm in &pipeline_options.vertex_buffer_mappings {
6673                let buffer_id = vbm.id;
6674                let buffer_stride = vbm.stride;
6675
6676                assert!(
6677                    buffer_stride > 0,
6678                    "Vertex pulling requires a non-zero buffer stride."
6679                );
6680
6681                match vbm.step_mode {
6682                    back::msl::VertexBufferStepMode::Constant => {}
6683                    back::msl::VertexBufferStepMode::ByVertex => {
6684                        needs_vertex_id = true;
6685                    }
6686                    back::msl::VertexBufferStepMode::ByInstance => {
6687                        needs_instance_id = true;
6688                    }
6689                }
6690
6691                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6692                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6693                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6694
6695                vbm_resolved.push(VertexBufferMappingResolved {
6696                    id: buffer_id,
6697                    stride: buffer_stride,
6698                    step_mode: vbm.step_mode,
6699                    ty_name: buffer_ty,
6700                    param_name: buffer_param,
6701                    elem_name: buffer_elem,
6702                    attributes: &vbm.attributes,
6703                });
6704
6705                // Iterate the attributes and generate needed unpacking functions.
6706                for attribute in &vbm.attributes {
6707                    if unpacking_functions.contains_key(&attribute.format) {
6708                        continue;
6709                    }
6710                    let (name, byte_count, dimension, scalar) =
6711                        match self.write_unpacking_function(attribute.format) {
6712                            Ok((name, byte_count, dimension, scalar)) => {
6713                                (name, byte_count, dimension, scalar)
6714                            }
6715                            _ => {
6716                                continue;
6717                            }
6718                        };
6719                    unpacking_functions.insert(
6720                        attribute.format,
6721                        UnpackingFunction {
6722                            name,
6723                            byte_count,
6724                            dimension,
6725                            scalar,
6726                        },
6727                    );
6728                }
6729            }
6730        }
6731
6732        let mut pass_through_globals = Vec::new();
6733        for (fun_handle, fun) in module.functions.iter() {
6734            log::trace!(
6735                "function {:?}, handle {:?}",
6736                fun.name.as_deref().unwrap_or("(anonymous)"),
6737                fun_handle
6738            );
6739
6740            let ctx = back::FunctionCtx {
6741                ty: back::FunctionType::Function(fun_handle),
6742                info: &mod_info[fun_handle],
6743                expressions: &fun.expressions,
6744                named_expressions: &fun.named_expressions,
6745            };
6746
6747            writeln!(self.out)?;
6748            self.write_wrapped_functions(module, &ctx)?;
6749
6750            let fun_info = &mod_info[fun_handle];
6751            pass_through_globals.clear();
6752            let mut needs_buffer_sizes = false;
6753            for (handle, var) in module.global_variables.iter() {
6754                if !fun_info[handle].is_empty() {
6755                    if var.space.needs_pass_through() {
6756                        pass_through_globals.push(handle);
6757                    }
6758                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6759                }
6760            }
6761
6762            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6763            match fun.result {
6764                Some(ref result) => {
6765                    let ty_name = TypeContext {
6766                        handle: result.ty,
6767                        gctx: module.to_ctx(),
6768                        names: &self.names,
6769                        access: crate::StorageAccess::empty(),
6770                        first_time: false,
6771                    };
6772                    write!(self.out, "{ty_name}")?;
6773                }
6774                None => {
6775                    write!(self.out, "void")?;
6776                }
6777            }
6778            writeln!(self.out, " {fun_name}(")?;
6779
6780            for (index, arg) in fun.arguments.iter().enumerate() {
6781                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6782                let param_type_name = TypeContext {
6783                    handle: arg.ty,
6784                    gctx: module.to_ctx(),
6785                    names: &self.names,
6786                    access: crate::StorageAccess::empty(),
6787                    first_time: false,
6788                };
6789                let separator = separate(
6790                    !pass_through_globals.is_empty()
6791                        || index + 1 != fun.arguments.len()
6792                        || needs_buffer_sizes,
6793                );
6794                writeln!(
6795                    self.out,
6796                    "{}{} {}{}",
6797                    back::INDENT,
6798                    param_type_name,
6799                    name,
6800                    separator
6801                )?;
6802            }
6803            for (index, &handle) in pass_through_globals.iter().enumerate() {
6804                let tyvar = TypedGlobalVariable {
6805                    module,
6806                    names: &self.names,
6807                    handle,
6808                    usage: fun_info[handle],
6809                    reference: true,
6810                };
6811                let separator =
6812                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6813                write!(self.out, "{}", back::INDENT)?;
6814                tyvar.try_fmt(&mut self.out)?;
6815                writeln!(self.out, "{separator}")?;
6816            }
6817
6818            if needs_buffer_sizes {
6819                writeln!(
6820                    self.out,
6821                    "{}constant _mslBufferSizes& _buffer_sizes",
6822                    back::INDENT
6823                )?;
6824            }
6825
6826            writeln!(self.out, ") {{")?;
6827
6828            let guarded_indices =
6829                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6830
6831            let context = StatementContext {
6832                expression: ExpressionContext {
6833                    function: fun,
6834                    origin: FunctionOrigin::Handle(fun_handle),
6835                    info: fun_info,
6836                    lang_version: options.lang_version,
6837                    policies: options.bounds_check_policies,
6838                    guarded_indices,
6839                    module,
6840                    mod_info,
6841                    pipeline_options,
6842                    force_loop_bounding: options.force_loop_bounding,
6843                    emit_int_div_checks: options.emit_int_div_checks,
6844                },
6845                result_struct: None,
6846            };
6847
6848            self.put_locals(&context.expression)?;
6849            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6850            self.put_block(back::Level(1), &fun.body, &context)?;
6851            writeln!(self.out, "}}")?;
6852            self.named_expressions.clear();
6853        }
6854
6855        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6856            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6857
6858        let mut info = TranslationInfo {
6859            entry_point_names: Vec::with_capacity(ep_range.len()),
6860        };
6861
6862        for ep_index in ep_range {
6863            let ep = &module.entry_points[ep_index];
6864            let fun = &ep.function;
6865            let fun_info = mod_info.get_entry_point(ep_index);
6866            let mut ep_error = None;
6867
6868            // For vertex_id and instance_id arguments, presume that we'll
6869            // use our generated names, but switch to the name of an
6870            // existing @builtin param, if we find one.
6871            let mut v_existing_id = None;
6872            let mut i_existing_id = None;
6873
6874            log::trace!(
6875                "entry point {:?}, index {:?}",
6876                fun.name.as_deref().unwrap_or("(anonymous)"),
6877                ep_index
6878            );
6879
6880            let ctx = back::FunctionCtx {
6881                ty: back::FunctionType::EntryPoint(ep_index as u16),
6882                info: fun_info,
6883                expressions: &fun.expressions,
6884                named_expressions: &fun.named_expressions,
6885            };
6886
6887            self.write_wrapped_functions(module, &ctx)?;
6888
6889            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6890                crate::ShaderStage::Vertex => (
6891                    Some("vertex"),
6892                    LocationMode::VertexInput,
6893                    LocationMode::VertexOutput,
6894                    true,
6895                ),
6896                crate::ShaderStage::Fragment => (
6897                    Some("fragment"),
6898                    LocationMode::FragmentInput,
6899                    LocationMode::FragmentOutput,
6900                    false,
6901                ),
6902                crate::ShaderStage::Compute => (
6903                    Some("kernel"),
6904                    LocationMode::Uniform,
6905                    LocationMode::Uniform,
6906                    false,
6907                ),
6908                crate::ShaderStage::Task => {
6909                    (None, LocationMode::Uniform, LocationMode::Uniform, false)
6910                }
6911                crate::ShaderStage::Mesh => {
6912                    (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6913                }
6914                crate::ShaderStage::RayGeneration
6915                | crate::ShaderStage::AnyHit
6916                | crate::ShaderStage::ClosestHit
6917                | crate::ShaderStage::Miss => unimplemented!(),
6918            };
6919
6920            // Should this entry point be modified to do vertex pulling?
6921            let do_vertex_pulling = can_vertex_pull
6922                && pipeline_options.vertex_pulling_transform
6923                && !pipeline_options.vertex_buffer_mappings.is_empty();
6924
6925            // Is any global variable used by this entry point dynamically sized?
6926            let needs_buffer_sizes = do_vertex_pulling
6927                || module
6928                    .global_variables
6929                    .iter()
6930                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6931                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6932
6933            // skip this entry point if any global bindings are missing,
6934            // or their types are incompatible.
6935            if !options.fake_missing_bindings {
6936                for (var_handle, var) in module.global_variables.iter() {
6937                    if fun_info[var_handle].is_empty() {
6938                        continue;
6939                    }
6940                    match var.space {
6941                        crate::AddressSpace::Uniform
6942                        | crate::AddressSpace::Storage { .. }
6943                        | crate::AddressSpace::Handle => {
6944                            let br = match var.binding {
6945                                Some(ref br) => br,
6946                                None => {
6947                                    let var_name = var.name.clone().unwrap_or_default();
6948                                    ep_error =
6949                                        Some(super::EntryPointError::MissingBinding(var_name));
6950                                    break;
6951                                }
6952                            };
6953                            let target = options.get_resource_binding_target(ep, br);
6954                            let good = match target {
6955                                Some(target) => {
6956                                    // We intentionally don't dereference binding_arrays here,
6957                                    // so that binding arrays fall to the buffer location.
6958
6959                                    match module.types[var.ty].inner {
6960                                        crate::TypeInner::Image {
6961                                            class: crate::ImageClass::External,
6962                                            ..
6963                                        } => target.external_texture.is_some(),
6964                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6965                                        crate::TypeInner::Sampler { .. } => {
6966                                            target.sampler.is_some()
6967                                        }
6968                                        _ => target.buffer.is_some(),
6969                                    }
6970                                }
6971                                None => false,
6972                            };
6973                            if !good {
6974                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6975                                break;
6976                            }
6977                        }
6978                        crate::AddressSpace::Immediate => {
6979                            if let Err(e) = options.resolve_immediates(ep) {
6980                                ep_error = Some(e);
6981                                break;
6982                            }
6983                        }
6984                        crate::AddressSpace::Function
6985                        | crate::AddressSpace::Private
6986                        | crate::AddressSpace::WorkGroup
6987                        | crate::AddressSpace::TaskPayload => {}
6988                        crate::AddressSpace::RayPayload
6989                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6990                    }
6991                }
6992                if needs_buffer_sizes {
6993                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6994                        ep_error = Some(err);
6995                    }
6996                }
6997            }
6998
6999            if let Some(err) = ep_error {
7000                info.entry_point_names.push(Err(err));
7001                continue;
7002            }
7003            let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7004            info.entry_point_names.push(Ok(fun_name.clone()));
7005
7006            writeln!(self.out)?;
7007
7008            // Since `Namer.reset` wasn't expecting struct members to be
7009            // suddenly injected into another namespace like this,
7010            // `self.names` doesn't keep them distinct from other variables.
7011            // Generate fresh names for these arguments, and remember the
7012            // mapping.
7013            let mut flattened_member_names = FastHashMap::default();
7014            // Varyings' members get their own namespace
7015            let mut varyings_namer = proc::Namer::default();
7016
7017            let mut empty_names = FastHashMap::default(); // Create a throwaway map
7018            varyings_namer.reset(
7019                module,
7020                &super::keywords::RESERVED_SET,
7021                proc::KeywordSet::empty(),
7022                proc::CaseInsensitiveKeywordSet::empty(),
7023                &[CLAMPED_LOD_LOAD_PREFIX],
7024                &mut empty_names,
7025            );
7026
7027            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7028            // flattening structs into their members. In Metal, we will pass
7029            // each of these values to the entry point as a separate argument—
7030            // except for the varyings, handled next.
7031            let mut flattened_arguments = Vec::new();
7032            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7033                match module.types[arg.ty].inner {
7034                    crate::TypeInner::Struct { ref members, .. } => {
7035                        for (member_index, member) in members.iter().enumerate() {
7036                            let member_index = member_index as u32;
7037                            flattened_arguments.push((
7038                                NameKey::StructMember(arg.ty, member_index),
7039                                member.ty,
7040                                member.binding.as_ref(),
7041                            ));
7042                            let name_key = NameKey::StructMember(arg.ty, member_index);
7043                            let name = match member.binding {
7044                                Some(crate::Binding::Location { .. }) => {
7045                                    if do_vertex_pulling {
7046                                        self.namer.call(&self.names[&name_key])
7047                                    } else {
7048                                        varyings_namer.call(&self.names[&name_key])
7049                                    }
7050                                }
7051                                _ => self.namer.call(&self.names[&name_key]),
7052                            };
7053                            flattened_member_names.insert(name_key, name);
7054                        }
7055                    }
7056                    _ => flattened_arguments.push((
7057                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7058                        arg.ty,
7059                        arg.binding.as_ref(),
7060                    )),
7061                }
7062            }
7063
7064            // Identify the varyings among the argument values, and maybe emit
7065            // a struct type named `<fun>Input` to hold them. If we are doing
7066            // vertex pulling, we instead update our attribute mapping to
7067            // note the types, names, and zero values of the attributes.
7068            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7069            let varyings_member_name = self.namer.call("varyings");
7070            let mut has_varyings = false;
7071
7072            if !flattened_arguments.is_empty() {
7073                if !do_vertex_pulling {
7074                    writeln!(self.out, "struct {stage_in_name} {{")?;
7075                }
7076                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7077                    let Some(binding) = binding else {
7078                        continue;
7079                    };
7080                    let name = match *name_key {
7081                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7082                        _ => &self.names[name_key],
7083                    };
7084                    let ty_name = TypeContext {
7085                        handle: ty,
7086                        gctx: module.to_ctx(),
7087                        names: &self.names,
7088                        access: crate::StorageAccess::empty(),
7089                        first_time: false,
7090                    };
7091                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7092                    let location = match *binding {
7093                        crate::Binding::Location { location, .. } => Some(location),
7094                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7095                        crate::Binding::BuiltIn(_) => continue,
7096                    };
7097                    if do_vertex_pulling {
7098                        let Some(location) = location else {
7099                            continue;
7100                        };
7101                        // Update our attribute mapping.
7102                        am_resolved.insert(
7103                            location,
7104                            AttributeMappingResolved {
7105                                ty_name: ty_name.to_string(),
7106                                dimension: ty_name.vector_size(),
7107                                scalar: ty_name.scalar().unwrap(),
7108                                name: name.to_string(),
7109                            },
7110                        );
7111                    } else {
7112                        has_varyings = true;
7113                        if let super::ResolvedBinding::User {
7114                            prefix,
7115                            index,
7116                            interpolation: Some(super::ResolvedInterpolation::PerVertex),
7117                        } = resolved
7118                        {
7119                            if options.lang_version < (4, 0) {
7120                                return Err(Error::PerVertexNotSupported);
7121                            }
7122                            write!(
7123                                self.out,
7124                                "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7125                                back::INDENT,
7126                                ty_name.unwrap_array()
7127                            )?;
7128                        } else {
7129                            write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7130                            resolved.try_fmt(&mut self.out)?;
7131                        }
7132                        writeln!(self.out, ";")?;
7133                    }
7134                }
7135                if !do_vertex_pulling {
7136                    writeln!(self.out, "}};")?;
7137                }
7138            }
7139
7140            // Define a struct type named for the return value, if any, named
7141            // `<fun>Output`.
7142            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7143            let result_member_name = self.namer.call("member");
7144            let result_type_name = match fun.result {
7145                Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7146                    let mut result_members = Vec::new();
7147                    if let crate::TypeInner::Struct { ref members, .. } =
7148                        module.types[result.ty].inner
7149                    {
7150                        for (member_index, member) in members.iter().enumerate() {
7151                            result_members.push((
7152                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7153                                member.ty,
7154                                member.binding.as_ref(),
7155                            ));
7156                        }
7157                    } else {
7158                        result_members.push((
7159                            &result_member_name,
7160                            result.ty,
7161                            result.binding.as_ref(),
7162                        ));
7163                    }
7164
7165                    writeln!(self.out, "struct {stage_out_name} {{")?;
7166                    let mut has_point_size = false;
7167                    for (name, ty, binding) in result_members {
7168                        let ty_name = TypeContext {
7169                            handle: ty,
7170                            gctx: module.to_ctx(),
7171                            names: &self.names,
7172                            access: crate::StorageAccess::empty(),
7173                            first_time: true,
7174                        };
7175                        let binding = binding.ok_or_else(|| {
7176                            Error::GenericValidation("Expected binding, got None".into())
7177                        })?;
7178
7179                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7180                            has_point_size = true;
7181                            if !pipeline_options.allow_and_force_point_size {
7182                                continue;
7183                            }
7184                        }
7185
7186                        let array_len = match module.types[ty].inner {
7187                            crate::TypeInner::Array {
7188                                size: crate::ArraySize::Constant(size),
7189                                ..
7190                            } => Some(size),
7191                            _ => None,
7192                        };
7193                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7194                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7195                        resolved.try_fmt(&mut self.out)?;
7196                        if let Some(array_len) = array_len {
7197                            write!(self.out, " [{array_len}]")?;
7198                        }
7199                        writeln!(self.out, ";")?;
7200                    }
7201
7202                    if pipeline_options.allow_and_force_point_size
7203                        && ep.stage == crate::ShaderStage::Vertex
7204                        && !has_point_size
7205                    {
7206                        // inject the point size output last
7207                        writeln!(
7208                            self.out,
7209                            "{}float _point_size [[point_size]];",
7210                            back::INDENT
7211                        )?;
7212                    }
7213                    writeln!(self.out, "}};")?;
7214                    &stage_out_name
7215                }
7216                Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7217                    assert_eq!(
7218                        module.types[result.ty].inner,
7219                        crate::TypeInner::Vector {
7220                            size: crate::VectorSize::Tri,
7221                            scalar: crate::Scalar::U32
7222                        }
7223                    );
7224
7225                    "metal::uint3"
7226                }
7227                _ => "void",
7228            };
7229
7230            let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7231                Some(self.write_mesh_output_types(
7232                    mesh_info,
7233                    &fun_name,
7234                    module,
7235                    pipeline_options.allow_and_force_point_size,
7236                    options,
7237                )?)
7238            } else {
7239                None
7240            };
7241
7242            // If we're doing a vertex pulling transform, define the buffer
7243            // structure types.
7244            if do_vertex_pulling {
7245                for vbm in &vbm_resolved {
7246                    let buffer_stride = vbm.stride;
7247                    let buffer_ty = &vbm.ty_name;
7248
7249                    // Define a structure of bytes of the appropriate size.
7250                    // When we access the attributes, we'll be unpacking these
7251                    // bytes at some offset.
7252                    writeln!(
7253                        self.out,
7254                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7255                    )?;
7256                }
7257            }
7258
7259            let is_wrapped = matches!(
7260                ep.stage,
7261                crate::ShaderStage::Task | crate::ShaderStage::Mesh
7262            );
7263            let fun_name = fun_name.clone();
7264            let nested_fun_name = if is_wrapped {
7265                self.namer.call(&format!("_{fun_name}"))
7266            } else {
7267                fun_name.clone()
7268            };
7269
7270            // https://web.archive.org/web/20181029003926/https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
7271            if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7272                let total_threads =
7273                    ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7274                write!(
7275                    self.out,
7276                    "[[max_total_threads_per_threadgroup({total_threads})]] "
7277                )?;
7278            }
7279
7280            // Write the entry point function's name, and begin its argument list.
7281            if let Some(em_str) = em_str {
7282                write!(self.out, "{em_str} ")?;
7283            }
7284            writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7285
7286            let mut args = Vec::new();
7287
7288            // If we have produced a struct holding the `EntryPoint`'s
7289            // `Function`'s arguments' varyings, pass that struct first.
7290            if has_varyings {
7291                args.push(EntryPointArgument {
7292                    ty_name: stage_in_name,
7293                    name: varyings_member_name.clone(),
7294                    binding: " [[stage_in]]".to_string(),
7295                    init: None,
7296                });
7297            }
7298
7299            let mut local_invocation_index = None;
7300
7301            // Then pass the remaining arguments not included in the varyings
7302            // struct.
7303            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7304                let binding = match binding {
7305                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7306                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7307                    _ => continue,
7308                };
7309                let name = match *name_key {
7310                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7311                    _ => &self.names[name_key],
7312                };
7313
7314                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7315                    local_invocation_index = Some(name_key);
7316                }
7317
7318                let ty_name = TypeContext {
7319                    handle: ty,
7320                    gctx: module.to_ctx(),
7321                    names: &self.names,
7322                    access: crate::StorageAccess::empty(),
7323                    first_time: false,
7324                };
7325
7326                match *binding {
7327                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7328                        v_existing_id = Some(name.clone());
7329                    }
7330                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7331                        i_existing_id = Some(name.clone());
7332                    }
7333                    _ => {}
7334                };
7335
7336                let resolved = options.resolve_local_binding(binding, in_mode)?;
7337                let mut binding = String::new();
7338                resolved.try_fmt(&mut binding)?;
7339
7340                args.push(EntryPointArgument {
7341                    ty_name: format!("{ty_name}"),
7342                    name: name.clone(),
7343                    binding,
7344                    init: None,
7345                });
7346            }
7347
7348            let need_workgroup_variables_initialization =
7349                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7350
7351            if local_invocation_index.is_none()
7352                && (need_workgroup_variables_initialization
7353                    || ep.stage == crate::ShaderStage::Task
7354                    || ep.stage == crate::ShaderStage::Mesh)
7355            {
7356                args.push(EntryPointArgument {
7357                    ty_name: "uint".to_string(),
7358                    name: "__local_invocation_index".to_string(),
7359                    binding: " [[thread_index_in_threadgroup]]".to_string(),
7360                    init: None,
7361                });
7362            }
7363
7364            // Those global variables used by this entry point and its callees
7365            // get passed as arguments. `Private` globals are an exception, they
7366            // don't outlive this invocation, so we declare them below as locals
7367            // within the entry point.
7368            for (handle, var) in module.global_variables.iter() {
7369                let usage = fun_info[handle];
7370                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7371                    continue;
7372                }
7373
7374                if options.lang_version < (1, 2) {
7375                    match var.space {
7376                        // This restriction is not documented in the MSL spec
7377                        // but validation will fail if it is not upheld.
7378                        //
7379                        // We infer the required version from the "Function
7380                        // Buffer Read-Writes" section of [what's new], where
7381                        // the feature sets listed correspond with the ones
7382                        // supporting MSL 1.2.
7383                        //
7384                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7385                        crate::AddressSpace::Storage { access }
7386                            if access.contains(crate::StorageAccess::STORE)
7387                                && ep.stage == crate::ShaderStage::Fragment =>
7388                        {
7389                            return Err(Error::UnsupportedWritableStorageBuffer)
7390                        }
7391                        crate::AddressSpace::Handle => {
7392                            match module.types[var.ty].inner {
7393                                crate::TypeInner::Image {
7394                                    class: crate::ImageClass::Storage { access, .. },
7395                                    ..
7396                                } => {
7397                                    // This restriction is not documented in the MSL spec
7398                                    // but validation will fail if it is not upheld.
7399                                    //
7400                                    // We infer the required version from the "Function
7401                                    // Texture Read-Writes" section of [what's new], where
7402                                    // the feature sets listed correspond with the ones
7403                                    // supporting MSL 1.2.
7404                                    //
7405                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7406                                    if access.contains(crate::StorageAccess::STORE)
7407                                        && (ep.stage == crate::ShaderStage::Vertex
7408                                            || ep.stage == crate::ShaderStage::Fragment)
7409                                    {
7410                                        return Err(Error::UnsupportedWritableStorageTexture(
7411                                            ep.stage,
7412                                        ));
7413                                    }
7414
7415                                    if access.contains(
7416                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7417                                    ) {
7418                                        return Err(Error::UnsupportedRWStorageTexture);
7419                                    }
7420                                }
7421                                _ => {}
7422                            }
7423                        }
7424                        _ => {}
7425                    }
7426                }
7427
7428                // Check min MSL version for binding arrays
7429                match var.space {
7430                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7431                        crate::TypeInner::BindingArray { base, .. } => {
7432                            match module.types[base].inner {
7433                                crate::TypeInner::Sampler { .. } => {
7434                                    if options.lang_version < (2, 0) {
7435                                        return Err(Error::UnsupportedArrayOf(
7436                                            "samplers".to_string(),
7437                                        ));
7438                                    }
7439                                }
7440                                crate::TypeInner::Image { class, .. } => match class {
7441                                    crate::ImageClass::Sampled { .. }
7442                                    | crate::ImageClass::Depth { .. }
7443                                    | crate::ImageClass::Storage {
7444                                        access: crate::StorageAccess::LOAD,
7445                                        ..
7446                                    } => {
7447                                        // Array of textures since:
7448                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7449                                        // - macOS: Metal 2
7450
7451                                        if options.lang_version < (2, 0) {
7452                                            return Err(Error::UnsupportedArrayOf(
7453                                                "textures".to_string(),
7454                                            ));
7455                                        }
7456                                    }
7457                                    crate::ImageClass::Storage {
7458                                        access: crate::StorageAccess::STORE,
7459                                        ..
7460                                    } => {
7461                                        // Array of write-only textures since:
7462                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7463                                        // - macOS: Metal 2
7464
7465                                        if options.lang_version < (2, 0) {
7466                                            return Err(Error::UnsupportedArrayOf(
7467                                                "write-only textures".to_string(),
7468                                            ));
7469                                        }
7470                                    }
7471                                    crate::ImageClass::Storage { .. } => {
7472                                        if options.lang_version < (3, 0) {
7473                                            return Err(Error::UnsupportedArrayOf(
7474                                                "read-write textures".to_string(),
7475                                            ));
7476                                        }
7477                                    }
7478                                    crate::ImageClass::External => {
7479                                        return Err(Error::UnsupportedArrayOf(
7480                                            "external textures".to_string(),
7481                                        ));
7482                                    }
7483                                },
7484                                _ => {
7485                                    return Err(Error::UnsupportedArrayOfType(base));
7486                                }
7487                            }
7488                        }
7489                        _ => {}
7490                    },
7491                    _ => {}
7492                }
7493
7494                // the resolves have already been checked for `!fake_missing_bindings` case
7495                let resolved = match var.space {
7496                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7497                    crate::AddressSpace::WorkGroup => None,
7498                    crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7499                    _ => options
7500                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7501                        .ok(),
7502                };
7503                if let Some(ref resolved) = resolved {
7504                    // Inline samplers are be defined in the EP body
7505                    if resolved.as_inline_sampler(options).is_some() {
7506                        continue;
7507                    }
7508                }
7509
7510                match module.types[var.ty].inner {
7511                    crate::TypeInner::Image {
7512                        class: crate::ImageClass::External,
7513                        ..
7514                    } => {
7515                        // External texture global variables get lowered to 3 textures
7516                        // and a constant buffer. We must emit a separate argument for
7517                        // each of these.
7518                        let target = match resolved {
7519                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7520                                target.external_texture
7521                            }
7522                            _ => None,
7523                        };
7524
7525                        for i in 0..3 {
7526                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7527                                handle,
7528                                ExternalTextureNameKey::Plane(i),
7529                            )];
7530                            let ty_name = format!(
7531                                "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7532                            );
7533                            let name = plane_name.clone();
7534                            let binding = if let Some(ref target) = target {
7535                                format!(" [[texture({})]]", target.planes[i])
7536                            } else {
7537                                String::new()
7538                            };
7539                            args.push(EntryPointArgument {
7540                                ty_name,
7541                                name,
7542                                binding,
7543                                init: None,
7544                            });
7545                        }
7546                        let params_ty_name = &self.names
7547                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7548                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7549                            handle,
7550                            ExternalTextureNameKey::Params,
7551                        )];
7552                        let binding = if let Some(ref target) = target {
7553                            format!(" [[buffer({})]]", target.params)
7554                        } else {
7555                            String::new()
7556                        };
7557
7558                        args.push(EntryPointArgument {
7559                            ty_name: format!("constant {params_ty_name}&"),
7560                            name: params_name.clone(),
7561                            binding,
7562                            init: None,
7563                        });
7564                    }
7565                    _ => {
7566                        if var.space == crate::AddressSpace::WorkGroup
7567                            && ep.stage == crate::ShaderStage::Mesh
7568                        {
7569                            continue;
7570                        }
7571                        let tyvar = TypedGlobalVariable {
7572                            module,
7573                            names: &self.names,
7574                            handle,
7575                            usage,
7576                            reference: true,
7577                        };
7578                        let parts = tyvar.to_parts()?;
7579                        let mut binding = String::new();
7580                        if let Some(resolved) = resolved {
7581                            resolved.try_fmt(&mut binding)?;
7582                        }
7583                        args.push(EntryPointArgument {
7584                            ty_name: parts.ty_name,
7585                            name: parts.var_name,
7586                            binding,
7587                            init: var.init,
7588                        });
7589                    }
7590                }
7591            }
7592
7593            if do_vertex_pulling {
7594                if needs_vertex_id && v_existing_id.is_none() {
7595                    // Write the [[vertex_id]] argument.
7596                    args.push(EntryPointArgument {
7597                        ty_name: "uint".to_string(),
7598                        name: v_id.clone(),
7599                        binding: " [[vertex_id]]".to_string(),
7600                        init: None,
7601                    });
7602                }
7603
7604                if needs_instance_id && i_existing_id.is_none() {
7605                    args.push(EntryPointArgument {
7606                        ty_name: "uint".to_string(),
7607                        name: i_id.clone(),
7608                        binding: " [[instance_id]]".to_string(),
7609                        init: None,
7610                    });
7611                }
7612
7613                // Iterate vbm_resolved, output one argument for every vertex buffer,
7614                // using the names we generated earlier.
7615                for vbm in &vbm_resolved {
7616                    let id = &vbm.id;
7617                    let ty_name = &vbm.ty_name;
7618                    let param_name = &vbm.param_name;
7619                    args.push(EntryPointArgument {
7620                        ty_name: format!("const device {ty_name}*"),
7621                        name: param_name.clone(),
7622                        binding: format!(" [[buffer({id})]]"),
7623                        init: None,
7624                    });
7625                }
7626            }
7627
7628            // If this entry uses any variable-length arrays, their sizes are
7629            // passed as a final struct-typed argument.
7630            if needs_buffer_sizes {
7631                // this is checked earlier
7632                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7633                let mut binding = String::new();
7634                resolved.try_fmt(&mut binding)?;
7635                args.push(EntryPointArgument {
7636                    ty_name: "constant _mslBufferSizes&".to_string(),
7637                    name: "_buffer_sizes".to_string(),
7638                    binding,
7639                    init: None,
7640                });
7641            }
7642
7643            let mut is_first_arg = true;
7644            for arg in &args {
7645                if is_first_arg {
7646                    write!(self.out, "  ")?;
7647                } else {
7648                    write!(self.out, ", ")?;
7649                }
7650                is_first_arg = false;
7651                write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7652                if !is_wrapped {
7653                    write!(self.out, "{}", arg.binding)?;
7654                    if let Some(init) = arg.init {
7655                        write!(self.out, " = ")?;
7656                        self.put_const_expression(
7657                            init,
7658                            module,
7659                            mod_info,
7660                            &module.global_expressions,
7661                        )?;
7662                    }
7663                }
7664                writeln!(self.out)?;
7665            }
7666            if ep.stage == crate::ShaderStage::Mesh {
7667                for (handle, var) in module.global_variables.iter() {
7668                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7669                        continue;
7670                    }
7671                    if is_first_arg {
7672                        write!(self.out, "  ")?;
7673                    } else {
7674                        write!(self.out, ", ")?;
7675                    }
7676                    let ty_context = TypeContext {
7677                        handle: module.global_variables[handle].ty,
7678                        gctx: module.to_ctx(),
7679                        names: &self.names,
7680                        access: crate::StorageAccess::empty(),
7681                        first_time: false,
7682                    };
7683                    writeln!(
7684                        self.out,
7685                        "threadgroup {ty_context}& {}",
7686                        self.names[&NameKey::GlobalVariable(handle)]
7687                    )?;
7688                }
7689            }
7690
7691            // end of the entry point argument list
7692            writeln!(self.out, ") {{")?;
7693
7694            // Starting the function body.
7695            if do_vertex_pulling {
7696                // Provide zero values for all the attributes, which we will overwrite with
7697                // real data from the vertex attribute buffers, if the indices are in-bounds.
7698                for vbm in &vbm_resolved {
7699                    for attribute in vbm.attributes {
7700                        let location = attribute.shader_location;
7701                        let am_option = am_resolved.get(&location);
7702                        if am_option.is_none() {
7703                            // This bound attribute isn't used in this entry point, so
7704                            // don't bother zero-initializing it.
7705                            continue;
7706                        }
7707                        let am = am_option.unwrap();
7708                        let attribute_ty_name = &am.ty_name;
7709                        let attribute_name = &am.name;
7710
7711                        writeln!(
7712                            self.out,
7713                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7714                            back::Level(1)
7715                        )?;
7716                    }
7717
7718                    // Output a bounds check block that will set real values for the
7719                    // attributes, if the bounds are satisfied.
7720                    write!(self.out, "{}if (", back::Level(1))?;
7721
7722                    let idx = &vbm.id;
7723                    let stride = &vbm.stride;
7724                    let index_name = match vbm.step_mode {
7725                        back::msl::VertexBufferStepMode::Constant => "0",
7726                        back::msl::VertexBufferStepMode::ByVertex => {
7727                            if let Some(ref name) = v_existing_id {
7728                                name
7729                            } else {
7730                                &v_id
7731                            }
7732                        }
7733                        back::msl::VertexBufferStepMode::ByInstance => {
7734                            if let Some(ref name) = i_existing_id {
7735                                name
7736                            } else {
7737                                &i_id
7738                            }
7739                        }
7740                    };
7741                    write!(
7742                        self.out,
7743                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7744                    )?;
7745
7746                    writeln!(self.out, ") {{")?;
7747
7748                    // Pull the bytes out of the vertex buffer.
7749                    let ty_name = &vbm.ty_name;
7750                    let elem_name = &vbm.elem_name;
7751                    let param_name = &vbm.param_name;
7752
7753                    writeln!(
7754                        self.out,
7755                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7756                        back::Level(2),
7757                    )?;
7758
7759                    // Now set real values for each of the attributes, by unpacking the data
7760                    // from the buffer elements.
7761                    for attribute in vbm.attributes {
7762                        let location = attribute.shader_location;
7763                        let Some(am) = am_resolved.get(&location) else {
7764                            // This bound attribute isn't used in this entry point, so
7765                            // don't bother extracting the data. Too bad we emitted the
7766                            // unpacking function earlier -- it might not get used.
7767                            continue;
7768                        };
7769                        let attribute_name = &am.name;
7770                        let attribute_ty_name = &am.ty_name;
7771
7772                        let offset = attribute.offset;
7773                        let func = unpacking_functions
7774                            .get(&attribute.format)
7775                            .expect("Should have generated this unpacking function earlier.");
7776                        let func_name = &func.name;
7777
7778                        // Check dimensionality of the attribute compared to the unpacking
7779                        // function. If attribute dimension > unpack dimension, we have to
7780                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7781                        // scalar type. Otherwise, if attribute dimension is < unpack
7782                        // dimension, then we need to explicitly truncate the result.
7783                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7784
7785                        // We need an extra type conversion if the shader type does not
7786                        // match the type returned from the unpacking function.
7787                        let needs_conversion = am.scalar != func.scalar;
7788
7789                        if needs_padding_or_truncation != Ordering::Equal {
7790                            // Emit a comment flagging that a conversion is happening,
7791                            // since the actual logic can be at the end of a long line.
7792                            writeln!(
7793                                self.out,
7794                                "{}// {attribute_ty_name} <- {:?}",
7795                                back::Level(2),
7796                                attribute.format
7797                            )?;
7798                        }
7799
7800                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7801
7802                        if needs_padding_or_truncation == Ordering::Greater {
7803                            // Needs padding: emit constructor call for wider type
7804                            write!(self.out, "{attribute_ty_name}(")?;
7805                        }
7806
7807                        // Emit call to unpacking function
7808                        if needs_conversion {
7809                            put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7810                            write!(self.out, "(")?;
7811                        }
7812                        write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7813                        for i in (offset + 1)..(offset + func.byte_count) {
7814                            write!(self.out, ", {elem_name}.data[{i}]")?;
7815                        }
7816                        write!(self.out, ")")?;
7817                        if needs_conversion {
7818                            write!(self.out, ")")?;
7819                        }
7820
7821                        match needs_padding_or_truncation {
7822                            Ordering::Greater => {
7823                                // Padding
7824                                let ty_is_int = scalar_is_int(am.scalar);
7825                                let zero_value = if ty_is_int { "0" } else { "0.0" };
7826                                let one_value = if ty_is_int { "1" } else { "1.0" };
7827                                for i in func.dimension.map_or(1, u8::from)
7828                                    ..am.dimension.map_or(1, u8::from)
7829                                {
7830                                    write!(
7831                                        self.out,
7832                                        ", {}",
7833                                        if i == 3 { one_value } else { zero_value }
7834                                    )?;
7835                                }
7836                            }
7837                            Ordering::Less => {
7838                                // Truncate to the first `am.dimension` components
7839                                write!(
7840                                    self.out,
7841                                    ".{}",
7842                                    &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7843                                )?;
7844                            }
7845                            Ordering::Equal => {}
7846                        }
7847
7848                        if needs_padding_or_truncation == Ordering::Greater {
7849                            write!(self.out, ")")?;
7850                        }
7851
7852                        writeln!(self.out, ";")?;
7853                    }
7854
7855                    // End the bounds check / attribute setting block.
7856                    writeln!(self.out, "{}}}", back::Level(1))?;
7857                }
7858            }
7859
7860            // Metal doesn't support private mutable variables outside of functions,
7861            // so we put them here, just like the locals.
7862            for (handle, var) in module.global_variables.iter() {
7863                let usage = fun_info[handle];
7864                if usage.is_empty() {
7865                    continue;
7866                }
7867                if var.space == crate::AddressSpace::Private {
7868                    let tyvar = TypedGlobalVariable {
7869                        module,
7870                        names: &self.names,
7871                        handle,
7872                        usage,
7873
7874                        reference: false,
7875                    };
7876                    write!(self.out, "{}", back::INDENT)?;
7877                    tyvar.try_fmt(&mut self.out)?;
7878                    match var.init {
7879                        Some(value) => {
7880                            write!(self.out, " = ")?;
7881                            self.put_const_expression(
7882                                value,
7883                                module,
7884                                mod_info,
7885                                &module.global_expressions,
7886                            )?;
7887                            writeln!(self.out, ";")?;
7888                        }
7889                        None => {
7890                            writeln!(self.out, " = {{}};")?;
7891                        }
7892                    };
7893                } else if let Some(ref binding) = var.binding {
7894                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7895                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7896                        // write an inline sampler
7897                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7898                        writeln!(
7899                            self.out,
7900                            "{}constexpr {}::sampler {}(",
7901                            back::INDENT,
7902                            NAMESPACE,
7903                            name
7904                        )?;
7905                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7906                        writeln!(self.out, "{});", back::INDENT)?;
7907                    } else if let crate::TypeInner::Image {
7908                        class: crate::ImageClass::External,
7909                        ..
7910                    } = module.types[var.ty].inner
7911                    {
7912                        // Wrap the individual arguments for each external texture global
7913                        // in a struct which can be easily passed around.
7914                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7915                        let l1 = back::Level(1);
7916                        let l2 = l1.next();
7917                        writeln!(
7918                            self.out,
7919                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7920                        )?;
7921                        for i in 0..3 {
7922                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7923                                handle,
7924                                ExternalTextureNameKey::Plane(i),
7925                            )];
7926                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7927                        }
7928                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7929                            handle,
7930                            ExternalTextureNameKey::Params,
7931                        )];
7932                        writeln!(self.out, "{l2}.params = {params_name},")?;
7933                        writeln!(self.out, "{l1}}};")?;
7934                    }
7935                }
7936            }
7937
7938            if need_workgroup_variables_initialization {
7939                self.write_workgroup_variables_initialization(
7940                    module,
7941                    mod_info,
7942                    fun_info,
7943                    local_invocation_index,
7944                    ep.stage,
7945                )?;
7946            }
7947
7948            // Now take the arguments that we gathered into structs, and the
7949            // structs that we flattened into arguments, and emit local
7950            // variables with initializers that put everything back the way the
7951            // body code expects.
7952            //
7953            // If we had to generate fresh names for struct members passed as
7954            // arguments, be sure to use those names when rebuilding the struct.
7955            //
7956            // "Each day, I change some zeros to ones, and some ones to zeros.
7957            // The rest, I leave alone."
7958            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7959                let arg_name =
7960                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7961                match module.types[arg.ty].inner {
7962                    crate::TypeInner::Struct { ref members, .. } => {
7963                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7964                        write!(
7965                            self.out,
7966                            "{}const {} {} = {{ ",
7967                            back::INDENT,
7968                            struct_name,
7969                            arg_name
7970                        )?;
7971                        for (member_index, member) in members.iter().enumerate() {
7972                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7973                            let name = &flattened_member_names[&key];
7974                            if member_index != 0 {
7975                                write!(self.out, ", ")?;
7976                            }
7977                            // insert padding initialization, if needed
7978                            if self
7979                                .struct_member_pads
7980                                .contains(&(arg.ty, member_index as u32))
7981                            {
7982                                write!(self.out, "{{}}, ")?;
7983                            }
7984                            match member.binding {
7985                                Some(crate::Binding::Location {
7986                                    interpolation: Some(crate::Interpolation::PerVertex),
7987                                    ..
7988                                }) => {
7989                                    writeln!(
7990                                        self.out,
7991                                        "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7992                                        back::INDENT,
7993                                        varyings_member_name,
7994                                        arg_name,
7995                                    )?;
7996                                    continue;
7997                                }
7998                                Some(crate::Binding::Location { .. }) => {
7999                                    if has_varyings {
8000                                        write!(self.out, "{varyings_member_name}.")?;
8001                                    }
8002                                }
8003                                _ => (),
8004                            }
8005                            write!(self.out, "{name}")?;
8006                        }
8007                        writeln!(self.out, " }};")?;
8008                    }
8009                    _ => match arg.binding {
8010                        Some(crate::Binding::Location {
8011                            interpolation: Some(crate::Interpolation::PerVertex),
8012                            ..
8013                        }) => {
8014                            let ty_name = TypeContext {
8015                                handle: arg.ty,
8016                                gctx: module.to_ctx(),
8017                                names: &self.names,
8018                                access: crate::StorageAccess::empty(),
8019                                first_time: false,
8020                            };
8021                            writeln!(
8022                                self.out,
8023                                "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8024                                back::INDENT,
8025                                varyings_member_name,
8026                                arg_name,
8027                            )?;
8028                        }
8029                        Some(crate::Binding::Location { .. })
8030                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8031                            if has_varyings {
8032                                writeln!(
8033                                    self.out,
8034                                    "{}const auto {} = {}.{};",
8035                                    back::INDENT,
8036                                    arg_name,
8037                                    varyings_member_name,
8038                                    arg_name
8039                                )?;
8040                            }
8041                        }
8042                        _ => {}
8043                    },
8044                }
8045            }
8046
8047            let guarded_indices =
8048                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8049
8050            let context = StatementContext {
8051                expression: ExpressionContext {
8052                    function: fun,
8053                    origin: FunctionOrigin::EntryPoint(ep_index as _),
8054                    info: fun_info,
8055                    lang_version: options.lang_version,
8056                    policies: options.bounds_check_policies,
8057                    guarded_indices,
8058                    module,
8059                    mod_info,
8060                    pipeline_options,
8061                    force_loop_bounding: options.force_loop_bounding,
8062                    emit_int_div_checks: options.emit_int_div_checks,
8063                },
8064                result_struct: if ep.stage == crate::ShaderStage::Task {
8065                    None
8066                } else {
8067                    Some(&stage_out_name)
8068                },
8069            };
8070
8071            // Finally, declare all the local variables that we need
8072            //TODO: we can postpone this till the relevant expressions are emitted
8073            self.put_locals(&context.expression)?;
8074            self.update_expressions_to_bake(fun, fun_info, &context.expression);
8075            self.put_block(back::Level(1), &fun.body, &context)?;
8076            writeln!(self.out, "}}")?;
8077            if ep_index + 1 != module.entry_points.len() {
8078                writeln!(self.out)?;
8079            }
8080            self.named_expressions.clear();
8081
8082            if is_wrapped {
8083                self.write_wrapper_function(NestedFunctionInfo {
8084                    options,
8085                    ep,
8086                    module,
8087                    mod_info,
8088                    fun_info,
8089                    args,
8090                    local_invocation_index,
8091                    nested_name: &nested_fun_name,
8092                    outer_name: &fun_name,
8093                    out_mesh_info,
8094                })?;
8095            }
8096        }
8097
8098        Ok(info)
8099    }
8100
8101    pub(super) fn write_barrier(
8102        &mut self,
8103        flags: crate::Barrier,
8104        level: back::Level,
8105    ) -> BackendResult {
8106        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
8107        // so we try to avoid it here.
8108        if flags.is_empty() {
8109            writeln!(
8110                self.out,
8111                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8112            )?;
8113        }
8114        if flags.contains(crate::Barrier::STORAGE) {
8115            writeln!(
8116                self.out,
8117                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8118            )?;
8119        }
8120        if flags.contains(crate::Barrier::WORK_GROUP) {
8121            writeln!(
8122                self.out,
8123                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8124            )?;
8125            if self.needs_object_memory_barriers {
8126                writeln!(
8127                    self.out,
8128                    "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8129                )?;
8130            }
8131        }
8132        if flags.contains(crate::Barrier::SUB_GROUP) {
8133            writeln!(
8134                self.out,
8135                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8136            )?;
8137        }
8138        if flags.contains(crate::Barrier::TEXTURE) {
8139            writeln!(
8140                self.out,
8141                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8142            )?;
8143        }
8144        Ok(())
8145    }
8146}
8147
8148/// Initializing workgroup variables is more tricky for Metal because we have to deal
8149/// with atomics at the type-level (which don't have a copy constructor).
8150mod workgroup_mem_init {
8151    use crate::EntryPoint;
8152
8153    use super::*;
8154
8155    enum Access {
8156        GlobalVariable(Handle<crate::GlobalVariable>),
8157        StructMember(Handle<crate::Type>, u32),
8158        Array(usize),
8159    }
8160
8161    impl Access {
8162        fn write<W: Write>(
8163            &self,
8164            writer: &mut W,
8165            names: &FastHashMap<NameKey, String>,
8166        ) -> Result<(), core::fmt::Error> {
8167            match *self {
8168                Access::GlobalVariable(handle) => {
8169                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8170                }
8171                Access::StructMember(handle, index) => {
8172                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8173                }
8174                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8175            }
8176        }
8177    }
8178
8179    struct AccessStack {
8180        stack: Vec<Access>,
8181        array_depth: usize,
8182    }
8183
8184    impl AccessStack {
8185        const fn new() -> Self {
8186            Self {
8187                stack: Vec::new(),
8188                array_depth: 0,
8189            }
8190        }
8191
8192        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8193            let array_depth = self.array_depth;
8194            self.stack.push(Access::Array(array_depth));
8195            self.array_depth += 1;
8196            let res = cb(self, array_depth);
8197            self.stack.pop();
8198            self.array_depth -= 1;
8199            res
8200        }
8201
8202        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8203            self.stack.push(new);
8204            let res = cb(self);
8205            self.stack.pop();
8206            res
8207        }
8208
8209        fn write<W: Write>(
8210            &self,
8211            writer: &mut W,
8212            names: &FastHashMap<NameKey, String>,
8213        ) -> Result<(), core::fmt::Error> {
8214            for next in self.stack.iter() {
8215                next.write(writer, names)?;
8216            }
8217            Ok(())
8218        }
8219    }
8220
8221    impl<W: Write> Writer<W> {
8222        pub(super) fn need_workgroup_variables_initialization(
8223            &mut self,
8224            options: &Options,
8225            ep: &EntryPoint,
8226            module: &crate::Module,
8227            fun_info: &valid::FunctionInfo,
8228        ) -> bool {
8229            let is_task = ep.stage == crate::ShaderStage::Task;
8230            options.zero_initialize_workgroup_memory
8231                && ep.stage.compute_like()
8232                && module.global_variables.iter().any(|(handle, var)| {
8233                    let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8234                        || (var.space == crate::AddressSpace::TaskPayload && is_task);
8235                    !fun_info[handle].is_empty() && is_right_address_space
8236                })
8237        }
8238
8239        pub fn write_workgroup_variables_initialization(
8240            &mut self,
8241            module: &crate::Module,
8242            module_info: &valid::ModuleInfo,
8243            fun_info: &valid::FunctionInfo,
8244            local_invocation_index: Option<&NameKey>,
8245            stage: crate::ShaderStage,
8246        ) -> BackendResult {
8247            let level = back::Level(1);
8248
8249            writeln!(
8250                self.out,
8251                "{}if ({} == 0u) {{",
8252                level,
8253                local_invocation_index
8254                    .map(|name_key| self.names[name_key].as_str())
8255                    .unwrap_or("__local_invocation_index"),
8256            )?;
8257
8258            let mut access_stack = AccessStack::new();
8259
8260            let is_task = stage == crate::ShaderStage::Task;
8261            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8262                let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8263                    || (var.space == crate::AddressSpace::TaskPayload && is_task);
8264                !fun_info[handle].is_empty() && is_right_address_space
8265            });
8266
8267            for (handle, var) in vars {
8268                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8269                    self.write_workgroup_variable_initialization(
8270                        module,
8271                        module_info,
8272                        var.ty,
8273                        access_stack,
8274                        level.next(),
8275                    )
8276                })?;
8277            }
8278
8279            writeln!(self.out, "{level}}}")?;
8280            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8281        }
8282
8283        fn write_workgroup_variable_initialization(
8284            &mut self,
8285            module: &crate::Module,
8286            module_info: &valid::ModuleInfo,
8287            ty: Handle<crate::Type>,
8288            access_stack: &mut AccessStack,
8289            level: back::Level,
8290        ) -> BackendResult {
8291            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8292                write!(self.out, "{level}")?;
8293                access_stack.write(&mut self.out, &self.names)?;
8294                writeln!(self.out, " = {{}};")?;
8295            } else {
8296                match module.types[ty].inner {
8297                    crate::TypeInner::Atomic { .. } => {
8298                        write!(
8299                            self.out,
8300                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8301                        )?;
8302                        access_stack.write(&mut self.out, &self.names)?;
8303                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8304                    }
8305                    crate::TypeInner::Array { base, size, .. } => {
8306                        let count = match size.resolve(module.to_ctx())? {
8307                            proc::IndexableLength::Known(count) => count,
8308                            proc::IndexableLength::Dynamic => unreachable!(),
8309                        };
8310
8311                        access_stack.enter_array(|access_stack, array_depth| {
8312                            writeln!(
8313                                self.out,
8314                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8315                            )?;
8316                            self.write_workgroup_variable_initialization(
8317                                module,
8318                                module_info,
8319                                base,
8320                                access_stack,
8321                                level.next(),
8322                            )?;
8323                            writeln!(self.out, "{level}}}")?;
8324                            BackendResult::Ok(())
8325                        })?;
8326                    }
8327                    crate::TypeInner::Struct { ref members, .. } => {
8328                        for (index, member) in members.iter().enumerate() {
8329                            access_stack.enter(
8330                                Access::StructMember(ty, index as u32),
8331                                |access_stack| {
8332                                    self.write_workgroup_variable_initialization(
8333                                        module,
8334                                        module_info,
8335                                        member.ty,
8336                                        access_stack,
8337                                        level,
8338                                    )
8339                                },
8340                            )?;
8341                        }
8342                    }
8343                    _ => unreachable!(),
8344                }
8345            }
8346
8347            Ok(())
8348        }
8349    }
8350}
8351
8352impl crate::AtomicFunction {
8353    const fn to_msl(self) -> &'static str {
8354        match self {
8355            Self::Add => "fetch_add",
8356            Self::Subtract => "fetch_sub",
8357            Self::And => "fetch_and",
8358            Self::InclusiveOr => "fetch_or",
8359            Self::ExclusiveOr => "fetch_xor",
8360            Self::Min => "fetch_min",
8361            Self::Max => "fetch_max",
8362            Self::Exchange { compare: None } => "exchange",
8363            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8364        }
8365    }
8366
8367    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8368        Ok(match self {
8369            Self::Min => "min",
8370            Self::Max => "max",
8371            _ => Err(Error::FeatureNotImplemented(
8372                "64-bit atomic operation other than min/max".to_string(),
8373            ))?,
8374        })
8375    }
8376}