naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17    sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo, NAMESPACE,
18    WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21    arena::{Handle, HandleSet},
22    back::{
23        self, get_entry_points,
24        msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25        Baked,
26    },
27    common,
28    proc::{
29        self, concrete_int_scalars,
30        index::{self, BoundsCheck},
31        ExternalTextureNameKey, NameKey, TypeResolution,
32    },
33    valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39// This is a hack: we need to pass a pointer to an atomic,
40// but generally the backend isn't putting "&" in front of every pointer.
41// Some more general handling of pointers is needed to be implemented here.
42const ATOMIC_REFERENCE: &str = "&";
43
44const RT_NAMESPACE: &str = "metal::raytracing";
45const RAY_QUERY_TYPE: &str = "_RayQuery";
46const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
47const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
48const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
49const RAY_QUERY_FIELD_READY: &str = "ready";
50const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
51
52pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
53pub(crate) const MODF_FUNCTION: &str = "naga_modf";
54pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
55pub(crate) const ABS_FUNCTION: &str = "naga_abs";
56pub(crate) const DIV_FUNCTION: &str = "naga_div";
57pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67    "nagaTextureSampleBaseClampToEdge";
68/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
69/// However, if you put that texture inside a struct, everything is totally fine. This
70/// baffles me to no end.
71///
72/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
73/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
74/// you have noticed that this should be exactly the same to the compiler, and you're correct.
75pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76/// Name of the struct that is declared to wrap the 3 textures and parameters
77/// buffer that [`crate::ImageClass::External`] variables are lowered to,
78/// allowing them to be conveniently passed to user-defined or wrapper
79/// functions. The struct is declared in [`Writer::write_type_defs`].
80pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
82pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
83
84/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
85///
86/// The `sizes` slice determines whether this function writes a
87/// scalar, vector, or matrix type:
88///
89/// - An empty slice produces a scalar type.
90/// - A one-element slice produces a vector type.
91/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
92fn put_numeric_type(
93    out: &mut impl Write,
94    scalar: crate::Scalar,
95    sizes: &[crate::VectorSize],
96) -> Result<(), FmtError> {
97    match (scalar, sizes) {
98        (scalar, &[]) => {
99            write!(out, "{}", scalar.to_msl_name())
100        }
101        (scalar, &[rows]) => {
102            write!(
103                out,
104                "{}::{}{}",
105                NAMESPACE,
106                scalar.to_msl_name(),
107                common::vector_size_str(rows)
108            )
109        }
110        (scalar, &[rows, columns]) => {
111            write!(
112                out,
113                "{}::{}{}x{}",
114                NAMESPACE,
115                scalar.to_msl_name(),
116                common::vector_size_str(columns),
117                common::vector_size_str(rows)
118            )
119        }
120        (_, _) => Ok(()), // not meaningful
121    }
122}
123
124const fn scalar_is_int(scalar: crate::Scalar) -> bool {
125    use crate::ScalarKind::*;
126    match scalar.kind {
127        Sint | Uint | AbstractInt | Bool => true,
128        Float | AbstractFloat => false,
129    }
130}
131
132/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
133const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
134
135/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
136const REINTERPRET_PREFIX: &str = "reinterpreted_";
137
138/// Wrapper for identifier names for clamped level-of-detail values
139///
140/// Values of this type implement [`core::fmt::Display`], formatting as
141/// the name of the variable used to hold the cached clamped
142/// level-of-detail value for an `ImageLoad` expression.
143struct ClampedLod(Handle<crate::Expression>);
144
145impl Display for ClampedLod {
146    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
147        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
148    }
149}
150
151/// Wrapper for generating `struct _mslBufferSizes` member names for
152/// runtime-sized array lengths.
153///
154/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
155/// as an argument to the entry point. This argument's type in the MSL is
156/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
157/// each global variable containing a runtime-sized array.
158///
159/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
160/// runtime-sized array, then the value `ArraySize(global)` implements
161/// [`core::fmt::Display`], formatting as the name of the struct member carrying
162/// the number of elements in that runtime-sized array.
163///
164/// [`GlobalVariable`]: crate::GlobalVariable
165struct ArraySizeMember(Handle<crate::GlobalVariable>);
166
167impl Display for ArraySizeMember {
168    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
169        self.0.write_prefixed(f, "size")
170    }
171}
172
173/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
174///
175/// Implements [`core::fmt::Display`], formatting as a name derived from
176/// `target_type` and the variable name of `orig`.
177#[derive(Clone, Copy)]
178struct Reinterpreted<'a> {
179    target_type: &'a str,
180    orig: Handle<crate::Expression>,
181}
182
183impl<'a> Reinterpreted<'a> {
184    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
185        Self { target_type, orig }
186    }
187}
188
189impl Display for Reinterpreted<'_> {
190    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
191        f.write_str(REINTERPRET_PREFIX)?;
192        f.write_str(self.target_type)?;
193        self.orig.write_prefixed(f, "_e")
194    }
195}
196
197pub(super) struct TypeContext<'a> {
198    pub handle: Handle<crate::Type>,
199    pub gctx: proc::GlobalCtx<'a>,
200    pub names: &'a FastHashMap<NameKey, String>,
201    pub access: crate::StorageAccess,
202    pub first_time: bool,
203}
204
205impl TypeContext<'_> {
206    fn scalar(&self) -> Option<crate::Scalar> {
207        let ty = &self.gctx.types[self.handle];
208        ty.inner.scalar()
209    }
210
211    fn vector_size(&self) -> Option<crate::VectorSize> {
212        let ty = &self.gctx.types[self.handle];
213        match ty.inner {
214            crate::TypeInner::Vector { size, .. } => Some(size),
215            _ => None,
216        }
217    }
218
219    fn unwrap_array(self) -> Self {
220        match self.gctx.types[self.handle].inner {
221            crate::TypeInner::Array { base, .. } => Self {
222                handle: base,
223                ..self
224            },
225            _ => self,
226        }
227    }
228}
229
230impl Display for TypeContext<'_> {
231    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
232        let ty = &self.gctx.types[self.handle];
233        if ty.needs_alias() && !self.first_time {
234            let name = &self.names[&NameKey::Type(self.handle)];
235            return write!(out, "{name}");
236        }
237
238        match ty.inner {
239            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
240            crate::TypeInner::Atomic(scalar) => {
241                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
242            }
243            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
244            crate::TypeInner::Matrix {
245                columns,
246                rows,
247                scalar,
248            } => put_numeric_type(out, scalar, &[rows, columns]),
249            // Requires Metal-2.3
250            crate::TypeInner::CooperativeMatrix {
251                columns,
252                rows,
253                scalar,
254                role: _,
255            } => {
256                write!(
257                    out,
258                    "{NAMESPACE}::simdgroup_{}{}x{}",
259                    scalar.to_msl_name(),
260                    columns as u32,
261                    rows as u32,
262                )
263            }
264            crate::TypeInner::Pointer { base, space } => {
265                let sub = Self {
266                    handle: base,
267                    first_time: false,
268                    ..*self
269                };
270                let space_name = match space.to_msl_name() {
271                    Some(name) => name,
272                    None => return Ok(()),
273                };
274                write!(out, "{space_name} {sub}&")
275            }
276            crate::TypeInner::ValuePointer {
277                size,
278                scalar,
279                space,
280            } => {
281                match space.to_msl_name() {
282                    Some(name) => write!(out, "{name} ")?,
283                    None => return Ok(()),
284                };
285                match size {
286                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
287                    None => put_numeric_type(out, scalar, &[])?,
288                };
289
290                write!(out, "&")
291            }
292            crate::TypeInner::Array { base, .. } => {
293                let sub = Self {
294                    handle: base,
295                    first_time: false,
296                    ..*self
297                };
298                // Array lengths go at the end of the type definition,
299                // so just print the element type here.
300                write!(out, "{sub}")
301            }
302            crate::TypeInner::Struct { .. } => unreachable!(),
303            crate::TypeInner::Image {
304                dim,
305                arrayed,
306                class,
307            } => {
308                let dim_str = match dim {
309                    crate::ImageDimension::D1 => "1d",
310                    crate::ImageDimension::D2 => "2d",
311                    crate::ImageDimension::D3 => "3d",
312                    crate::ImageDimension::Cube => "cube",
313                };
314                let (texture_str, msaa_str, scalar, access) = match class {
315                    crate::ImageClass::Sampled { kind, multi } => {
316                        let (msaa_str, access) = if multi {
317                            ("_ms", "read")
318                        } else {
319                            ("", "sample")
320                        };
321                        let scalar = crate::Scalar { kind, width: 4 };
322                        ("texture", msaa_str, scalar, access)
323                    }
324                    crate::ImageClass::Depth { multi } => {
325                        let (msaa_str, access) = if multi {
326                            ("_ms", "read")
327                        } else {
328                            ("", "sample")
329                        };
330                        let scalar = crate::Scalar {
331                            kind: crate::ScalarKind::Float,
332                            width: 4,
333                        };
334                        ("depth", msaa_str, scalar, access)
335                    }
336                    crate::ImageClass::Storage { format, .. } => {
337                        let access = if self
338                            .access
339                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
340                        {
341                            "read_write"
342                        } else if self.access.contains(crate::StorageAccess::STORE) {
343                            "write"
344                        } else if self.access.contains(crate::StorageAccess::LOAD) {
345                            "read"
346                        } else {
347                            log::warn!(
348                                "Storage access for {:?} (name '{}'): {:?}",
349                                self.handle,
350                                ty.name.as_deref().unwrap_or_default(),
351                                self.access
352                            );
353                            unreachable!("module is not valid");
354                        };
355                        ("texture", "", format.into(), access)
356                    }
357                    crate::ImageClass::External => {
358                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
359                    }
360                };
361                let base_name = scalar.to_msl_name();
362                let array_str = if arrayed { "_array" } else { "" };
363                write!(
364                    out,
365                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
366                )
367            }
368            crate::TypeInner::Sampler { comparison: _ } => {
369                write!(out, "{NAMESPACE}::sampler")
370            }
371            crate::TypeInner::AccelerationStructure { vertex_return } => {
372                if vertex_return {
373                    unimplemented!("metal does not support vertex ray hit return")
374                }
375                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
376            }
377            crate::TypeInner::RayQuery { vertex_return } => {
378                if vertex_return {
379                    unimplemented!("metal does not support vertex ray hit return")
380                }
381                write!(out, "{RAY_QUERY_TYPE}")
382            }
383            crate::TypeInner::BindingArray { base, .. } => {
384                let base_tyname = Self {
385                    handle: base,
386                    first_time: false,
387                    ..*self
388                };
389
390                write!(
391                    out,
392                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
393                )
394            }
395        }
396    }
397}
398
399pub(super) struct TypedGlobalVariable<'a> {
400    pub module: &'a crate::Module,
401    pub names: &'a FastHashMap<NameKey, String>,
402    pub handle: Handle<crate::GlobalVariable>,
403    pub usage: valid::GlobalUse,
404    pub reference: bool,
405}
406
407struct TypedGlobalVariableParts {
408    ty_name: String,
409    var_name: String,
410}
411
412impl TypedGlobalVariable<'_> {
413    fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
414        let var = &self.module.global_variables[self.handle];
415        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
416
417        let storage_access = match var.space {
418            crate::AddressSpace::Storage { access } => access,
419            _ => match self.module.types[var.ty].inner {
420                crate::TypeInner::Image {
421                    class: crate::ImageClass::Storage { access, .. },
422                    ..
423                } => access,
424                crate::TypeInner::BindingArray { base, .. } => {
425                    match self.module.types[base].inner {
426                        crate::TypeInner::Image {
427                            class: crate::ImageClass::Storage { access, .. },
428                            ..
429                        } => access,
430                        _ => crate::StorageAccess::default(),
431                    }
432                }
433                _ => crate::StorageAccess::default(),
434            },
435        };
436        let ty_name = TypeContext {
437            handle: var.ty,
438            gctx: self.module.to_ctx(),
439            names: self.names,
440            access: storage_access,
441            first_time: false,
442        };
443
444        let access = if var.space.needs_access_qualifier()
445            && !self.usage.intersects(valid::GlobalUse::WRITE)
446        {
447            "const"
448        } else {
449            ""
450        };
451        let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
452            (Some(space), crate::AddressSpace::WorkGroup) => {
453                ("", space, access, if self.reference { "&" } else { "" })
454            }
455            (Some(space), _) if self.reference => {
456                let coherent = if var
457                    .memory_decorations
458                    .contains(crate::MemoryDecorations::COHERENT)
459                {
460                    "coherent "
461                } else {
462                    ""
463                };
464                (coherent, space, access, "&")
465            }
466            _ => ("", "", "", ""),
467        };
468
469        let ty = format!(
470            "{coherent}{space}{}{ty_name}{}{access}{reference}",
471            if space.is_empty() { "" } else { " " },
472            if access.is_empty() { "" } else { " " },
473        );
474
475        Ok(TypedGlobalVariableParts {
476            ty_name: ty,
477            var_name: name.clone(),
478        })
479    }
480    pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
481        let parts = self.to_parts()?;
482
483        Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
484    }
485}
486
487#[derive(Eq, PartialEq, Hash)]
488pub(super) enum WrappedFunction {
489    UnaryOp {
490        op: crate::UnaryOperator,
491        ty: (Option<crate::VectorSize>, crate::Scalar),
492    },
493    BinaryOp {
494        op: crate::BinaryOperator,
495        left_ty: (Option<crate::VectorSize>, crate::Scalar),
496        right_ty: (Option<crate::VectorSize>, crate::Scalar),
497    },
498    Math {
499        fun: crate::MathFunction,
500        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
501    },
502    Cast {
503        src_scalar: crate::Scalar,
504        vector_size: Option<crate::VectorSize>,
505        dst_scalar: crate::Scalar,
506    },
507    ImageLoad {
508        class: crate::ImageClass,
509    },
510    ImageSample {
511        class: crate::ImageClass,
512        clamp_to_edge: bool,
513    },
514    ImageQuerySize {
515        class: crate::ImageClass,
516    },
517    CooperativeLoad {
518        space_name: &'static str,
519        columns: crate::CooperativeSize,
520        rows: crate::CooperativeSize,
521        scalar: crate::Scalar,
522    },
523    CooperativeMultiplyAdd {
524        space_name: &'static str,
525        columns: crate::CooperativeSize,
526        rows: crate::CooperativeSize,
527        intermediate: crate::CooperativeSize,
528        scalar: crate::Scalar,
529    },
530}
531
532pub struct Writer<W> {
533    pub(super) out: W,
534    pub(super) names: FastHashMap<NameKey, String>,
535    pub(super) named_expressions: crate::NamedExpressions,
536    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
537    pub(super) need_bake_expressions: back::NeedBakeExpressions,
538    pub(super) namer: proc::Namer,
539    pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
540    #[cfg(test)]
541    pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
542    #[cfg(test)]
543    pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
544    /// Set of (struct type, struct field index) denoting which fields require
545    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
546    pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
547    pub(super) needs_object_memory_barriers: bool,
548}
549
550impl crate::Scalar {
551    pub(super) fn to_msl_name(self) -> &'static str {
552        use crate::ScalarKind as Sk;
553        match self {
554            Self {
555                kind: Sk::Float,
556                width: 4,
557            } => "float",
558            Self {
559                kind: Sk::Float,
560                width: 2,
561            } => "half",
562            Self {
563                kind: Sk::Sint,
564                width: 4,
565            } => "int",
566            Self {
567                kind: Sk::Uint,
568                width: 4,
569            } => "uint",
570            Self {
571                kind: Sk::Sint,
572                width: 8,
573            } => "long",
574            Self {
575                kind: Sk::Uint,
576                width: 8,
577            } => "ulong",
578            Self {
579                kind: Sk::Bool,
580                width: _,
581            } => "bool",
582            Self {
583                kind: Sk::AbstractInt | Sk::AbstractFloat,
584                width: _,
585            } => unreachable!("Found Abstract scalar kind"),
586            _ => unreachable!("Unsupported scalar kind: {:?}", self),
587        }
588    }
589}
590
591const fn separate(need_separator: bool) -> &'static str {
592    if need_separator {
593        ","
594    } else {
595        ""
596    }
597}
598
599fn should_pack_struct_member(
600    members: &[crate::StructMember],
601    span: u32,
602    index: usize,
603    module: &crate::Module,
604) -> Option<crate::Scalar> {
605    let member = &members[index];
606
607    let ty_inner = &module.types[member.ty].inner;
608    let last_offset = member.offset + ty_inner.size(module.to_ctx());
609    let next_offset = match members.get(index + 1) {
610        Some(next) => next.offset,
611        None => span,
612    };
613    let is_tight = next_offset == last_offset;
614
615    match *ty_inner {
616        crate::TypeInner::Vector {
617            size: crate::VectorSize::Tri,
618            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
619        } if is_tight => Some(scalar),
620        _ => None,
621    }
622}
623
624fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
625    match arena[ty].inner {
626        crate::TypeInner::Struct { ref members, .. } => {
627            if let Some(member) = members.last() {
628                if let crate::TypeInner::Array {
629                    size: crate::ArraySize::Dynamic,
630                    ..
631                } = arena[member.ty].inner
632                {
633                    return true;
634                }
635            }
636            false
637        }
638        crate::TypeInner::Array {
639            size: crate::ArraySize::Dynamic,
640            ..
641        } => true,
642        _ => false,
643    }
644}
645
646impl crate::AddressSpace {
647    /// Returns true if global variables in this address space are
648    /// passed in function arguments. These arguments need to be
649    /// passed through any functions called from the entry point.
650    const fn needs_pass_through(&self) -> bool {
651        match *self {
652            Self::Uniform
653            | Self::Storage { .. }
654            | Self::Private
655            | Self::WorkGroup
656            | Self::Immediate
657            | Self::Handle
658            | Self::TaskPayload => true,
659            Self::Function => false,
660            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
661        }
662    }
663
664    /// Returns true if the address space may need a "const" qualifier.
665    const fn needs_access_qualifier(&self) -> bool {
666        match *self {
667            //Note: we are ignoring the storage access here, and instead
668            // rely on the actual use of a global by functions. This means we
669            // may end up with "const" even if the binding is read-write,
670            // and that should be OK.
671            Self::Storage { .. } => true,
672            Self::TaskPayload => true,
673            Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
674            // These should always be read-write.
675            Self::Private | Self::WorkGroup => false,
676            // These translate to `constant` address space, no need for qualifiers.
677            Self::Uniform | Self::Immediate => false,
678            // Not applicable.
679            Self::Handle | Self::Function => false,
680        }
681    }
682
683    const fn to_msl_name(self) -> Option<&'static str> {
684        match self {
685            Self::Handle => None,
686            Self::Uniform | Self::Immediate => Some("constant"),
687            Self::Storage { .. } => Some("device"),
688            // note for `RayPayload`, this probably needs to be emulated as a
689            // private variable, as metal has essentially an inout input
690            // for where it is passed.
691            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
692            Self::WorkGroup => Some("threadgroup"),
693            Self::TaskPayload => Some("object_data"),
694            Self::IncomingRayPayload => Some("ray_data"),
695        }
696    }
697}
698
699impl crate::Type {
700    // Returns `true` if we need to emit an alias for this type.
701    const fn needs_alias(&self) -> bool {
702        use crate::TypeInner as Ti;
703
704        match self.inner {
705            // value types are concise enough, we only alias them if they are named
706            Ti::Scalar(_)
707            | Ti::Vector { .. }
708            | Ti::Matrix { .. }
709            | Ti::CooperativeMatrix { .. }
710            | Ti::Atomic(_)
711            | Ti::Pointer { .. }
712            | Ti::ValuePointer { .. } => self.name.is_some(),
713            // composite types are better to be aliased, regardless of the name
714            Ti::Struct { .. } | Ti::Array { .. } => true,
715            // handle types may be different, depending on the global var access, so we always inline them
716            Ti::Image { .. }
717            | Ti::Sampler { .. }
718            | Ti::AccelerationStructure { .. }
719            | Ti::RayQuery { .. }
720            | Ti::BindingArray { .. } => false,
721        }
722    }
723}
724
725#[derive(Clone, Copy)]
726enum FunctionOrigin {
727    Handle(Handle<crate::Function>),
728    EntryPoint(proc::EntryPointIndex),
729}
730
731trait NameKeyExt {
732    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
733        match origin {
734            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
735            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
736        }
737    }
738
739    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
740    /// policy when it needs to produce a pointer-typed result for an OOB access. These
741    /// are unique per accessed type, so the second argument is a type handle. See docs
742    /// for [`crate::back::msl`].
743    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
744        match origin {
745            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
746            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
747        }
748    }
749}
750
751impl NameKeyExt for NameKey {}
752
753/// A level of detail argument.
754///
755/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
756/// save the clamped level of detail in a temporary variable whose name is based
757/// on the handle of the `ImageLoad` expression. But for other policies, we just
758/// use the expression directly.
759///
760/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
761/// [`ImageLoad`]: crate::Expression::ImageLoad
762#[derive(Clone, Copy)]
763enum LevelOfDetail {
764    Direct(Handle<crate::Expression>),
765    Restricted(Handle<crate::Expression>),
766}
767
768/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
769///
770/// When this is used in code paths unconcerned with the `Restrict` bounds check
771/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
772/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
773/// be too awkward. If that changes, we can revisit.
774///
775/// [`ImageLoad`]: crate::Expression::ImageLoad
776/// [`ImageStore`]: crate::Statement::ImageStore
777struct TexelAddress {
778    coordinate: Handle<crate::Expression>,
779    array_index: Option<Handle<crate::Expression>>,
780    sample: Option<Handle<crate::Expression>>,
781    level: Option<LevelOfDetail>,
782}
783
784pub(super) struct ExpressionContext<'a> {
785    function: &'a crate::Function,
786    origin: FunctionOrigin,
787    info: &'a valid::FunctionInfo,
788    module: &'a crate::Module,
789    mod_info: &'a valid::ModuleInfo,
790    pipeline_options: &'a PipelineOptions,
791    lang_version: (u8, u8),
792    policies: index::BoundsCheckPolicies,
793
794    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
795    /// accesses. These may need to be cached in temporary variables. See
796    /// `index::find_checked_indexes` for details.
797    guarded_indices: HandleSet<crate::Expression>,
798    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
799    force_loop_bounding: bool,
800}
801
802impl<'a> ExpressionContext<'a> {
803    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
804        self.info[handle].ty.inner_with(&self.module.types)
805    }
806
807    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
808    ///
809    /// Only mipmapped images need to specify a level of detail. Since 1D
810    /// textures cannot have mipmaps, MSL requires that the level argument to
811    /// texture1d queries and accesses must be a constexpr 0. It's easiest
812    /// just to omit the level entirely for 1D textures.
813    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
814        let image_ty = self.resolve_type(image);
815        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
816            class.is_mipmapped() && dim != crate::ImageDimension::D1
817        } else {
818            false
819        }
820    }
821
822    fn choose_bounds_check_policy(
823        &self,
824        pointer: Handle<crate::Expression>,
825    ) -> index::BoundsCheckPolicy {
826        self.policies
827            .choose_policy(pointer, &self.module.types, self.info)
828    }
829
830    /// See docs for [`proc::index::access_needs_check`].
831    fn access_needs_check(
832        &self,
833        base: Handle<crate::Expression>,
834        index: index::GuardedIndex,
835    ) -> Option<index::IndexableLength> {
836        index::access_needs_check(
837            base,
838            index,
839            self.module,
840            &self.function.expressions,
841            self.info,
842        )
843    }
844
845    /// See docs for [`proc::index::bounds_check_iter`].
846    fn bounds_check_iter(
847        &self,
848        chain: Handle<crate::Expression>,
849    ) -> impl Iterator<Item = BoundsCheck> + '_ {
850        index::bounds_check_iter(chain, self.module, self.function, self.info)
851    }
852
853    /// See docs for [`proc::index::oob_local_types`].
854    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
855        index::oob_local_types(self.module, self.function, self.info, self.policies)
856    }
857
858    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
859        match self.function.expressions[expr_handle] {
860            crate::Expression::AccessIndex { base, index } => {
861                let ty = match *self.resolve_type(base) {
862                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
863                    ref ty => ty,
864                };
865                match *ty {
866                    crate::TypeInner::Struct {
867                        ref members, span, ..
868                    } => should_pack_struct_member(members, span, index as usize, self.module),
869                    _ => None,
870                }
871            }
872            _ => None,
873        }
874    }
875}
876
877struct StatementContext<'a> {
878    expression: ExpressionContext<'a>,
879    result_struct: Option<&'a str>,
880}
881
882impl<W: Write> Writer<W> {
883    /// Creates a new `Writer` instance.
884    pub fn new(out: W) -> Self {
885        Writer {
886            out,
887            names: FastHashMap::default(),
888            named_expressions: Default::default(),
889            need_bake_expressions: Default::default(),
890            namer: proc::Namer::default(),
891            wrapped_functions: FastHashSet::default(),
892            #[cfg(test)]
893            put_expression_stack_pointers: Default::default(),
894            #[cfg(test)]
895            put_block_stack_pointers: Default::default(),
896            struct_member_pads: FastHashSet::default(),
897            needs_object_memory_barriers: false,
898        }
899    }
900
901    /// Finishes writing and returns the output.
902    // See https://github.com/rust-lang/rust-clippy/issues/4979.
903    pub fn finish(self) -> W {
904        self.out
905    }
906
907    /// Generates statements to be inserted immediately before and at the very
908    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
909    /// The 0th item of the returned tuple should be inserted immediately prior
910    /// to the loop and the 1st item should be inserted at the very start of
911    /// the loop body.
912    ///
913    /// # What is this trying to solve?
914    ///
915    /// In Metal Shading Language, an infinite loop has undefined behavior.
916    /// (This rule is inherited from C++14.) This means that, if the MSL
917    /// compiler determines that a given loop will never exit, it may assume
918    /// that it is never reached. It may thus assume that any conditions
919    /// sufficient to cause the loop to be reached must be false. Like many
920    /// optimizing compilers, MSL uses this kind of analysis to establish limits
921    /// on the range of values variables involved in those conditions might
922    /// hold.
923    ///
924    /// For example, suppose the MSL compiler sees the code:
925    ///
926    /// ```ignore
927    /// if (i >= 10) {
928    ///     while (true) { }
929    /// }
930    /// ```
931    ///
932    /// It will recognize that the `while` loop will never terminate, conclude
933    /// that it must be unreachable, and thus infer that, if this code is
934    /// reached, then `i < 10` at that point.
935    ///
936    /// Now suppose that, at some point where `i` has the same value as above,
937    /// the compiler sees the code:
938    ///
939    /// ```ignore
940    /// if (i < 10) {
941    ///     a[i] = 1;
942    /// }
943    /// ```
944    ///
945    /// Because the compiler is confident that `i < 10`, it will make the
946    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
947    ///
948    /// ```ignore
949    /// a[i] = 1;
950    /// ```
951    ///
952    /// If that `if` condition was injected by Naga to implement a bounds check,
953    /// the MSL compiler's optimizations could allow out-of-bounds array
954    /// accesses to occur.
955    ///
956    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
957    /// that a loop is infinite, so an attacker could craft a Naga module
958    /// containing an infinite loop protected by conditions that cause the Metal
959    /// compiler to remove bounds checks that Naga injected elsewhere in the
960    /// function.
961    ///
962    /// This rewrite could occur even if the conditional assignment appears
963    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
964    /// reached. This would allow the attacker to save the results of
965    /// unauthorized reads somewhere accessible before entering the infinite
966    /// loop. But even worse, the MSL compiler has been observed to simply
967    /// delete the infinite loop entirely, so that even code dominated by the
968    /// loop becomes reachable. This would make the attack even more flexible,
969    /// since shaders that would appear to never terminate would actually exit
970    /// nicely, after having stolen data from elsewhere in the GPU address
971    /// space.
972    ///
973    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
974    /// generates is infinite. One approach would be to add inline assembly to
975    /// each loop that is annotated as potentially branching out of the loop,
976    /// but which in fact generates no instructions. Unfortunately, inline
977    /// assembly is not handled correctly by some Metal device drivers.
978    ///
979    /// A previously used approach was to add the following code to the bottom
980    /// of every loop:
981    ///
982    /// ```ignore
983    /// if (volatile bool unpredictable = false; unpredictable)
984    ///     break;
985    /// ```
986    ///
987    /// Although the `if` condition will always be false in any real execution,
988    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
989    /// it must assume that the `break` might be reached, and hence that the
990    /// loop is not unbounded. This prevents the range analysis impact described
991    /// above. Unfortunately this prevented the compiler from making important,
992    /// and safe, optimizations such as loop unrolling and was observed to
993    /// significantly hurt performance.
994    ///
995    /// Our current approach declares a counter before every loop and
996    /// increments it every iteration, breaking after 2^64 iterations:
997    ///
998    /// ```ignore
999    /// uint2 loop_bound = uint2(0);
1000    /// while (true) {
1001    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
1002    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
1003    /// }
1004    /// ```
1005    ///
1006    /// This convinces the compiler that the loop is finite and therefore may
1007    /// execute, whilst at the same time allowing optimizations such as loop
1008    /// unrolling. Furthermore the 64-bit counter is large enough it seems
1009    /// implausible that it would affect the execution of any shader.
1010    ///
1011    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
1012    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
1013    fn gen_force_bounded_loop_statements(
1014        &mut self,
1015        level: back::Level,
1016        context: &StatementContext,
1017    ) -> Option<(String, String)> {
1018        if !context.expression.force_loop_bounding {
1019            return None;
1020        }
1021
1022        let loop_bound_name = self.namer.call("loop_bound");
1023        // Count down from u32::MAX rather than up from 0 to avoid hang on
1024        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
1025        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1026        let level = level.next();
1027        let break_and_inc = format!(
1028            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1029{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1030        );
1031
1032        Some((decl, break_and_inc))
1033    }
1034
1035    fn put_call_parameters(
1036        &mut self,
1037        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1038        context: &ExpressionContext,
1039    ) -> BackendResult {
1040        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1041            writer.put_expression(expr, context, true)
1042        })
1043    }
1044
1045    fn put_call_parameters_impl<C, E>(
1046        &mut self,
1047        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1048        ctx: &C,
1049        put_expression: E,
1050    ) -> BackendResult
1051    where
1052        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1053    {
1054        write!(self.out, "(")?;
1055        for (i, handle) in parameters.enumerate() {
1056            if i != 0 {
1057                write!(self.out, ", ")?;
1058            }
1059            put_expression(self, ctx, handle)?;
1060        }
1061        write!(self.out, ")")?;
1062        Ok(())
1063    }
1064
1065    /// Writes the local variables of the given function, as well as any extra
1066    /// out-of-bounds locals that are needed.
1067    ///
1068    /// The names of the OOB locals are also added to `self.names` at the same
1069    /// time.
1070    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1071        let oob_local_types = context.oob_local_types();
1072        for &ty in oob_local_types.iter() {
1073            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1074            self.names.insert(name_key, self.namer.call("oob"));
1075        }
1076
1077        for (name_key, ty, init) in context
1078            .function
1079            .local_variables
1080            .iter()
1081            .map(|(local_handle, local)| {
1082                let name_key = NameKey::local(context.origin, local_handle);
1083                (name_key, local.ty, local.init)
1084            })
1085            .chain(oob_local_types.iter().map(|&ty| {
1086                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1087                (name_key, ty, None)
1088            }))
1089        {
1090            let ty_name = TypeContext {
1091                handle: ty,
1092                gctx: context.module.to_ctx(),
1093                names: &self.names,
1094                access: crate::StorageAccess::empty(),
1095                first_time: false,
1096            };
1097            write!(
1098                self.out,
1099                "{}{} {}",
1100                back::INDENT,
1101                ty_name,
1102                self.names[&name_key]
1103            )?;
1104            match init {
1105                Some(value) => {
1106                    write!(self.out, " = ")?;
1107                    self.put_expression(value, context, true)?;
1108                }
1109                None => {
1110                    write!(self.out, " = {{}}")?;
1111                }
1112            };
1113            writeln!(self.out, ";")?;
1114        }
1115        Ok(())
1116    }
1117
1118    fn put_level_of_detail(
1119        &mut self,
1120        level: LevelOfDetail,
1121        context: &ExpressionContext,
1122    ) -> BackendResult {
1123        match level {
1124            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1125            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1126        }
1127        Ok(())
1128    }
1129
1130    fn put_image_query(
1131        &mut self,
1132        image: Handle<crate::Expression>,
1133        query: &str,
1134        level: Option<LevelOfDetail>,
1135        context: &ExpressionContext,
1136    ) -> BackendResult {
1137        self.put_expression(image, context, false)?;
1138        write!(self.out, ".get_{query}(")?;
1139        if let Some(level) = level {
1140            self.put_level_of_detail(level, context)?;
1141        }
1142        write!(self.out, ")")?;
1143        Ok(())
1144    }
1145
1146    fn put_image_size_query(
1147        &mut self,
1148        image: Handle<crate::Expression>,
1149        level: Option<LevelOfDetail>,
1150        kind: crate::ScalarKind,
1151        context: &ExpressionContext,
1152    ) -> BackendResult {
1153        if let crate::TypeInner::Image {
1154            class: crate::ImageClass::External,
1155            ..
1156        } = *context.resolve_type(image)
1157        {
1158            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1159            self.put_expression(image, context, true)?;
1160            write!(self.out, ")")?;
1161            return Ok(());
1162        }
1163
1164        //Note: MSL only has separate width/height/depth queries,
1165        // so compose the result of them.
1166        let dim = match *context.resolve_type(image) {
1167            crate::TypeInner::Image { dim, .. } => dim,
1168            ref other => unreachable!("Unexpected type {:?}", other),
1169        };
1170        let scalar = crate::Scalar { kind, width: 4 };
1171        let coordinate_type = scalar.to_msl_name();
1172        match dim {
1173            crate::ImageDimension::D1 => {
1174                // Since 1D textures never have mipmaps, MSL requires that the
1175                // `level` argument be a constexpr 0. It's simplest for us just
1176                // to pass `None` and omit the level entirely.
1177                if kind == crate::ScalarKind::Uint {
1178                    // No need to construct a vector. No cast needed.
1179                    self.put_image_query(image, "width", None, context)?;
1180                } else {
1181                    // There's no definition for `int` in the `metal` namespace.
1182                    write!(self.out, "int(")?;
1183                    self.put_image_query(image, "width", None, context)?;
1184                    write!(self.out, ")")?;
1185                }
1186            }
1187            crate::ImageDimension::D2 => {
1188                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1189                self.put_image_query(image, "width", level, context)?;
1190                write!(self.out, ", ")?;
1191                self.put_image_query(image, "height", level, context)?;
1192                write!(self.out, ")")?;
1193            }
1194            crate::ImageDimension::D3 => {
1195                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1196                self.put_image_query(image, "width", level, context)?;
1197                write!(self.out, ", ")?;
1198                self.put_image_query(image, "height", level, context)?;
1199                write!(self.out, ", ")?;
1200                self.put_image_query(image, "depth", level, context)?;
1201                write!(self.out, ")")?;
1202            }
1203            crate::ImageDimension::Cube => {
1204                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1205                self.put_image_query(image, "width", level, context)?;
1206                write!(self.out, ")")?;
1207            }
1208        }
1209        Ok(())
1210    }
1211
1212    fn put_cast_to_uint_scalar_or_vector(
1213        &mut self,
1214        expr: Handle<crate::Expression>,
1215        context: &ExpressionContext,
1216    ) -> BackendResult {
1217        // coordinates in IR are int, but Metal expects uint
1218        match *context.resolve_type(expr) {
1219            crate::TypeInner::Scalar(_) => {
1220                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1221            }
1222            crate::TypeInner::Vector { size, .. } => {
1223                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1224            }
1225            _ => {
1226                return Err(Error::GenericValidation(
1227                    "Invalid type for image coordinate".into(),
1228                ))
1229            }
1230        };
1231
1232        write!(self.out, "(")?;
1233        self.put_expression(expr, context, true)?;
1234        write!(self.out, ")")?;
1235        Ok(())
1236    }
1237
1238    fn put_image_sample_level(
1239        &mut self,
1240        image: Handle<crate::Expression>,
1241        level: crate::SampleLevel,
1242        context: &ExpressionContext,
1243    ) -> BackendResult {
1244        let has_levels = context.image_needs_lod(image);
1245        match level {
1246            crate::SampleLevel::Auto => {}
1247            crate::SampleLevel::Zero => {
1248                //TODO: do we support Zero on `Sampled` image classes?
1249            }
1250            _ if !has_levels => {
1251                log::warn!("1D image can't be sampled with level {level:?}");
1252            }
1253            crate::SampleLevel::Exact(h) => {
1254                write!(self.out, ", {NAMESPACE}::level(")?;
1255                self.put_expression(h, context, true)?;
1256                write!(self.out, ")")?;
1257            }
1258            crate::SampleLevel::Bias(h) => {
1259                write!(self.out, ", {NAMESPACE}::bias(")?;
1260                self.put_expression(h, context, true)?;
1261                write!(self.out, ")")?;
1262            }
1263            crate::SampleLevel::Gradient { x, y } => {
1264                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1265                self.put_expression(x, context, true)?;
1266                write!(self.out, ", ")?;
1267                self.put_expression(y, context, true)?;
1268                write!(self.out, ")")?;
1269            }
1270        }
1271        Ok(())
1272    }
1273
1274    fn put_image_coordinate_limits(
1275        &mut self,
1276        image: Handle<crate::Expression>,
1277        level: Option<LevelOfDetail>,
1278        context: &ExpressionContext,
1279    ) -> BackendResult {
1280        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1281        write!(self.out, " - 1")?;
1282        Ok(())
1283    }
1284
1285    /// General function for writing restricted image indexes.
1286    ///
1287    /// This is used to produce restricted mip levels, array indices, and sample
1288    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1289    /// [`Restrict`] bounds check policy.
1290    ///
1291    /// This function writes an expression of the form:
1292    ///
1293    /// ```ignore
1294    ///
1295    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1296    ///
1297    /// ```
1298    ///
1299    /// [`ImageLoad`]: crate::Expression::ImageLoad
1300    /// [`ImageStore`]: crate::Statement::ImageStore
1301    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1302    fn put_restricted_scalar_image_index(
1303        &mut self,
1304        image: Handle<crate::Expression>,
1305        index: Handle<crate::Expression>,
1306        limit_method: &str,
1307        context: &ExpressionContext,
1308    ) -> BackendResult {
1309        write!(self.out, "{NAMESPACE}::min(uint(")?;
1310        self.put_expression(index, context, true)?;
1311        write!(self.out, "), ")?;
1312        self.put_expression(image, context, false)?;
1313        write!(self.out, ".{limit_method}() - 1)")?;
1314        Ok(())
1315    }
1316
1317    fn put_restricted_texel_address(
1318        &mut self,
1319        image: Handle<crate::Expression>,
1320        address: &TexelAddress,
1321        context: &ExpressionContext,
1322    ) -> BackendResult {
1323        // Write the coordinate.
1324        write!(self.out, "{NAMESPACE}::min(")?;
1325        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1326        write!(self.out, ", ")?;
1327        self.put_image_coordinate_limits(image, address.level, context)?;
1328        write!(self.out, ")")?;
1329
1330        // Write the array index, if present.
1331        if let Some(array_index) = address.array_index {
1332            write!(self.out, ", ")?;
1333            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1334        }
1335
1336        // Write the sample index, if present.
1337        if let Some(sample) = address.sample {
1338            write!(self.out, ", ")?;
1339            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1340        }
1341
1342        // The level of detail should be clamped and cached by
1343        // `put_cache_restricted_level`, so we don't need to clamp it here.
1344        if let Some(level) = address.level {
1345            write!(self.out, ", ")?;
1346            self.put_level_of_detail(level, context)?;
1347        }
1348
1349        Ok(())
1350    }
1351
1352    /// Write an expression that is true if the given image access is in bounds.
1353    fn put_image_access_bounds_check(
1354        &mut self,
1355        image: Handle<crate::Expression>,
1356        address: &TexelAddress,
1357        context: &ExpressionContext,
1358    ) -> BackendResult {
1359        let mut conjunction = "";
1360
1361        // First, check the level of detail. Only if that is in bounds can we
1362        // use it to find the appropriate bounds for the coordinates.
1363        let level = if let Some(level) = address.level {
1364            write!(self.out, "uint(")?;
1365            self.put_level_of_detail(level, context)?;
1366            write!(self.out, ") < ")?;
1367            self.put_expression(image, context, true)?;
1368            write!(self.out, ".get_num_mip_levels()")?;
1369            conjunction = " && ";
1370            Some(level)
1371        } else {
1372            None
1373        };
1374
1375        // Check sample index, if present.
1376        if let Some(sample) = address.sample {
1377            write!(self.out, "uint(")?;
1378            self.put_expression(sample, context, true)?;
1379            write!(self.out, ") < ")?;
1380            self.put_expression(image, context, true)?;
1381            write!(self.out, ".get_num_samples()")?;
1382            conjunction = " && ";
1383        }
1384
1385        // Check array index, if present.
1386        if let Some(array_index) = address.array_index {
1387            write!(self.out, "{conjunction}uint(")?;
1388            self.put_expression(array_index, context, true)?;
1389            write!(self.out, ") < ")?;
1390            self.put_expression(image, context, true)?;
1391            write!(self.out, ".get_array_size()")?;
1392            conjunction = " && ";
1393        }
1394
1395        // Finally, check if the coordinates are within bounds.
1396        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1397            crate::TypeInner::Vector { .. } => true,
1398            _ => false,
1399        };
1400        write!(self.out, "{conjunction}")?;
1401        if coord_is_vector {
1402            write!(self.out, "{NAMESPACE}::all(")?;
1403        }
1404        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1405        write!(self.out, " < ")?;
1406        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1407        if coord_is_vector {
1408            write!(self.out, ")")?;
1409        }
1410
1411        Ok(())
1412    }
1413
1414    fn put_image_load(
1415        &mut self,
1416        load: Handle<crate::Expression>,
1417        image: Handle<crate::Expression>,
1418        mut address: TexelAddress,
1419        context: &ExpressionContext,
1420    ) -> BackendResult {
1421        if let crate::TypeInner::Image {
1422            class: crate::ImageClass::External,
1423            ..
1424        } = *context.resolve_type(image)
1425        {
1426            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1427            self.put_expression(image, context, true)?;
1428            write!(self.out, ", ")?;
1429            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1430            write!(self.out, ")")?;
1431            return Ok(());
1432        }
1433
1434        match context.policies.image_load {
1435            proc::BoundsCheckPolicy::Restrict => {
1436                // Use the cached restricted level of detail, if any. Omit the
1437                // level altogether for 1D textures.
1438                if address.level.is_some() {
1439                    address.level = if context.image_needs_lod(image) {
1440                        Some(LevelOfDetail::Restricted(load))
1441                    } else {
1442                        None
1443                    }
1444                }
1445
1446                self.put_expression(image, context, false)?;
1447                write!(self.out, ".read(")?;
1448                self.put_restricted_texel_address(image, &address, context)?;
1449                write!(self.out, ")")?;
1450            }
1451            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1452                write!(self.out, "(")?;
1453                self.put_image_access_bounds_check(image, &address, context)?;
1454                write!(self.out, " ? ")?;
1455                self.put_unchecked_image_load(image, &address, context)?;
1456                write!(self.out, ": DefaultConstructible())")?;
1457            }
1458            proc::BoundsCheckPolicy::Unchecked => {
1459                self.put_unchecked_image_load(image, &address, context)?;
1460            }
1461        }
1462
1463        Ok(())
1464    }
1465
1466    fn put_unchecked_image_load(
1467        &mut self,
1468        image: Handle<crate::Expression>,
1469        address: &TexelAddress,
1470        context: &ExpressionContext,
1471    ) -> BackendResult {
1472        self.put_expression(image, context, false)?;
1473        write!(self.out, ".read(")?;
1474        // coordinates in IR are int, but Metal expects uint
1475        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1476        if let Some(expr) = address.array_index {
1477            write!(self.out, ", ")?;
1478            self.put_expression(expr, context, true)?;
1479        }
1480        if let Some(sample) = address.sample {
1481            write!(self.out, ", ")?;
1482            self.put_expression(sample, context, true)?;
1483        }
1484        if let Some(level) = address.level {
1485            if context.image_needs_lod(image) {
1486                write!(self.out, ", ")?;
1487                self.put_level_of_detail(level, context)?;
1488            }
1489        }
1490        write!(self.out, ")")?;
1491
1492        Ok(())
1493    }
1494
1495    fn put_image_atomic(
1496        &mut self,
1497        level: back::Level,
1498        image: Handle<crate::Expression>,
1499        address: &TexelAddress,
1500        fun: crate::AtomicFunction,
1501        value: Handle<crate::Expression>,
1502        context: &StatementContext,
1503    ) -> BackendResult {
1504        write!(self.out, "{level}")?;
1505        self.put_expression(image, &context.expression, false)?;
1506        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1507            fun.to_msl_64_bit()?
1508        } else {
1509            fun.to_msl()
1510        };
1511        write!(self.out, ".atomic_{op}(")?;
1512        // coordinates in IR are int, but Metal expects uint
1513        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1514        write!(self.out, ", ")?;
1515        self.put_expression(value, &context.expression, true)?;
1516        writeln!(self.out, ");")?;
1517
1518        // Workaround for Apple Metal TBDR driver bug: fragment shader atomic
1519        // texture writes randomly drop unless followed by a standard texture
1520        // write. Insert a dead-code write behind an unprovable condition so
1521        // the compiler emits proper memory safety barriers.
1522        // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425
1523        let value_ty = context.expression.resolve_type(value);
1524        let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1525            (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1526            (_, Some(8)) => "ulong4(0uL)",
1527            _ => "uint4(0u)",
1528        };
1529        let coord_ty = context.expression.resolve_type(address.coordinate);
1530        let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1531            ""
1532        } else {
1533            ".x"
1534        };
1535        write!(self.out, "{level}if (")?;
1536        self.put_expression(address.coordinate, &context.expression, true)?;
1537        write!(self.out, "{x} == -99999) {{ ")?;
1538        self.put_expression(image, &context.expression, false)?;
1539        write!(self.out, ".write({zero_value}, ")?;
1540        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1541        if let Some(array_index) = address.array_index {
1542            write!(self.out, ", ")?;
1543            self.put_expression(array_index, &context.expression, true)?;
1544        }
1545        writeln!(self.out, "); }}")?;
1546
1547        Ok(())
1548    }
1549
1550    fn put_image_store(
1551        &mut self,
1552        level: back::Level,
1553        image: Handle<crate::Expression>,
1554        address: &TexelAddress,
1555        value: Handle<crate::Expression>,
1556        context: &StatementContext,
1557    ) -> BackendResult {
1558        write!(self.out, "{level}")?;
1559        self.put_expression(image, &context.expression, false)?;
1560        write!(self.out, ".write(")?;
1561        self.put_expression(value, &context.expression, true)?;
1562        write!(self.out, ", ")?;
1563        // coordinates in IR are int, but Metal expects uint
1564        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1565        if let Some(expr) = address.array_index {
1566            write!(self.out, ", ")?;
1567            self.put_expression(expr, &context.expression, true)?;
1568        }
1569        writeln!(self.out, ");")?;
1570
1571        Ok(())
1572    }
1573
1574    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1575    ///
1576    /// The 'maximum valid index' is simply one less than the array's length.
1577    ///
1578    /// This emits an expression of the form `a / b`, so the caller must
1579    /// parenthesize its output if it will be applying operators of higher
1580    /// precedence.
1581    ///
1582    /// `handle` must be the handle of a global variable whose final member is a
1583    /// dynamically sized array.
1584    fn put_dynamic_array_max_index(
1585        &mut self,
1586        handle: Handle<crate::GlobalVariable>,
1587        context: &ExpressionContext,
1588    ) -> BackendResult {
1589        let global = &context.module.global_variables[handle];
1590        let (offset, array_ty) = match context.module.types[global.ty].inner {
1591            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1592                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1593                None => return Err(Error::GenericValidation("Struct has no members".into())),
1594            },
1595            crate::TypeInner::Array {
1596                size: crate::ArraySize::Dynamic,
1597                ..
1598            } => (0, global.ty),
1599            ref ty => {
1600                return Err(Error::GenericValidation(format!(
1601                    "Expected type with dynamic array, got {ty:?}"
1602                )))
1603            }
1604        };
1605
1606        let (size, stride) = match context.module.types[array_ty].inner {
1607            crate::TypeInner::Array { base, stride, .. } => (
1608                context.module.types[base]
1609                    .inner
1610                    .size(context.module.to_ctx()),
1611                stride,
1612            ),
1613            ref ty => {
1614                return Err(Error::GenericValidation(format!(
1615                    "Expected array type, got {ty:?}"
1616                )))
1617            }
1618        };
1619
1620        // When the stride length is larger than the size, the final element's stride of
1621        // bytes would have padding following the value. But the buffer size in
1622        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1623        // enough to hold the actual values' bytes.
1624        //
1625        // So subtract off the size to get a byte size that falls at the start or within
1626        // the final element. Then divide by the stride size, to get one less than the
1627        // length, and then add one. This works even if the buffer size does include the
1628        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1629        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1630        // rules, together with draw-time validation when `minBindingSize` is zero,
1631        // prevent that.
1632        write!(
1633            self.out,
1634            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1635            member = ArraySizeMember(handle),
1636            offset = offset,
1637            size = size,
1638            stride = stride,
1639        )?;
1640        Ok(())
1641    }
1642
1643    /// Emit code for the arithmetic expression of the dot product.
1644    ///
1645    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1646    /// an index. It writes out the expression for the vector component at that index.
1647    fn put_dot_product<T: Copy>(
1648        &mut self,
1649        arg: T,
1650        arg1: T,
1651        size: usize,
1652        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1653    ) -> BackendResult {
1654        // Write parentheses around the dot product expression to prevent operators
1655        // with different precedences from applying earlier.
1656        write!(self.out, "(")?;
1657
1658        // Cycle through all the components of the vector
1659        for index in 0..size {
1660            // Write the addition to the previous product
1661            // This will print an extra '+' at the beginning but that is fine in msl
1662            write!(self.out, " + ")?;
1663            extractor(self, arg, index)?;
1664            write!(self.out, " * ")?;
1665            extractor(self, arg1, index)?;
1666        }
1667
1668        write!(self.out, ")")?;
1669        Ok(())
1670    }
1671
1672    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1673    fn put_pack4x8(
1674        &mut self,
1675        arg: Handle<crate::Expression>,
1676        context: &ExpressionContext<'_>,
1677        was_signed: bool,
1678        clamp_bounds: Option<(&str, &str)>,
1679    ) -> Result<(), Error> {
1680        let write_arg = |this: &mut Self| -> BackendResult {
1681            if let Some((min, max)) = clamp_bounds {
1682                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1683                write!(this.out, "{NAMESPACE}::clamp(")?;
1684                this.put_expression(arg, context, true)?;
1685                write!(this.out, ", {min}, {max})")?;
1686            } else {
1687                this.put_expression(arg, context, true)?;
1688            }
1689            Ok(())
1690        };
1691
1692        if context.lang_version >= (2, 1) {
1693            let packed_type = if was_signed {
1694                "packed_char4"
1695            } else {
1696                "packed_uchar4"
1697            };
1698            // Metal uses little endian byte order, which matches what WGSL expects here.
1699            write!(self.out, "as_type<uint>({packed_type}(")?;
1700            write_arg(self)?;
1701            write!(self.out, "))")?;
1702        } else {
1703            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1704            if was_signed {
1705                write!(self.out, "uint(")?;
1706            }
1707            write!(self.out, "(")?;
1708            write_arg(self)?;
1709            write!(self.out, "[0] & 0xFF) | ((")?;
1710            write_arg(self)?;
1711            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1712            write_arg(self)?;
1713            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1714            write_arg(self)?;
1715            write!(self.out, "[3] & 0xFF) << 24)")?;
1716            if was_signed {
1717                write!(self.out, ")")?;
1718            }
1719        }
1720
1721        Ok(())
1722    }
1723
1724    /// Emit code for the isign expression.
1725    ///
1726    fn put_isign(
1727        &mut self,
1728        arg: Handle<crate::Expression>,
1729        context: &ExpressionContext,
1730    ) -> BackendResult {
1731        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1732        let scalar = context
1733            .resolve_type(arg)
1734            .scalar()
1735            .expect("put_isign should only be called for args which have an integer scalar type")
1736            .to_msl_name();
1737        match context.resolve_type(arg) {
1738            &crate::TypeInner::Vector { size, .. } => {
1739                let size = common::vector_size_str(size);
1740                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1741            }
1742            _ => {
1743                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1744            }
1745        }
1746        write!(self.out, ", (")?;
1747        self.put_expression(arg, context, true)?;
1748        write!(self.out, " > 0)), {scalar}(0), (")?;
1749        self.put_expression(arg, context, true)?;
1750        write!(self.out, " == 0))")?;
1751        Ok(())
1752    }
1753
1754    pub(super) fn put_const_expression(
1755        &mut self,
1756        expr_handle: Handle<crate::Expression>,
1757        module: &crate::Module,
1758        mod_info: &valid::ModuleInfo,
1759        arena: &crate::Arena<crate::Expression>,
1760    ) -> BackendResult {
1761        self.put_possibly_const_expression(
1762            expr_handle,
1763            arena,
1764            module,
1765            mod_info,
1766            &(module, mod_info),
1767            |&(_, mod_info), expr| &mod_info[expr],
1768            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1769        )
1770    }
1771
1772    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1773        match literal {
1774            crate::Literal::F64(_) => {
1775                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1776            }
1777            crate::Literal::F16(value) => {
1778                if value.is_infinite() {
1779                    let sign = if value.is_sign_negative() { "-" } else { "" };
1780                    write!(self.out, "{sign}INFINITY")?;
1781                } else if value.is_nan() {
1782                    write!(self.out, "NAN")?;
1783                } else {
1784                    let suffix = if value.fract() == f16::from_f32(0.0) {
1785                        ".0h"
1786                    } else {
1787                        "h"
1788                    };
1789                    write!(self.out, "{value}{suffix}")?;
1790                }
1791            }
1792            crate::Literal::F32(value) => {
1793                if value.is_infinite() {
1794                    let sign = if value.is_sign_negative() { "-" } else { "" };
1795                    write!(self.out, "{sign}INFINITY")?;
1796                } else if value.is_nan() {
1797                    write!(self.out, "NAN")?;
1798                } else {
1799                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1800                    write!(self.out, "{value}{suffix}")?;
1801                }
1802            }
1803            crate::Literal::U32(value) => {
1804                write!(self.out, "{value}u")?;
1805            }
1806            crate::Literal::I32(value) => {
1807                // `-2147483648` is parsed as unary negation of positive 2147483648.
1808                // 2147483648 is too large for int32_t meaning the expression gets
1809                // promoted to a int64_t which is not our intention. Avoid this by instead
1810                // using `-2147483647 - 1`.
1811                if value == i32::MIN {
1812                    write!(self.out, "({} - 1)", value + 1)?;
1813                } else {
1814                    write!(self.out, "{value}")?;
1815                }
1816            }
1817            crate::Literal::U64(value) => {
1818                write!(self.out, "{value}uL")?;
1819            }
1820            crate::Literal::I64(value) => {
1821                // `-9223372036854775808` is parsed as unary negation of positive
1822                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1823                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1824                // value to `-9223372036854775808`. Which would then be negated, possibly
1825                // causing undefined behaviour. Avoid this by instead using
1826                // `-9223372036854775808L - 1L`.
1827                if value == i64::MIN {
1828                    write!(self.out, "({}L - 1L)", value + 1)?;
1829                } else {
1830                    write!(self.out, "{value}L")?;
1831                }
1832            }
1833            crate::Literal::Bool(value) => {
1834                write!(self.out, "{value}")?;
1835            }
1836            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1837                return Err(Error::GenericValidation(
1838                    "Unsupported abstract literal".into(),
1839                ));
1840            }
1841        }
1842        Ok(())
1843    }
1844
1845    #[allow(clippy::too_many_arguments)]
1846    fn put_possibly_const_expression<C, I, E>(
1847        &mut self,
1848        expr_handle: Handle<crate::Expression>,
1849        expressions: &crate::Arena<crate::Expression>,
1850        module: &crate::Module,
1851        mod_info: &valid::ModuleInfo,
1852        ctx: &C,
1853        get_expr_ty: I,
1854        put_expression: E,
1855    ) -> BackendResult
1856    where
1857        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1858        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1859    {
1860        match expressions[expr_handle] {
1861            crate::Expression::Literal(literal) => {
1862                self.put_literal(literal)?;
1863            }
1864            crate::Expression::Constant(handle) => {
1865                let constant = &module.constants[handle];
1866                if constant.name.is_some() {
1867                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1868                } else {
1869                    self.put_const_expression(
1870                        constant.init,
1871                        module,
1872                        mod_info,
1873                        &module.global_expressions,
1874                    )?;
1875                }
1876            }
1877            crate::Expression::ZeroValue(ty) => {
1878                let ty_name = TypeContext {
1879                    handle: ty,
1880                    gctx: module.to_ctx(),
1881                    names: &self.names,
1882                    access: crate::StorageAccess::empty(),
1883                    first_time: false,
1884                };
1885                write!(self.out, "{ty_name} {{}}")?;
1886            }
1887            crate::Expression::Compose { ty, ref components } => {
1888                let ty_name = TypeContext {
1889                    handle: ty,
1890                    gctx: module.to_ctx(),
1891                    names: &self.names,
1892                    access: crate::StorageAccess::empty(),
1893                    first_time: false,
1894                };
1895                write!(self.out, "{ty_name}")?;
1896                match module.types[ty].inner {
1897                    crate::TypeInner::Scalar(_)
1898                    | crate::TypeInner::Vector { .. }
1899                    | crate::TypeInner::Matrix { .. } => {
1900                        self.put_call_parameters_impl(
1901                            components.iter().copied(),
1902                            ctx,
1903                            put_expression,
1904                        )?;
1905                    }
1906                    crate::TypeInner::Array { .. } => {
1907                        // Naga Arrays are Metal arrays wrapped in structs, so
1908                        // we need two levels of braces.
1909                        write!(self.out, " {{{{")?;
1910                        for (index, &component) in components.iter().enumerate() {
1911                            if index != 0 {
1912                                write!(self.out, ", ")?;
1913                            }
1914                            put_expression(self, ctx, component)?;
1915                        }
1916                        write!(self.out, "}}}}")?;
1917                    }
1918                    crate::TypeInner::Struct { .. } => {
1919                        write!(self.out, " {{")?;
1920                        for (index, &component) in components.iter().enumerate() {
1921                            if index != 0 {
1922                                write!(self.out, ", ")?;
1923                            }
1924                            // insert padding initialization, if needed
1925                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1926                                write!(self.out, "{{}}, ")?;
1927                            }
1928                            put_expression(self, ctx, component)?;
1929                        }
1930                        write!(self.out, "}}")?;
1931                    }
1932                    _ => return Err(Error::UnsupportedCompose(ty)),
1933                }
1934            }
1935            crate::Expression::Splat { size, value } => {
1936                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1937                    crate::TypeInner::Scalar(scalar) => scalar,
1938                    ref ty => {
1939                        return Err(Error::GenericValidation(format!(
1940                            "Expected splat value type must be a scalar, got {ty:?}",
1941                        )))
1942                    }
1943                };
1944                put_numeric_type(&mut self.out, scalar, &[size])?;
1945                write!(self.out, "(")?;
1946                put_expression(self, ctx, value)?;
1947                write!(self.out, ")")?;
1948            }
1949            _ => {
1950                return Err(Error::Override);
1951            }
1952        }
1953
1954        Ok(())
1955    }
1956
1957    /// Emit code for the expression `expr_handle`.
1958    ///
1959    /// The `is_scoped` argument is true if the surrounding operators have the
1960    /// precedence of the comma operator, or lower. So, for example:
1961    ///
1962    /// - Pass `true` for `is_scoped` when writing function arguments, an
1963    ///   expression statement, an initializer expression, or anything already
1964    ///   wrapped in parenthesis.
1965    ///
1966    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1967    ///   almost anything else.
1968    pub(super) fn put_expression(
1969        &mut self,
1970        expr_handle: Handle<crate::Expression>,
1971        context: &ExpressionContext,
1972        is_scoped: bool,
1973    ) -> BackendResult {
1974        // Add to the set in order to track the stack size.
1975        #[cfg(test)]
1976        self.put_expression_stack_pointers
1977            .insert(ptr::from_ref(&expr_handle).cast());
1978
1979        if let Some(name) = self.named_expressions.get(&expr_handle) {
1980            write!(self.out, "{name}")?;
1981            return Ok(());
1982        }
1983
1984        let expression = &context.function.expressions[expr_handle];
1985        match *expression {
1986            crate::Expression::Literal(_)
1987            | crate::Expression::Constant(_)
1988            | crate::Expression::ZeroValue(_)
1989            | crate::Expression::Compose { .. }
1990            | crate::Expression::Splat { .. } => {
1991                self.put_possibly_const_expression(
1992                    expr_handle,
1993                    &context.function.expressions,
1994                    context.module,
1995                    context.mod_info,
1996                    context,
1997                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1998                    |writer, context, expr| writer.put_expression(expr, context, true),
1999                )?;
2000            }
2001            crate::Expression::Override(_) => return Err(Error::Override),
2002            crate::Expression::Access { base, .. }
2003            | crate::Expression::AccessIndex { base, .. } => {
2004                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
2005                // Since `put_bounds_checks` and `put_access_chain` handle an entire
2006                // access chain at a time, recursing back through `put_expression` only
2007                // for index expressions and the base object, we will never see intermediate
2008                // `Access` or `AccessIndex` expressions here.
2009                let policy = context.choose_bounds_check_policy(base);
2010                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2011                    && self.put_bounds_checks(
2012                        expr_handle,
2013                        context,
2014                        back::Level(0),
2015                        if is_scoped { "" } else { "(" },
2016                    )?
2017                {
2018                    write!(self.out, " ? ")?;
2019                    self.put_access_chain(expr_handle, policy, context)?;
2020                    write!(self.out, " : ")?;
2021
2022                    if context.resolve_type(base).pointer_space().is_some() {
2023                        // We can't just use `DefaultConstructible` if this is a pointer.
2024                        // Instead, we create a dummy local variable to serve as pointer
2025                        // target if the access is out of bounds.
2026                        let result_ty = context.info[expr_handle]
2027                            .ty
2028                            .inner_with(&context.module.types)
2029                            .pointer_base_type();
2030                        let result_ty_handle = match result_ty {
2031                            Some(TypeResolution::Handle(handle)) => handle,
2032                            Some(TypeResolution::Value(_)) => {
2033                                // As long as the result of a pointer access expression is
2034                                // passed to a function or stored in a let binding, the
2035                                // type will be in the arena. If additional uses of
2036                                // pointers become valid, this assumption might no longer
2037                                // hold. Note that the LHS of a load or store doesn't
2038                                // take this path -- there is dedicated code in `put_load`
2039                                // and `put_store`.
2040                                unreachable!(
2041                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2042                                );
2043                            }
2044                            None => {
2045                                unreachable!(
2046                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2047                                )
2048                            }
2049                        };
2050                        let name_key =
2051                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
2052                        self.out.write_str(&self.names[&name_key])?;
2053                    } else {
2054                        write!(self.out, "DefaultConstructible()")?;
2055                    }
2056
2057                    if !is_scoped {
2058                        write!(self.out, ")")?;
2059                    }
2060                } else {
2061                    self.put_access_chain(expr_handle, policy, context)?;
2062                }
2063            }
2064            crate::Expression::Swizzle {
2065                size,
2066                vector,
2067                pattern,
2068            } => {
2069                self.put_wrapped_expression_for_packed_vec3_access(
2070                    vector,
2071                    context,
2072                    false,
2073                    &Self::put_expression,
2074                )?;
2075                write!(self.out, ".")?;
2076                for &sc in pattern[..size as usize].iter() {
2077                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2078                }
2079            }
2080            crate::Expression::FunctionArgument(index) => {
2081                let name_key = match context.origin {
2082                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2083                    FunctionOrigin::EntryPoint(ep_index) => {
2084                        NameKey::EntryPointArgument(ep_index, index)
2085                    }
2086                };
2087                let name = &self.names[&name_key];
2088                write!(self.out, "{name}")?;
2089            }
2090            crate::Expression::GlobalVariable(handle) => {
2091                let name = &self.names[&NameKey::GlobalVariable(handle)];
2092                write!(self.out, "{name}")?;
2093            }
2094            crate::Expression::LocalVariable(handle) => {
2095                let name_key = NameKey::local(context.origin, handle);
2096                let name = &self.names[&name_key];
2097                write!(self.out, "{name}")?;
2098            }
2099            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2100            crate::Expression::ImageSample {
2101                coordinate,
2102                image,
2103                sampler,
2104                clamp_to_edge: true,
2105                gather: None,
2106                array_index: None,
2107                offset: None,
2108                level: crate::SampleLevel::Zero,
2109                depth_ref: None,
2110            } => {
2111                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2112                self.put_expression(image, context, true)?;
2113                write!(self.out, ", ")?;
2114                self.put_expression(sampler, context, true)?;
2115                write!(self.out, ", ")?;
2116                self.put_expression(coordinate, context, true)?;
2117                write!(self.out, ")")?;
2118            }
2119            crate::Expression::ImageSample {
2120                image,
2121                sampler,
2122                gather,
2123                coordinate,
2124                array_index,
2125                offset,
2126                level,
2127                depth_ref,
2128                clamp_to_edge,
2129            } => {
2130                if clamp_to_edge {
2131                    return Err(Error::GenericValidation(
2132                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2133                    ));
2134                }
2135
2136                let main_op = match gather {
2137                    Some(_) => "gather",
2138                    None => "sample",
2139                };
2140                let comparison_op = match depth_ref {
2141                    Some(_) => "_compare",
2142                    None => "",
2143                };
2144                self.put_expression(image, context, false)?;
2145                write!(self.out, ".{main_op}{comparison_op}(")?;
2146                self.put_expression(sampler, context, true)?;
2147                write!(self.out, ", ")?;
2148                self.put_expression(coordinate, context, true)?;
2149                if let Some(expr) = array_index {
2150                    write!(self.out, ", ")?;
2151                    self.put_expression(expr, context, true)?;
2152                }
2153                if let Some(dref) = depth_ref {
2154                    write!(self.out, ", ")?;
2155                    self.put_expression(dref, context, true)?;
2156                }
2157
2158                self.put_image_sample_level(image, level, context)?;
2159
2160                if let Some(offset) = offset {
2161                    write!(self.out, ", ")?;
2162                    self.put_expression(offset, context, true)?;
2163                }
2164
2165                match gather {
2166                    None | Some(crate::SwizzleComponent::X) => {}
2167                    Some(component) => {
2168                        let is_cube_map = match *context.resolve_type(image) {
2169                            crate::TypeInner::Image {
2170                                dim: crate::ImageDimension::Cube,
2171                                ..
2172                            } => true,
2173                            _ => false,
2174                        };
2175                        // Offset always comes before the gather, except
2176                        // in cube maps where it's not applicable
2177                        if offset.is_none() && !is_cube_map {
2178                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2179                        }
2180                        let letter = back::COMPONENTS[component as usize];
2181                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2182                    }
2183                }
2184                write!(self.out, ")")?;
2185            }
2186            crate::Expression::ImageLoad {
2187                image,
2188                coordinate,
2189                array_index,
2190                sample,
2191                level,
2192            } => {
2193                let address = TexelAddress {
2194                    coordinate,
2195                    array_index,
2196                    sample,
2197                    level: level.map(LevelOfDetail::Direct),
2198                };
2199                self.put_image_load(expr_handle, image, address, context)?;
2200            }
2201            //Note: for all the queries, the signed integers are expected,
2202            // so a conversion is needed.
2203            crate::Expression::ImageQuery { image, query } => match query {
2204                crate::ImageQuery::Size { level } => {
2205                    self.put_image_size_query(
2206                        image,
2207                        level.map(LevelOfDetail::Direct),
2208                        crate::ScalarKind::Uint,
2209                        context,
2210                    )?;
2211                }
2212                crate::ImageQuery::NumLevels => {
2213                    self.put_expression(image, context, false)?;
2214                    write!(self.out, ".get_num_mip_levels()")?;
2215                }
2216                crate::ImageQuery::NumLayers => {
2217                    self.put_expression(image, context, false)?;
2218                    write!(self.out, ".get_array_size()")?;
2219                }
2220                crate::ImageQuery::NumSamples => {
2221                    self.put_expression(image, context, false)?;
2222                    write!(self.out, ".get_num_samples()")?;
2223                }
2224            },
2225            crate::Expression::Unary { op, expr } => {
2226                let op_str = match op {
2227                    crate::UnaryOperator::Negate => {
2228                        match context.resolve_type(expr).scalar_kind() {
2229                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2230                            _ => "-",
2231                        }
2232                    }
2233                    crate::UnaryOperator::LogicalNot => "!",
2234                    crate::UnaryOperator::BitwiseNot => "~",
2235                };
2236                write!(self.out, "{op_str}(")?;
2237                self.put_expression(expr, context, false)?;
2238                write!(self.out, ")")?;
2239            }
2240            crate::Expression::Binary { op, left, right } => {
2241                let kind = context
2242                    .resolve_type(left)
2243                    .scalar_kind()
2244                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2245
2246                if op == crate::BinaryOperator::Divide
2247                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2248                {
2249                    write!(self.out, "{DIV_FUNCTION}(")?;
2250                    self.put_expression(left, context, true)?;
2251                    write!(self.out, ", ")?;
2252                    self.put_expression(right, context, true)?;
2253                    write!(self.out, ")")?;
2254                } else if op == crate::BinaryOperator::Modulo
2255                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2256                {
2257                    write!(self.out, "{MOD_FUNCTION}(")?;
2258                    self.put_expression(left, context, true)?;
2259                    write!(self.out, ", ")?;
2260                    self.put_expression(right, context, true)?;
2261                    write!(self.out, ")")?;
2262                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2263                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2264                    //
2265                    // float:
2266                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2267                    write!(self.out, "{NAMESPACE}::fmod(")?;
2268                    self.put_expression(left, context, true)?;
2269                    write!(self.out, ", ")?;
2270                    self.put_expression(right, context, true)?;
2271                    write!(self.out, ")")?;
2272                } else if (op == crate::BinaryOperator::Add
2273                    || op == crate::BinaryOperator::Subtract
2274                    || op == crate::BinaryOperator::Multiply)
2275                    && kind == crate::ScalarKind::Sint
2276                {
2277                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2278                        crate::TypeInner::Scalar(scalar) => {
2279                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2280                                kind: crate::ScalarKind::Uint,
2281                                ..scalar
2282                            }))
2283                        }
2284                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2285                            size,
2286                            scalar: crate::Scalar {
2287                                kind: crate::ScalarKind::Uint,
2288                                ..scalar
2289                            },
2290                        }),
2291                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2292                    };
2293
2294                    // Avoid undefined behaviour due to overflowing signed
2295                    // integer arithmetic. Cast the operands to unsigned prior
2296                    // to performing the operation, then cast the result back
2297                    // to signed.
2298                    self.put_bitcasted_expression(
2299                        context.resolve_type(expr_handle),
2300                        expr_handle,
2301                        context,
2302                        &|writer, context, is_scoped| {
2303                            writer.put_binop(
2304                                op,
2305                                left,
2306                                right,
2307                                context,
2308                                is_scoped,
2309                                &|writer, expr, context, _is_scoped| {
2310                                    writer.put_bitcasted_expression(
2311                                        &to_unsigned(context.resolve_type(expr))?,
2312                                        expr,
2313                                        context,
2314                                        &|writer, context, is_scoped| {
2315                                            writer.put_expression(expr, context, is_scoped)
2316                                        },
2317                                    )
2318                                },
2319                            )
2320                        },
2321                    )?;
2322                } else {
2323                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2324                }
2325            }
2326            crate::Expression::Select {
2327                condition,
2328                accept,
2329                reject,
2330            } => match *context.resolve_type(condition) {
2331                crate::TypeInner::Scalar(crate::Scalar {
2332                    kind: crate::ScalarKind::Bool,
2333                    ..
2334                }) => {
2335                    if !is_scoped {
2336                        write!(self.out, "(")?;
2337                    }
2338                    self.put_expression(condition, context, false)?;
2339                    write!(self.out, " ? ")?;
2340                    self.put_expression(accept, context, false)?;
2341                    write!(self.out, " : ")?;
2342                    self.put_expression(reject, context, false)?;
2343                    if !is_scoped {
2344                        write!(self.out, ")")?;
2345                    }
2346                }
2347                crate::TypeInner::Vector {
2348                    scalar:
2349                        crate::Scalar {
2350                            kind: crate::ScalarKind::Bool,
2351                            ..
2352                        },
2353                    ..
2354                } => {
2355                    write!(self.out, "{NAMESPACE}::select(")?;
2356                    self.put_expression(reject, context, true)?;
2357                    write!(self.out, ", ")?;
2358                    self.put_expression(accept, context, true)?;
2359                    write!(self.out, ", ")?;
2360                    self.put_expression(condition, context, true)?;
2361                    write!(self.out, ")")?;
2362                }
2363                ref ty => {
2364                    return Err(Error::GenericValidation(format!(
2365                        "Expected select condition to be a non-bool type, got {ty:?}",
2366                    )))
2367                }
2368            },
2369            crate::Expression::Derivative { axis, expr, .. } => {
2370                use crate::DerivativeAxis as Axis;
2371                let op = match axis {
2372                    Axis::X => "dfdx",
2373                    Axis::Y => "dfdy",
2374                    Axis::Width => "fwidth",
2375                };
2376                write!(self.out, "{NAMESPACE}::{op}")?;
2377                self.put_call_parameters(iter::once(expr), context)?;
2378            }
2379            crate::Expression::Relational { fun, argument } => {
2380                let op = match fun {
2381                    crate::RelationalFunction::Any => "any",
2382                    crate::RelationalFunction::All => "all",
2383                    crate::RelationalFunction::IsNan => "isnan",
2384                    crate::RelationalFunction::IsInf => "isinf",
2385                };
2386                write!(self.out, "{NAMESPACE}::{op}")?;
2387                self.put_call_parameters(iter::once(argument), context)?;
2388            }
2389            crate::Expression::Math {
2390                fun,
2391                arg,
2392                arg1,
2393                arg2,
2394                arg3,
2395            } => {
2396                use crate::MathFunction as Mf;
2397
2398                let arg_type = context.resolve_type(arg);
2399                let scalar_argument = match arg_type {
2400                    &crate::TypeInner::Scalar(_) => true,
2401                    _ => false,
2402                };
2403
2404                let fun_name = match fun {
2405                    // comparison
2406                    Mf::Abs => "abs",
2407                    Mf::Min => "min",
2408                    Mf::Max => "max",
2409                    Mf::Clamp => "clamp",
2410                    Mf::Saturate => "saturate",
2411                    // trigonometry
2412                    Mf::Cos => "cos",
2413                    Mf::Cosh => "cosh",
2414                    Mf::Sin => "sin",
2415                    Mf::Sinh => "sinh",
2416                    Mf::Tan => "tan",
2417                    Mf::Tanh => "tanh",
2418                    Mf::Acos => "acos",
2419                    Mf::Asin => "asin",
2420                    Mf::Atan => "atan",
2421                    Mf::Atan2 => "atan2",
2422                    Mf::Asinh => "asinh",
2423                    Mf::Acosh => "acosh",
2424                    Mf::Atanh => "atanh",
2425                    Mf::Radians => "",
2426                    Mf::Degrees => "",
2427                    // decomposition
2428                    Mf::Ceil => "ceil",
2429                    Mf::Floor => "floor",
2430                    Mf::Round => "rint",
2431                    Mf::Fract => "fract",
2432                    Mf::Trunc => "trunc",
2433                    Mf::Modf => MODF_FUNCTION,
2434                    Mf::Frexp => FREXP_FUNCTION,
2435                    Mf::Ldexp => "ldexp",
2436                    // exponent
2437                    Mf::Exp => "exp",
2438                    Mf::Exp2 => "exp2",
2439                    Mf::Log => "log",
2440                    Mf::Log2 => "log2",
2441                    Mf::Pow => "pow",
2442                    // geometry
2443                    Mf::Dot => match *context.resolve_type(arg) {
2444                        crate::TypeInner::Vector {
2445                            scalar:
2446                                crate::Scalar {
2447                                    // Resolve float values to MSL's builtin dot function.
2448                                    kind: crate::ScalarKind::Float,
2449                                    ..
2450                                },
2451                            ..
2452                        } => "dot",
2453                        crate::TypeInner::Vector {
2454                            size,
2455                            scalar:
2456                                scalar @ crate::Scalar {
2457                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2458                                    ..
2459                                },
2460                        } => {
2461                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2462                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2463                            write!(self.out, "{fun_name}(")?;
2464                            self.put_expression(arg, context, true)?;
2465                            write!(self.out, ", ")?;
2466                            self.put_expression(arg1.unwrap(), context, true)?;
2467                            write!(self.out, ")")?;
2468                            return Ok(());
2469                        }
2470                        _ => unreachable!(
2471                            "Correct TypeInner for dot product should be already validated"
2472                        ),
2473                    },
2474                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2475                        if context.lang_version >= (2, 1) {
2476                            // Write potentially optimizable code using `packed_(u?)char4`.
2477                            // The two function arguments were already reinterpreted as packed (signed
2478                            // or unsigned) chars in `Self::put_block`.
2479                            let packed_type = match fun {
2480                                Mf::Dot4I8Packed => "packed_char4",
2481                                Mf::Dot4U8Packed => "packed_uchar4",
2482                                _ => unreachable!(),
2483                            };
2484
2485                            return self.put_dot_product(
2486                                Reinterpreted::new(packed_type, arg),
2487                                Reinterpreted::new(packed_type, arg1.unwrap()),
2488                                4,
2489                                |writer, arg, index| {
2490                                    // MSL implicitly promotes these (signed or unsigned) chars to
2491                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2492                                    write!(writer.out, "{arg}[{index}]")?;
2493                                    Ok(())
2494                                },
2495                            );
2496                        } else {
2497                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2498                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2499                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2500                            let conversion = match fun {
2501                                Mf::Dot4I8Packed => "int",
2502                                Mf::Dot4U8Packed => "",
2503                                _ => unreachable!(),
2504                            };
2505
2506                            return self.put_dot_product(
2507                                arg,
2508                                arg1.unwrap(),
2509                                4,
2510                                |writer, arg, index| {
2511                                    write!(writer.out, "({conversion}(")?;
2512                                    writer.put_expression(arg, context, true)?;
2513                                    if index == 3 {
2514                                        write!(writer.out, ") >> 24)")?;
2515                                    } else {
2516                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2517                                    }
2518                                    Ok(())
2519                                },
2520                            );
2521                        }
2522                    }
2523                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2524                    Mf::Cross => "cross",
2525                    Mf::Distance => "distance",
2526                    Mf::Length if scalar_argument => "abs",
2527                    Mf::Length => "length",
2528                    Mf::Normalize => "normalize",
2529                    Mf::FaceForward => "faceforward",
2530                    Mf::Reflect => "reflect",
2531                    Mf::Refract => "refract",
2532                    // computational
2533                    Mf::Sign => match arg_type.scalar_kind() {
2534                        Some(crate::ScalarKind::Sint) => {
2535                            return self.put_isign(arg, context);
2536                        }
2537                        _ => "sign",
2538                    },
2539                    Mf::Fma => "fma",
2540                    Mf::Mix => "mix",
2541                    Mf::Step => "step",
2542                    Mf::SmoothStep => "smoothstep",
2543                    Mf::Sqrt => "sqrt",
2544                    Mf::InverseSqrt => "rsqrt",
2545                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2546                    Mf::Transpose => "transpose",
2547                    Mf::Determinant => "determinant",
2548                    Mf::QuantizeToF16 => "",
2549                    // bits
2550                    Mf::CountTrailingZeros => "ctz",
2551                    Mf::CountLeadingZeros => "clz",
2552                    Mf::CountOneBits => "popcount",
2553                    Mf::ReverseBits => "reverse_bits",
2554                    Mf::ExtractBits => "",
2555                    Mf::InsertBits => "",
2556                    Mf::FirstTrailingBit => "",
2557                    Mf::FirstLeadingBit => "",
2558                    // data packing
2559                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2560                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2561                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2562                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2563                    Mf::Pack2x16float => "",
2564                    Mf::Pack4xI8 => "",
2565                    Mf::Pack4xU8 => "",
2566                    Mf::Pack4xI8Clamp => "",
2567                    Mf::Pack4xU8Clamp => "",
2568                    // data unpacking
2569                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2570                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2571                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2572                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2573                    Mf::Unpack2x16float => "",
2574                    Mf::Unpack4xI8 => "",
2575                    Mf::Unpack4xU8 => "",
2576                };
2577
2578                match fun {
2579                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2580                        // reverse_bits is listed as requiring MSL 2.1 but that
2581                        // is a copy/paste error. Looking at previous snapshots
2582                        // on web.archive.org it's present in MSL 1.2.
2583                        //
2584                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2585                        // also talks about MSL 1.2 adding "New integer
2586                        // functions to extract, insert, and reverse bits, as
2587                        // described in Integer Functions."
2588                        if context.lang_version < (1, 2) {
2589                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2590                        }
2591                    }
2592                    _ => {}
2593                }
2594
2595                match fun {
2596                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2597                        write!(self.out, "{ABS_FUNCTION}(")?;
2598                        self.put_expression(arg, context, true)?;
2599                        write!(self.out, ")")?;
2600                    }
2601                    Mf::Distance if scalar_argument => {
2602                        write!(self.out, "{NAMESPACE}::abs(")?;
2603                        self.put_expression(arg, context, false)?;
2604                        write!(self.out, " - ")?;
2605                        self.put_expression(arg1.unwrap(), context, false)?;
2606                        write!(self.out, ")")?;
2607                    }
2608                    Mf::FirstTrailingBit => {
2609                        let scalar = context.resolve_type(arg).scalar().unwrap();
2610                        let constant = scalar.width * 8 + 1;
2611
2612                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2613                        self.put_expression(arg, context, true)?;
2614                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2615                    }
2616                    Mf::FirstLeadingBit => {
2617                        let inner = context.resolve_type(arg);
2618                        let scalar = inner.scalar().unwrap();
2619                        let constant = scalar.width * 8 - 1;
2620
2621                        write!(
2622                            self.out,
2623                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2624                        )?;
2625
2626                        if scalar.kind == crate::ScalarKind::Sint {
2627                            write!(self.out, "{NAMESPACE}::select(")?;
2628                            self.put_expression(arg, context, true)?;
2629                            write!(self.out, ", ~")?;
2630                            self.put_expression(arg, context, true)?;
2631                            write!(self.out, ", ")?;
2632                            self.put_expression(arg, context, true)?;
2633                            write!(self.out, " < 0)")?;
2634                        } else {
2635                            self.put_expression(arg, context, true)?;
2636                        }
2637
2638                        write!(self.out, "), ")?;
2639
2640                        // or metal will complain that select is ambiguous
2641                        match *inner {
2642                            crate::TypeInner::Vector { size, scalar } => {
2643                                let size = common::vector_size_str(size);
2644                                let name = scalar.to_msl_name();
2645                                write!(self.out, "{name}{size}")?;
2646                            }
2647                            crate::TypeInner::Scalar(scalar) => {
2648                                let name = scalar.to_msl_name();
2649                                write!(self.out, "{name}")?;
2650                            }
2651                            _ => (),
2652                        }
2653
2654                        write!(self.out, "(-1), ")?;
2655                        self.put_expression(arg, context, true)?;
2656                        write!(self.out, " == 0")?;
2657                        if scalar.kind == crate::ScalarKind::Sint {
2658                            write!(self.out, " || ")?;
2659                            self.put_expression(arg, context, true)?;
2660                            write!(self.out, " == -1")?;
2661                        }
2662                        write!(self.out, ")")?;
2663                    }
2664                    Mf::Unpack2x16float => {
2665                        write!(self.out, "float2(as_type<half2>(")?;
2666                        self.put_expression(arg, context, false)?;
2667                        write!(self.out, "))")?;
2668                    }
2669                    Mf::Pack2x16float => {
2670                        write!(self.out, "as_type<uint>(half2(")?;
2671                        self.put_expression(arg, context, false)?;
2672                        write!(self.out, "))")?;
2673                    }
2674                    Mf::ExtractBits => {
2675                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2676                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2677                        // will return out-of-spec values if the extracted range is not within the bit width.
2678                        //
2679                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2680                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2681                        //
2682                        // w = sizeof(x) * 8
2683                        // o = min(offset, w)
2684                        // tmp = w - o
2685                        // c = min(count, tmp)
2686                        //
2687                        // bitfieldExtract(x, o, c)
2688                        //
2689                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2690
2691                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2692
2693                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2694                        self.put_expression(arg, context, true)?;
2695                        write!(self.out, ", {NAMESPACE}::min(")?;
2696                        self.put_expression(arg1.unwrap(), context, true)?;
2697                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2698                        self.put_expression(arg2.unwrap(), context, true)?;
2699                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2700                        self.put_expression(arg1.unwrap(), context, true)?;
2701                        write!(self.out, ", {scalar_bits}u)))")?;
2702                    }
2703                    Mf::InsertBits => {
2704                        // The behavior of InsertBits has the same issue as ExtractBits.
2705                        //
2706                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2707
2708                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2709
2710                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2711                        self.put_expression(arg, context, true)?;
2712                        write!(self.out, ", ")?;
2713                        self.put_expression(arg1.unwrap(), context, true)?;
2714                        write!(self.out, ", {NAMESPACE}::min(")?;
2715                        self.put_expression(arg2.unwrap(), context, true)?;
2716                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2717                        self.put_expression(arg3.unwrap(), context, true)?;
2718                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2719                        self.put_expression(arg2.unwrap(), context, true)?;
2720                        write!(self.out, ", {scalar_bits}u)))")?;
2721                    }
2722                    Mf::Radians => {
2723                        write!(self.out, "((")?;
2724                        self.put_expression(arg, context, false)?;
2725                        write!(self.out, ") * 0.017453292519943295474)")?;
2726                    }
2727                    Mf::Degrees => {
2728                        write!(self.out, "((")?;
2729                        self.put_expression(arg, context, false)?;
2730                        write!(self.out, ") * 57.295779513082322865)")?;
2731                    }
2732                    Mf::Modf | Mf::Frexp => {
2733                        write!(self.out, "{fun_name}")?;
2734                        self.put_call_parameters(iter::once(arg), context)?;
2735                    }
2736                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2737                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2738                    Mf::Pack4xI8Clamp => {
2739                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2740                    }
2741                    Mf::Pack4xU8Clamp => {
2742                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2743                    }
2744                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2745                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2746                            "u"
2747                        } else {
2748                            ""
2749                        };
2750
2751                        if context.lang_version >= (2, 1) {
2752                            // Metal uses little endian byte order, which matches what WGSL expects here.
2753                            write!(
2754                                self.out,
2755                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2756                            )?;
2757                            self.put_expression(arg, context, true)?;
2758                            write!(self.out, "))")?;
2759                        } else {
2760                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2761                            write!(self.out, "({sign_prefix}int4(")?;
2762                            self.put_expression(arg, context, true)?;
2763                            write!(self.out, ", ")?;
2764                            self.put_expression(arg, context, true)?;
2765                            write!(self.out, " >> 8, ")?;
2766                            self.put_expression(arg, context, true)?;
2767                            write!(self.out, " >> 16, ")?;
2768                            self.put_expression(arg, context, true)?;
2769                            write!(self.out, " >> 24) << 24 >> 24)")?;
2770                        }
2771                    }
2772                    Mf::QuantizeToF16 => {
2773                        match *context.resolve_type(arg) {
2774                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2775                            crate::TypeInner::Vector { size, .. } => write!(
2776                                self.out,
2777                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2778                                size = common::vector_size_str(size),
2779                            )?,
2780                            _ => unreachable!(
2781                                "Correct TypeInner for QuantizeToF16 should be already validated"
2782                            ),
2783                        };
2784
2785                        self.put_expression(arg, context, true)?;
2786                        write!(self.out, "))")?;
2787                    }
2788                    _ => {
2789                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2790                        self.put_call_parameters(
2791                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2792                            context,
2793                        )?;
2794                    }
2795                }
2796            }
2797            crate::Expression::As {
2798                expr,
2799                kind,
2800                convert,
2801            } => match *context.resolve_type(expr) {
2802                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2803                    if src.kind == crate::ScalarKind::Float
2804                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2805                        && convert.is_some()
2806                    {
2807                        // Use helper functions for float to int casts in order to avoid
2808                        // undefined behaviour when value is out of range for the target
2809                        // type.
2810                        let fun_name = match (kind, convert) {
2811                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2812                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2813                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2814                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2815                            _ => unreachable!(),
2816                        };
2817                        write!(self.out, "{fun_name}(")?;
2818                        self.put_expression(expr, context, true)?;
2819                        write!(self.out, ")")?;
2820                    } else {
2821                        let target_scalar = crate::Scalar {
2822                            kind,
2823                            width: convert.unwrap_or(src.width),
2824                        };
2825                        let op = match convert {
2826                            Some(_) => "static_cast",
2827                            None => "as_type",
2828                        };
2829                        write!(self.out, "{op}<")?;
2830                        match *context.resolve_type(expr) {
2831                            crate::TypeInner::Vector { size, .. } => {
2832                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2833                            }
2834                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2835                        };
2836                        write!(self.out, ">(")?;
2837                        self.put_expression(expr, context, true)?;
2838                        write!(self.out, ")")?;
2839                    }
2840                }
2841                crate::TypeInner::Matrix {
2842                    columns,
2843                    rows,
2844                    scalar,
2845                } => {
2846                    let target_scalar = crate::Scalar {
2847                        kind,
2848                        width: convert.unwrap_or(scalar.width),
2849                    };
2850                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2851                    write!(self.out, "(")?;
2852                    self.put_expression(expr, context, true)?;
2853                    write!(self.out, ")")?;
2854                }
2855                ref ty => {
2856                    return Err(Error::GenericValidation(format!(
2857                        "Unsupported type for As: {ty:?}"
2858                    )))
2859                }
2860            },
2861            // has to be a named expression
2862            crate::Expression::CallResult(_)
2863            | crate::Expression::AtomicResult { .. }
2864            | crate::Expression::WorkGroupUniformLoadResult { .. }
2865            | crate::Expression::SubgroupBallotResult
2866            | crate::Expression::SubgroupOperationResult { .. }
2867            | crate::Expression::RayQueryProceedResult => {
2868                unreachable!()
2869            }
2870            crate::Expression::ArrayLength(expr) => {
2871                // Find the global to which the array belongs.
2872                let global = match context.function.expressions[expr] {
2873                    crate::Expression::AccessIndex { base, .. } => {
2874                        match context.function.expressions[base] {
2875                            crate::Expression::GlobalVariable(handle) => handle,
2876                            ref ex => {
2877                                return Err(Error::GenericValidation(format!(
2878                                    "Expected global variable in AccessIndex, got {ex:?}"
2879                                )))
2880                            }
2881                        }
2882                    }
2883                    crate::Expression::GlobalVariable(handle) => handle,
2884                    ref ex => {
2885                        return Err(Error::GenericValidation(format!(
2886                            "Unexpected expression in ArrayLength, got {ex:?}"
2887                        )))
2888                    }
2889                };
2890
2891                if !is_scoped {
2892                    write!(self.out, "(")?;
2893                }
2894                write!(self.out, "1 + ")?;
2895                self.put_dynamic_array_max_index(global, context)?;
2896                if !is_scoped {
2897                    write!(self.out, ")")?;
2898                }
2899            }
2900            crate::Expression::RayQueryVertexPositions { .. } => {
2901                unimplemented!()
2902            }
2903            crate::Expression::RayQueryGetIntersection {
2904                query,
2905                committed: _,
2906            } => {
2907                if context.lang_version < (2, 4) {
2908                    return Err(Error::UnsupportedRayTracing);
2909                }
2910
2911                let ty = context.module.special_types.ray_intersection.unwrap();
2912                let type_name = &self.names[&NameKey::Type(ty)];
2913                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2914                self.put_expression(query, context, true)?;
2915                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2916                let fields = [
2917                    "distance",
2918                    "user_instance_id", // req Metal 2.4
2919                    "instance_id",
2920                    "", // SBT offset
2921                    "geometry_id",
2922                    "primitive_id",
2923                    "triangle_barycentric_coord",
2924                    "triangle_front_facing",
2925                    "",                          // padding
2926                    "object_to_world_transform", // req Metal 2.4
2927                    "world_to_object_transform", // req Metal 2.4
2928                ];
2929                for field in fields {
2930                    write!(self.out, ", ")?;
2931                    if field.is_empty() {
2932                        write!(self.out, "{{}}")?;
2933                    } else {
2934                        self.put_expression(query, context, true)?;
2935                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2936                    }
2937                }
2938                write!(self.out, "}}")?;
2939            }
2940            crate::Expression::CooperativeLoad { ref data, .. } => {
2941                if context.lang_version < (2, 3) {
2942                    return Err(Error::UnsupportedCooperativeMatrix);
2943                }
2944                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2945                write!(self.out, "&")?;
2946                self.put_access_chain(data.pointer, context.policies.index, context)?;
2947                write!(self.out, ", ")?;
2948                self.put_expression(data.stride, context, true)?;
2949                write!(self.out, ", {})", data.row_major)?;
2950            }
2951            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2952                if context.lang_version < (2, 3) {
2953                    return Err(Error::UnsupportedCooperativeMatrix);
2954                }
2955                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2956                self.put_expression(a, context, true)?;
2957                write!(self.out, ", ")?;
2958                self.put_expression(b, context, true)?;
2959                write!(self.out, ", ")?;
2960                self.put_expression(c, context, true)?;
2961                write!(self.out, ")")?;
2962            }
2963        }
2964        Ok(())
2965    }
2966
2967    /// Emits code for a binary operation, using the provided callback to emit
2968    /// the left and right operands.
2969    fn put_binop<F>(
2970        &mut self,
2971        op: crate::BinaryOperator,
2972        left: Handle<crate::Expression>,
2973        right: Handle<crate::Expression>,
2974        context: &ExpressionContext,
2975        is_scoped: bool,
2976        put_expression: &F,
2977    ) -> BackendResult
2978    where
2979        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2980    {
2981        let op_str = back::binary_operation_str(op);
2982
2983        if !is_scoped {
2984            write!(self.out, "(")?;
2985        }
2986
2987        // Cast packed vector if necessary
2988        // Packed vector - matrix multiplications are not supported in MSL
2989        if op == crate::BinaryOperator::Multiply
2990            && matches!(
2991                context.resolve_type(right),
2992                &crate::TypeInner::Matrix { .. }
2993            )
2994        {
2995            self.put_wrapped_expression_for_packed_vec3_access(
2996                left,
2997                context,
2998                false,
2999                put_expression,
3000            )?;
3001        } else {
3002            put_expression(self, left, context, false)?;
3003        }
3004
3005        write!(self.out, " {op_str} ")?;
3006
3007        // See comment above
3008        if op == crate::BinaryOperator::Multiply
3009            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3010        {
3011            self.put_wrapped_expression_for_packed_vec3_access(
3012                right,
3013                context,
3014                false,
3015                put_expression,
3016            )?;
3017        } else {
3018            put_expression(self, right, context, false)?;
3019        }
3020
3021        if !is_scoped {
3022            write!(self.out, ")")?;
3023        }
3024
3025        Ok(())
3026    }
3027
3028    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
3029    fn put_wrapped_expression_for_packed_vec3_access<F>(
3030        &mut self,
3031        expr_handle: Handle<crate::Expression>,
3032        context: &ExpressionContext,
3033        is_scoped: bool,
3034        put_expression: &F,
3035    ) -> BackendResult
3036    where
3037        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3038    {
3039        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3040            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3041            put_expression(self, expr_handle, context, is_scoped)?;
3042            write!(self.out, ")")?;
3043        } else {
3044            put_expression(self, expr_handle, context, is_scoped)?;
3045        }
3046        Ok(())
3047    }
3048
3049    /// Emits code for an expression using the provided callback, wrapping the
3050    /// result in a bitcast to the type `cast_to`.
3051    fn put_bitcasted_expression<F>(
3052        &mut self,
3053        cast_to: &crate::TypeInner,
3054        inner_expr: Handle<crate::Expression>,
3055        context: &ExpressionContext,
3056        put_expression: &F,
3057    ) -> BackendResult
3058    where
3059        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3060    {
3061        write!(self.out, "as_type<")?;
3062        match *cast_to {
3063            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3064            crate::TypeInner::Vector { size, scalar } => {
3065                put_numeric_type(&mut self.out, scalar, &[size])?
3066            }
3067            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3068        };
3069        write!(self.out, ">(")?;
3070
3071        // if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
3072        if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3073            put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3074            write!(self.out, "(")?;
3075            put_expression(self, context, true)?;
3076            write!(self.out, ")")?;
3077        } else {
3078            put_expression(self, context, true)?;
3079        }
3080
3081        write!(self.out, ")")?;
3082        Ok(())
3083    }
3084
3085    /// Write a `GuardedIndex` as a Metal expression.
3086    fn put_index(
3087        &mut self,
3088        index: index::GuardedIndex,
3089        context: &ExpressionContext,
3090        is_scoped: bool,
3091    ) -> BackendResult {
3092        match index {
3093            index::GuardedIndex::Expression(expr) => {
3094                self.put_expression(expr, context, is_scoped)?
3095            }
3096            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3097        }
3098        Ok(())
3099    }
3100
3101    /// Emit an index bounds check condition for `chain`, if required.
3102    ///
3103    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3104    /// operating either on a pointer to a value, or on a value directly. If we cannot
3105    /// statically determine that all indexing operations in `chain` are within
3106    /// bounds, then write a conditional expression to check them dynamically,
3107    /// and return true. All accesses in the chain are checked by the generated
3108    /// expression.
3109    ///
3110    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3111    ///
3112    /// The text written is of the form:
3113    ///
3114    /// ```ignore
3115    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3116    /// ```
3117    ///
3118    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3119    /// statements, presumably these arguments start an indented `if` statement; for
3120    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3121    /// expression. In either case, what is written is not a complete syntactic structure
3122    /// in its own right, and the caller will have to finish it off if we return `true`.
3123    ///
3124    /// If no expression is written, return false.
3125    ///
3126    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3127    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3128    /// [`Store`]: crate::Statement::Store
3129    /// [`Load`]: crate::Expression::Load
3130    fn put_bounds_checks(
3131        &mut self,
3132        chain: Handle<crate::Expression>,
3133        context: &ExpressionContext,
3134        level: back::Level,
3135        prefix: &'static str,
3136    ) -> Result<bool, Error> {
3137        let mut check_written = false;
3138
3139        // Iterate over the access chain, handling each required bounds check.
3140        for item in context.bounds_check_iter(chain) {
3141            let BoundsCheck {
3142                base,
3143                index,
3144                length,
3145            } = item;
3146
3147            if check_written {
3148                write!(self.out, " && ")?;
3149            } else {
3150                write!(self.out, "{level}{prefix}")?;
3151                check_written = true;
3152            }
3153
3154            // Check that the index falls within bounds. Do this with a single
3155            // comparison, by casting the index to `uint` first, so that negative
3156            // indices become large positive values.
3157            write!(self.out, "uint(")?;
3158            self.put_index(index, context, true)?;
3159            self.out.write_str(") < ")?;
3160            match length {
3161                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3162                index::IndexableLength::Dynamic => {
3163                    let global = context.function.originating_global(base).ok_or_else(|| {
3164                        Error::GenericValidation("Could not find originating global".into())
3165                    })?;
3166                    write!(self.out, "1 + ")?;
3167                    self.put_dynamic_array_max_index(global, context)?
3168                }
3169            }
3170        }
3171
3172        Ok(check_written)
3173    }
3174
3175    /// Write the access chain `chain`.
3176    ///
3177    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3178    /// operating either on a pointer to a value, or on a value directly.
3179    ///
3180    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3181    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3182    /// that must be handled in the caller.
3183    ///
3184    /// Handle the entire chain, recursing back into `put_expression` only for index
3185    /// expressions and the base expression that originates the pointer or composite value
3186    /// being accessed. This allows `put_expression` to assume that any `Access` or
3187    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3188    /// `ReadZeroSkipWrite` checks.
3189    ///
3190    /// [`Access`]: crate::Expression::Access
3191    /// [`AccessIndex`]: crate::Expression::AccessIndex
3192    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3193    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3194    fn put_access_chain(
3195        &mut self,
3196        chain: Handle<crate::Expression>,
3197        policy: index::BoundsCheckPolicy,
3198        context: &ExpressionContext,
3199    ) -> BackendResult {
3200        match context.function.expressions[chain] {
3201            crate::Expression::Access { base, index } => {
3202                let mut base_ty = context.resolve_type(base);
3203
3204                // Look through any pointers to see what we're really indexing.
3205                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3206                    base_ty = &context.module.types[base].inner;
3207                }
3208
3209                self.put_subscripted_access_chain(
3210                    base,
3211                    base_ty,
3212                    index::GuardedIndex::Expression(index),
3213                    policy,
3214                    context,
3215                )?;
3216            }
3217            crate::Expression::AccessIndex { base, index } => {
3218                let base_resolution = &context.info[base].ty;
3219                let mut base_ty = base_resolution.inner_with(&context.module.types);
3220                let mut base_ty_handle = base_resolution.handle();
3221
3222                // Look through any pointers to see what we're really indexing.
3223                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3224                    base_ty = &context.module.types[base].inner;
3225                    base_ty_handle = Some(base);
3226                }
3227
3228                // Handle structs and anything else that can use `.x` syntax here, so
3229                // `put_subscripted_access_chain` won't have to handle the absurd case of
3230                // indexing a struct with an expression.
3231                match *base_ty {
3232                    crate::TypeInner::Struct { .. } => {
3233                        let base_ty = base_ty_handle.unwrap();
3234                        self.put_access_chain(base, policy, context)?;
3235                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3236                        write!(self.out, ".{name}")?;
3237                    }
3238                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3239                        self.put_access_chain(base, policy, context)?;
3240                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3241                        // however array indexing is
3242                        if context.get_packed_vec_kind(base).is_some() {
3243                            write!(self.out, "[{index}]")?;
3244                        } else {
3245                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3246                        }
3247                    }
3248                    _ => {
3249                        self.put_subscripted_access_chain(
3250                            base,
3251                            base_ty,
3252                            index::GuardedIndex::Known(index),
3253                            policy,
3254                            context,
3255                        )?;
3256                    }
3257                }
3258            }
3259            _ => self.put_expression(chain, context, false)?,
3260        }
3261
3262        Ok(())
3263    }
3264
3265    /// Write a `[]`-style access of `base` by `index`.
3266    ///
3267    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3268    /// values within bounds.
3269    ///
3270    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3271    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3272    /// removed. Our callers often already have this handy.
3273    ///
3274    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3275    /// referencing vector components by name.
3276    ///
3277    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3278    /// [`Array`]: crate::TypeInner::Array
3279    /// [`Vector`]: crate::TypeInner::Vector
3280    /// [`Pointer`]: crate::TypeInner::Pointer
3281    fn put_subscripted_access_chain(
3282        &mut self,
3283        base: Handle<crate::Expression>,
3284        base_ty: &crate::TypeInner,
3285        index: index::GuardedIndex,
3286        policy: index::BoundsCheckPolicy,
3287        context: &ExpressionContext,
3288    ) -> BackendResult {
3289        let accessing_wrapped_array = match *base_ty {
3290            crate::TypeInner::Array {
3291                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3292                ..
3293            } => true,
3294            _ => false,
3295        };
3296        let accessing_wrapped_binding_array =
3297            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3298
3299        self.put_access_chain(base, policy, context)?;
3300        if accessing_wrapped_array {
3301            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3302        }
3303        write!(self.out, "[")?;
3304
3305        // Decide whether this index needs to be clamped to fall within range.
3306        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3307            context.access_needs_check(base, index)
3308        } else {
3309            None
3310        };
3311        if let Some(limit) = restriction_needed {
3312            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3313            self.put_index(index, context, true)?;
3314            write!(self.out, "), ")?;
3315            match limit {
3316                index::IndexableLength::Known(limit) => {
3317                    write!(self.out, "{}u", limit - 1)?;
3318                }
3319                index::IndexableLength::Dynamic => {
3320                    let global = context.function.originating_global(base).ok_or_else(|| {
3321                        Error::GenericValidation("Could not find originating global".into())
3322                    })?;
3323                    self.put_dynamic_array_max_index(global, context)?;
3324                }
3325            }
3326            write!(self.out, ")")?;
3327        } else {
3328            self.put_index(index, context, true)?;
3329        }
3330
3331        write!(self.out, "]")?;
3332
3333        if accessing_wrapped_binding_array {
3334            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3335        }
3336
3337        Ok(())
3338    }
3339
3340    fn put_load(
3341        &mut self,
3342        pointer: Handle<crate::Expression>,
3343        context: &ExpressionContext,
3344        is_scoped: bool,
3345    ) -> BackendResult {
3346        // Since access chains never cross between address spaces, we can just
3347        // check the index bounds check policy once at the top.
3348        let policy = context.choose_bounds_check_policy(pointer);
3349        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3350            && self.put_bounds_checks(
3351                pointer,
3352                context,
3353                back::Level(0),
3354                if is_scoped { "" } else { "(" },
3355            )?
3356        {
3357            write!(self.out, " ? ")?;
3358            self.put_unchecked_load(pointer, policy, context)?;
3359            write!(self.out, " : DefaultConstructible()")?;
3360
3361            if !is_scoped {
3362                write!(self.out, ")")?;
3363            }
3364        } else {
3365            self.put_unchecked_load(pointer, policy, context)?;
3366        }
3367
3368        Ok(())
3369    }
3370
3371    fn put_unchecked_load(
3372        &mut self,
3373        pointer: Handle<crate::Expression>,
3374        policy: index::BoundsCheckPolicy,
3375        context: &ExpressionContext,
3376    ) -> BackendResult {
3377        let is_atomic_pointer = context
3378            .resolve_type(pointer)
3379            .is_atomic_pointer(&context.module.types);
3380
3381        if is_atomic_pointer {
3382            write!(
3383                self.out,
3384                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3385            )?;
3386            self.put_access_chain(pointer, policy, context)?;
3387            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3388        } else {
3389            // We don't do any dereferencing with `*` here as pointer arguments to functions
3390            // are done by `&` references and not `*` pointers. These do not need to be
3391            // dereferenced.
3392            self.put_access_chain(pointer, policy, context)?;
3393        }
3394
3395        Ok(())
3396    }
3397
3398    fn put_return_value(
3399        &mut self,
3400        level: back::Level,
3401        expr_handle: Handle<crate::Expression>,
3402        result_struct: Option<&str>,
3403        context: &ExpressionContext,
3404    ) -> BackendResult {
3405        match result_struct {
3406            Some(struct_name) => {
3407                let mut has_point_size = false;
3408                let result_ty = context.function.result.as_ref().unwrap().ty;
3409                match context.module.types[result_ty].inner {
3410                    crate::TypeInner::Struct { ref members, .. } => {
3411                        let tmp = "_tmp";
3412                        write!(self.out, "{level}const auto {tmp} = ")?;
3413                        self.put_expression(expr_handle, context, true)?;
3414                        writeln!(self.out, ";")?;
3415                        write!(self.out, "{level}return {struct_name} {{")?;
3416
3417                        let mut is_first = true;
3418
3419                        for (index, member) in members.iter().enumerate() {
3420                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3421                                member.binding
3422                            {
3423                                has_point_size = true;
3424                                if !context.pipeline_options.allow_and_force_point_size {
3425                                    continue;
3426                                }
3427                            }
3428
3429                            let comma = if is_first { "" } else { "," };
3430                            is_first = false;
3431                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3432                            // HACK: we are forcefully deduplicating the expression here
3433                            // to convert from a wrapped struct to a raw array, e.g.
3434                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3435                            if let crate::TypeInner::Array {
3436                                size: crate::ArraySize::Constant(size),
3437                                ..
3438                            } = context.module.types[member.ty].inner
3439                            {
3440                                write!(self.out, "{comma} {{")?;
3441                                for j in 0..size.get() {
3442                                    if j != 0 {
3443                                        write!(self.out, ",")?;
3444                                    }
3445                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3446                                }
3447                                write!(self.out, "}}")?;
3448                            } else {
3449                                write!(self.out, "{comma} {tmp}.{name}")?;
3450                            }
3451                        }
3452                    }
3453                    _ => {
3454                        write!(self.out, "{level}return {struct_name} {{ ")?;
3455                        self.put_expression(expr_handle, context, true)?;
3456                    }
3457                }
3458
3459                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3460                    let stage = context.module.entry_points[ep_index as usize].stage;
3461                    if context.pipeline_options.allow_and_force_point_size
3462                        && stage == crate::ShaderStage::Vertex
3463                        && !has_point_size
3464                    {
3465                        // point size was injected and comes last
3466                        write!(self.out, ", 1.0")?;
3467                    }
3468                }
3469                write!(self.out, " }}")?;
3470            }
3471            None => {
3472                write!(self.out, "{level}return ")?;
3473                self.put_expression(expr_handle, context, true)?;
3474            }
3475        }
3476        writeln!(self.out, ";")?;
3477        Ok(())
3478    }
3479
3480    /// Helper method used to find which expressions of a given function require baking
3481    ///
3482    /// # Notes
3483    /// This function overwrites the contents of `self.need_bake_expressions`
3484    fn update_expressions_to_bake(
3485        &mut self,
3486        func: &crate::Function,
3487        info: &valid::FunctionInfo,
3488        context: &ExpressionContext,
3489    ) {
3490        use crate::Expression;
3491        self.need_bake_expressions.clear();
3492
3493        for (expr_handle, expr) in func.expressions.iter() {
3494            // Expressions whose reference count is above the
3495            // threshold should always be stored in temporaries.
3496            let expr_info = &info[expr_handle];
3497            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3498            if min_ref_count <= expr_info.ref_count {
3499                self.need_bake_expressions.insert(expr_handle);
3500            } else {
3501                match expr_info.ty {
3502                    // force ray desc to be baked: it's used multiple times internally
3503                    TypeResolution::Handle(h)
3504                        if Some(h) == context.module.special_types.ray_desc =>
3505                    {
3506                        self.need_bake_expressions.insert(expr_handle);
3507                    }
3508                    _ => {}
3509                }
3510            }
3511
3512            if let Expression::Math {
3513                fun,
3514                arg,
3515                arg1,
3516                arg2,
3517                ..
3518            } = *expr
3519            {
3520                match fun {
3521                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3522                    // works on floating-point vectors, so we emit inline code for
3523                    // integer vector `dot` calls. But that code uses each argument `N`
3524                    // times, once for each component (see `put_dot_product`), so to
3525                    // avoid duplicated evaluation, we must bake integer operands.
3526                    // This applies both when using the polyfill (because of the duplicate
3527                    // evaluation issue) and when we don't use the polyfill (because we
3528                    // need them to be emitted before casting to packed chars -- see the
3529                    // comment at the call to `put_casting_to_packed_chars`).
3530                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3531                        self.need_bake_expressions.insert(arg);
3532                        self.need_bake_expressions.insert(arg1.unwrap());
3533                    }
3534                    crate::MathFunction::FirstLeadingBit => {
3535                        self.need_bake_expressions.insert(arg);
3536                    }
3537                    crate::MathFunction::Pack4xI8
3538                    | crate::MathFunction::Pack4xU8
3539                    | crate::MathFunction::Pack4xI8Clamp
3540                    | crate::MathFunction::Pack4xU8Clamp
3541                    | crate::MathFunction::Unpack4xI8
3542                    | crate::MathFunction::Unpack4xU8 => {
3543                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3544                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3545                        if context.lang_version < (2, 1) {
3546                            self.need_bake_expressions.insert(arg);
3547                        }
3548                    }
3549                    crate::MathFunction::ExtractBits => {
3550                        // Only argument 1 is re-used.
3551                        self.need_bake_expressions.insert(arg1.unwrap());
3552                    }
3553                    crate::MathFunction::InsertBits => {
3554                        // Only argument 2 is re-used.
3555                        self.need_bake_expressions.insert(arg2.unwrap());
3556                    }
3557                    crate::MathFunction::Sign => {
3558                        // WGSL's `sign` function works also on signed ints, but Metal's only
3559                        // works on floating points, so we emit inline code for integer `sign`
3560                        // calls. But that code uses each argument 2 times (see `put_isign`),
3561                        // so to avoid duplicated evaluation, we must bake the argument.
3562                        let inner = context.resolve_type(expr_handle);
3563                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3564                            self.need_bake_expressions.insert(arg);
3565                        }
3566                    }
3567                    _ => {}
3568                }
3569            }
3570        }
3571    }
3572
3573    fn start_baking_expression(
3574        &mut self,
3575        handle: Handle<crate::Expression>,
3576        context: &ExpressionContext,
3577        name: &str,
3578    ) -> BackendResult {
3579        match context.info[handle].ty {
3580            TypeResolution::Handle(ty_handle) => {
3581                let ty_name = TypeContext {
3582                    handle: ty_handle,
3583                    gctx: context.module.to_ctx(),
3584                    names: &self.names,
3585                    access: crate::StorageAccess::empty(),
3586                    first_time: false,
3587                };
3588                write!(self.out, "{ty_name}")?;
3589            }
3590            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3591                put_numeric_type(&mut self.out, scalar, &[])?;
3592            }
3593            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3594                put_numeric_type(&mut self.out, scalar, &[size])?;
3595            }
3596            TypeResolution::Value(crate::TypeInner::Matrix {
3597                columns,
3598                rows,
3599                scalar,
3600            }) => {
3601                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3602            }
3603            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3604                columns,
3605                rows,
3606                scalar,
3607                role: _,
3608            }) => {
3609                write!(
3610                    self.out,
3611                    "{}::simdgroup_{}{}x{}",
3612                    NAMESPACE,
3613                    scalar.to_msl_name(),
3614                    columns as u32,
3615                    rows as u32,
3616                )?;
3617            }
3618            TypeResolution::Value(ref other) => {
3619                log::warn!("Type {other:?} isn't a known local");
3620                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3621            }
3622        }
3623
3624        //TODO: figure out the naming scheme that wouldn't collide with user names.
3625        write!(self.out, " {name} = ")?;
3626
3627        Ok(())
3628    }
3629
3630    /// Cache a clamped level of detail value, if necessary.
3631    ///
3632    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3633    /// properly clamped level of detail value both in the access itself, and
3634    /// for fetching the size of the requested MIP level, needed to clamp the
3635    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3636    /// it in a temporary variable, as part of the [`Emit`] statement covering
3637    /// the [`ImageLoad`] expression.
3638    ///
3639    /// [`ImageLoad`]: crate::Expression::ImageLoad
3640    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3641    /// [`Emit`]: crate::Statement::Emit
3642    fn put_cache_restricted_level(
3643        &mut self,
3644        load: Handle<crate::Expression>,
3645        image: Handle<crate::Expression>,
3646        mip_level: Option<Handle<crate::Expression>>,
3647        indent: back::Level,
3648        context: &StatementContext,
3649    ) -> BackendResult {
3650        // Does this image access actually require (or even permit) a
3651        // level-of-detail, and does the policy require us to restrict it?
3652        let level_of_detail = match mip_level {
3653            Some(level) => level,
3654            None => return Ok(()),
3655        };
3656
3657        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3658            || !context.expression.image_needs_lod(image)
3659        {
3660            return Ok(());
3661        }
3662
3663        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3664        self.put_restricted_scalar_image_index(
3665            image,
3666            level_of_detail,
3667            "get_num_mip_levels",
3668            &context.expression,
3669        )?;
3670        writeln!(self.out, ";")?;
3671
3672        Ok(())
3673    }
3674
3675    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3676    ///
3677    /// Caches the results in temporary variables (whose names are derived from
3678    /// the original variable names). This caching avoids the need to redo the
3679    /// casting for each vector component when emitting the dot product.
3680    fn put_casting_to_packed_chars(
3681        &mut self,
3682        fun: crate::MathFunction,
3683        arg0: Handle<crate::Expression>,
3684        arg1: Handle<crate::Expression>,
3685        indent: back::Level,
3686        context: &StatementContext<'_>,
3687    ) -> Result<(), Error> {
3688        let packed_type = match fun {
3689            crate::MathFunction::Dot4I8Packed => "packed_char4",
3690            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3691            _ => unreachable!(),
3692        };
3693
3694        for arg in [arg0, arg1] {
3695            write!(
3696                self.out,
3697                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3698                Reinterpreted::new(packed_type, arg)
3699            )?;
3700            self.put_expression(arg, &context.expression, true)?;
3701            writeln!(self.out, ");")?;
3702        }
3703
3704        Ok(())
3705    }
3706
3707    fn put_block(
3708        &mut self,
3709        level: back::Level,
3710        statements: &[crate::Statement],
3711        context: &StatementContext,
3712    ) -> BackendResult {
3713        // Add to the set in order to track the stack size.
3714        #[cfg(test)]
3715        self.put_block_stack_pointers
3716            .insert(ptr::from_ref(&level).cast());
3717
3718        for statement in statements {
3719            log::trace!("statement[{}] {:?}", level.0, statement);
3720            match *statement {
3721                crate::Statement::Emit(ref range) => {
3722                    for handle in range.clone() {
3723                        use crate::MathFunction as Mf;
3724
3725                        match context.expression.function.expressions[handle] {
3726                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3727                            // may need to cache a clamped version of their level-of-detail argument.
3728                            crate::Expression::ImageLoad {
3729                                image,
3730                                level: mip_level,
3731                                ..
3732                            } => {
3733                                self.put_cache_restricted_level(
3734                                    handle, image, mip_level, level, context,
3735                                )?;
3736                            }
3737
3738                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3739                            // 2.1+ then we introduce two intermediate variables that recast the two
3740                            // arguments as packed (signed or unsigned) chars. The actual dot product
3741                            // is implemented in `Self::put_expression`, and it uses both of these
3742                            // intermediate variables multiple times. There's no danger that the
3743                            // original arguments get modified between the definition of these
3744                            // intermediate variables and the implementation of the actual dot
3745                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3746                            crate::Expression::Math {
3747                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3748                                arg,
3749                                arg1,
3750                                ..
3751                            } if context.expression.lang_version >= (2, 1) => {
3752                                self.put_casting_to_packed_chars(
3753                                    fun,
3754                                    arg,
3755                                    arg1.unwrap(),
3756                                    level,
3757                                    context,
3758                                )?;
3759                            }
3760
3761                            _ => (),
3762                        }
3763
3764                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3765                        let expr_name = if ptr_class.is_some() {
3766                            None // don't bake pointer expressions (just yet)
3767                        } else if let Some(name) =
3768                            context.expression.function.named_expressions.get(&handle)
3769                        {
3770                            // The `crate::Function::named_expressions` table holds
3771                            // expressions that should be saved in temporaries once they
3772                            // are `Emit`ted. We only add them to `self.named_expressions`
3773                            // when we reach the `Emit` that covers them, so that we don't
3774                            // try to use their names before we've actually initialized
3775                            // the temporary that holds them.
3776                            //
3777                            // Don't assume the names in `named_expressions` are unique,
3778                            // or even valid. Use the `Namer`.
3779                            Some(self.namer.call(name))
3780                        } else {
3781                            // If this expression is an index that we're going to first compare
3782                            // against a limit, and then actually use as an index, then we may
3783                            // want to cache it in a temporary, to avoid evaluating it twice.
3784                            let bake = if context.expression.guarded_indices.contains(handle) {
3785                                true
3786                            } else {
3787                                self.need_bake_expressions.contains(&handle)
3788                            };
3789
3790                            if bake {
3791                                Some(Baked(handle).to_string())
3792                            } else {
3793                                None
3794                            }
3795                        };
3796
3797                        if let Some(name) = expr_name {
3798                            write!(self.out, "{level}")?;
3799                            self.start_baking_expression(handle, &context.expression, &name)?;
3800                            self.put_expression(handle, &context.expression, true)?;
3801                            self.named_expressions.insert(handle, name);
3802                            writeln!(self.out, ";")?;
3803                        }
3804                    }
3805                }
3806                crate::Statement::Block(ref block) => {
3807                    if !block.is_empty() {
3808                        writeln!(self.out, "{level}{{")?;
3809                        self.put_block(level.next(), block, context)?;
3810                        writeln!(self.out, "{level}}}")?;
3811                    }
3812                }
3813                crate::Statement::If {
3814                    condition,
3815                    ref accept,
3816                    ref reject,
3817                } => {
3818                    write!(self.out, "{level}if (")?;
3819                    self.put_expression(condition, &context.expression, true)?;
3820                    writeln!(self.out, ") {{")?;
3821                    self.put_block(level.next(), accept, context)?;
3822                    if !reject.is_empty() {
3823                        writeln!(self.out, "{level}}} else {{")?;
3824                        self.put_block(level.next(), reject, context)?;
3825                    }
3826                    writeln!(self.out, "{level}}}")?;
3827                }
3828                crate::Statement::Switch {
3829                    selector,
3830                    ref cases,
3831                } => {
3832                    write!(self.out, "{level}switch(")?;
3833                    self.put_expression(selector, &context.expression, true)?;
3834                    writeln!(self.out, ") {{")?;
3835                    let lcase = level.next();
3836                    for case in cases.iter() {
3837                        match case.value {
3838                            crate::SwitchValue::I32(value) => {
3839                                write!(self.out, "{lcase}case {value}:")?;
3840                            }
3841                            crate::SwitchValue::U32(value) => {
3842                                write!(self.out, "{lcase}case {value}u:")?;
3843                            }
3844                            crate::SwitchValue::Default => {
3845                                write!(self.out, "{lcase}default:")?;
3846                            }
3847                        }
3848
3849                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3850                        if write_block_braces {
3851                            writeln!(self.out, " {{")?;
3852                        } else {
3853                            writeln!(self.out)?;
3854                        }
3855
3856                        self.put_block(lcase.next(), &case.body, context)?;
3857                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3858                        {
3859                            writeln!(self.out, "{}break;", lcase.next())?;
3860                        }
3861
3862                        if write_block_braces {
3863                            writeln!(self.out, "{lcase}}}")?;
3864                        }
3865                    }
3866                    writeln!(self.out, "{level}}}")?;
3867                }
3868                crate::Statement::Loop {
3869                    ref body,
3870                    ref continuing,
3871                    break_if,
3872                } => {
3873                    let force_loop_bound_statements =
3874                        self.gen_force_bounded_loop_statements(level, context);
3875                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3876                        .then(|| self.namer.call("loop_init"));
3877
3878                    if let Some((ref decl, _)) = force_loop_bound_statements {
3879                        writeln!(self.out, "{decl}")?;
3880                    }
3881                    if let Some(ref gate_name) = gate_name {
3882                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3883                    }
3884
3885                    writeln!(self.out, "{level}while(true) {{",)?;
3886                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3887                        writeln!(self.out, "{break_and_inc}")?;
3888                    }
3889                    if let Some(ref gate_name) = gate_name {
3890                        let lif = level.next();
3891                        let lcontinuing = lif.next();
3892                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3893                        self.put_block(lcontinuing, continuing, context)?;
3894                        if let Some(condition) = break_if {
3895                            write!(self.out, "{lcontinuing}if (")?;
3896                            self.put_expression(condition, &context.expression, true)?;
3897                            writeln!(self.out, ") {{")?;
3898                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3899                            writeln!(self.out, "{lcontinuing}}}")?;
3900                        }
3901                        writeln!(self.out, "{lif}}}")?;
3902                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3903                    }
3904                    self.put_block(level.next(), body, context)?;
3905
3906                    writeln!(self.out, "{level}}}")?;
3907                }
3908                crate::Statement::Break => {
3909                    writeln!(self.out, "{level}break;")?;
3910                }
3911                crate::Statement::Continue => {
3912                    writeln!(self.out, "{level}continue;")?;
3913                }
3914                crate::Statement::Return {
3915                    value: Some(expr_handle),
3916                } => {
3917                    self.put_return_value(
3918                        level,
3919                        expr_handle,
3920                        context.result_struct,
3921                        &context.expression,
3922                    )?;
3923                }
3924                crate::Statement::Return { value: None } => {
3925                    writeln!(self.out, "{level}return;")?;
3926                }
3927                crate::Statement::Kill => {
3928                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3929                }
3930                crate::Statement::ControlBarrier(flags)
3931                | crate::Statement::MemoryBarrier(flags) => {
3932                    self.write_barrier(flags, level)?;
3933                }
3934                crate::Statement::Store { pointer, value } => {
3935                    self.put_store(pointer, value, level, context)?
3936                }
3937                crate::Statement::ImageStore {
3938                    image,
3939                    coordinate,
3940                    array_index,
3941                    value,
3942                } => {
3943                    let address = TexelAddress {
3944                        coordinate,
3945                        array_index,
3946                        sample: None,
3947                        level: None,
3948                    };
3949                    self.put_image_store(level, image, &address, value, context)?
3950                }
3951                crate::Statement::Call {
3952                    function,
3953                    ref arguments,
3954                    result,
3955                } => {
3956                    write!(self.out, "{level}")?;
3957                    if let Some(expr) = result {
3958                        let name = Baked(expr).to_string();
3959                        self.start_baking_expression(expr, &context.expression, &name)?;
3960                        self.named_expressions.insert(expr, name);
3961                    }
3962                    let fun_name = &self.names[&NameKey::Function(function)];
3963                    write!(self.out, "{fun_name}(")?;
3964                    // first, write down the actual arguments
3965                    for (i, &handle) in arguments.iter().enumerate() {
3966                        if i != 0 {
3967                            write!(self.out, ", ")?;
3968                        }
3969                        self.put_expression(handle, &context.expression, true)?;
3970                    }
3971                    // follow-up with any global resources used
3972                    let mut separate = !arguments.is_empty();
3973                    let fun_info = &context.expression.mod_info[function];
3974                    let mut needs_buffer_sizes = false;
3975                    for (handle, var) in context.expression.module.global_variables.iter() {
3976                        if fun_info[handle].is_empty() {
3977                            continue;
3978                        }
3979                        if var.space.needs_pass_through() {
3980                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3981                            if separate {
3982                                write!(self.out, ", ")?;
3983                            } else {
3984                                separate = true;
3985                            }
3986                            write!(self.out, "{name}")?;
3987                        }
3988                        needs_buffer_sizes |=
3989                            needs_array_length(var.ty, &context.expression.module.types);
3990                    }
3991                    if needs_buffer_sizes {
3992                        if separate {
3993                            write!(self.out, ", ")?;
3994                        }
3995                        write!(self.out, "_buffer_sizes")?;
3996                    }
3997
3998                    // done
3999                    writeln!(self.out, ");")?;
4000                }
4001                crate::Statement::Atomic {
4002                    pointer,
4003                    ref fun,
4004                    value,
4005                    result,
4006                } => {
4007                    let context = &context.expression;
4008
4009                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
4010                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
4011                    // `Some`, we are not operating on a 64-bit value, and that if we are
4012                    // operating on a 64-bit value, `result` is `None`.
4013                    write!(self.out, "{level}")?;
4014                    let fun_key = if let Some(result) = result {
4015                        let res_name = Baked(result).to_string();
4016                        self.start_baking_expression(result, context, &res_name)?;
4017                        self.named_expressions.insert(result, res_name);
4018                        fun.to_msl()
4019                    } else if context.resolve_type(value).scalar_width() == Some(8) {
4020                        fun.to_msl_64_bit()?
4021                    } else {
4022                        fun.to_msl()
4023                    };
4024
4025                    // If the pointer we're passing to the atomic operation needs to be conditional
4026                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
4027                    // the pointer operand should be unchecked.
4028                    let policy = context.choose_bounds_check_policy(pointer);
4029                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4030                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4031
4032                    // If requested and successfully put bounds checks, continue the ternary expression.
4033                    if checked {
4034                        write!(self.out, " ? ")?;
4035                    }
4036
4037                    // Put the atomic function invocation.
4038                    match *fun {
4039                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4040                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4041                            self.put_access_chain(pointer, policy, context)?;
4042                            write!(self.out, ", ")?;
4043                            self.put_expression(cmp, context, true)?;
4044                            write!(self.out, ", ")?;
4045                            self.put_expression(value, context, true)?;
4046                            write!(self.out, ")")?;
4047                        }
4048                        _ => {
4049                            write!(
4050                                self.out,
4051                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4052                            )?;
4053                            self.put_access_chain(pointer, policy, context)?;
4054                            write!(self.out, ", ")?;
4055                            self.put_expression(value, context, true)?;
4056                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4057                        }
4058                    }
4059
4060                    // Finish the ternary expression.
4061                    if checked {
4062                        write!(self.out, " : DefaultConstructible()")?;
4063                    }
4064
4065                    // Done
4066                    writeln!(self.out, ";")?;
4067                }
4068                crate::Statement::ImageAtomic {
4069                    image,
4070                    coordinate,
4071                    array_index,
4072                    fun,
4073                    value,
4074                } => {
4075                    let address = TexelAddress {
4076                        coordinate,
4077                        array_index,
4078                        sample: None,
4079                        level: None,
4080                    };
4081                    self.put_image_atomic(level, image, &address, fun, value, context)?
4082                }
4083                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4084                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4085
4086                    write!(self.out, "{level}")?;
4087                    let name = self.namer.call("");
4088                    self.start_baking_expression(result, &context.expression, &name)?;
4089                    self.put_load(pointer, &context.expression, true)?;
4090                    self.named_expressions.insert(result, name);
4091
4092                    writeln!(self.out, ";")?;
4093                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4094                }
4095                crate::Statement::RayQuery { query, ref fun } => {
4096                    if context.expression.lang_version < (2, 4) {
4097                        return Err(Error::UnsupportedRayTracing);
4098                    }
4099
4100                    match *fun {
4101                        crate::RayQueryFunction::Initialize {
4102                            acceleration_structure,
4103                            descriptor,
4104                        } => {
4105                            //TODO: how to deal with winding?
4106                            write!(self.out, "{level}")?;
4107                            self.put_expression(query, &context.expression, true)?;
4108                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4109                            {
4110                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4111                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4112                                write!(self.out, "{level}")?;
4113                                self.put_expression(query, &context.expression, true)?;
4114                                write!(
4115                                    self.out,
4116                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4117                                )?;
4118                                self.put_expression(descriptor, &context.expression, true)?;
4119                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4120                                self.put_expression(descriptor, &context.expression, true)?;
4121                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4122                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4123                            }
4124                            {
4125                                let f_opaque = back::RayFlag::OPAQUE.bits();
4126                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4127                                write!(self.out, "{level}")?;
4128                                self.put_expression(query, &context.expression, true)?;
4129                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4130                                self.put_expression(descriptor, &context.expression, true)?;
4131                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4132                                self.put_expression(descriptor, &context.expression, true)?;
4133                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4134                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4135                            }
4136                            {
4137                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4138                                write!(self.out, "{level}")?;
4139                                self.put_expression(query, &context.expression, true)?;
4140                                write!(
4141                                    self.out,
4142                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4143                                )?;
4144                                self.put_expression(descriptor, &context.expression, true)?;
4145                                writeln!(self.out, ".flags & {flag}) != 0);")?;
4146                            }
4147
4148                            write!(self.out, "{level}")?;
4149                            self.put_expression(query, &context.expression, true)?;
4150                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4151                            self.put_expression(query, &context.expression, true)?;
4152                            write!(
4153                                self.out,
4154                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4155                            )?;
4156                            self.put_expression(descriptor, &context.expression, true)?;
4157                            write!(self.out, ".origin, ")?;
4158                            self.put_expression(descriptor, &context.expression, true)?;
4159                            write!(self.out, ".dir, ")?;
4160                            self.put_expression(descriptor, &context.expression, true)?;
4161                            write!(self.out, ".tmin, ")?;
4162                            self.put_expression(descriptor, &context.expression, true)?;
4163                            write!(self.out, ".tmax), ")?;
4164                            self.put_expression(acceleration_structure, &context.expression, true)?;
4165                            write!(self.out, ", ")?;
4166                            self.put_expression(descriptor, &context.expression, true)?;
4167                            write!(self.out, ".cull_mask);")?;
4168
4169                            write!(self.out, "{level}")?;
4170                            self.put_expression(query, &context.expression, true)?;
4171                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4172                        }
4173                        crate::RayQueryFunction::Proceed { result } => {
4174                            write!(self.out, "{level}")?;
4175                            let name = Baked(result).to_string();
4176                            self.start_baking_expression(result, &context.expression, &name)?;
4177                            self.named_expressions.insert(result, name);
4178                            self.put_expression(query, &context.expression, true)?;
4179                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4180                            if RAY_QUERY_MODERN_SUPPORT {
4181                                write!(self.out, "{level}")?;
4182                                self.put_expression(query, &context.expression, true)?;
4183                                writeln!(self.out, ".?.next();")?;
4184                            }
4185                        }
4186                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4187                            if RAY_QUERY_MODERN_SUPPORT {
4188                                write!(self.out, "{level}")?;
4189                                self.put_expression(query, &context.expression, true)?;
4190                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4191                                self.put_expression(hit_t, &context.expression, true)?;
4192                                writeln!(self.out, ");")?;
4193                            } else {
4194                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4195                            }
4196                        }
4197                        crate::RayQueryFunction::ConfirmIntersection => {
4198                            if RAY_QUERY_MODERN_SUPPORT {
4199                                write!(self.out, "{level}")?;
4200                                self.put_expression(query, &context.expression, true)?;
4201                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4202                            } else {
4203                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4204                            }
4205                        }
4206                        crate::RayQueryFunction::Terminate => {
4207                            if RAY_QUERY_MODERN_SUPPORT {
4208                                write!(self.out, "{level}")?;
4209                                self.put_expression(query, &context.expression, true)?;
4210                                writeln!(self.out, ".?.abort();")?;
4211                            }
4212                            write!(self.out, "{level}")?;
4213                            self.put_expression(query, &context.expression, true)?;
4214                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4215                        }
4216                    }
4217                }
4218                crate::Statement::SubgroupBallot { result, predicate } => {
4219                    write!(self.out, "{level}")?;
4220                    let name = self.namer.call("");
4221                    self.start_baking_expression(result, &context.expression, &name)?;
4222                    self.named_expressions.insert(result, name);
4223                    write!(
4224                        self.out,
4225                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4226                    )?;
4227                    if let Some(predicate) = predicate {
4228                        self.put_expression(predicate, &context.expression, true)?;
4229                    } else {
4230                        write!(self.out, "true")?;
4231                    }
4232                    writeln!(self.out, "), 0, 0, 0);")?;
4233                }
4234                crate::Statement::SubgroupCollectiveOperation {
4235                    op,
4236                    collective_op,
4237                    argument,
4238                    result,
4239                } => {
4240                    write!(self.out, "{level}")?;
4241                    let name = self.namer.call("");
4242                    self.start_baking_expression(result, &context.expression, &name)?;
4243                    self.named_expressions.insert(result, name);
4244                    match (collective_op, op) {
4245                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4246                            write!(self.out, "{NAMESPACE}::simd_all(")?
4247                        }
4248                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4249                            write!(self.out, "{NAMESPACE}::simd_any(")?
4250                        }
4251                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4252                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4253                        }
4254                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4255                            write!(self.out, "{NAMESPACE}::simd_product(")?
4256                        }
4257                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4258                            write!(self.out, "{NAMESPACE}::simd_max(")?
4259                        }
4260                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4261                            write!(self.out, "{NAMESPACE}::simd_min(")?
4262                        }
4263                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4264                            write!(self.out, "{NAMESPACE}::simd_and(")?
4265                        }
4266                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4267                            write!(self.out, "{NAMESPACE}::simd_or(")?
4268                        }
4269                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4270                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4271                        }
4272                        (
4273                            crate::CollectiveOperation::ExclusiveScan,
4274                            crate::SubgroupOperation::Add,
4275                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4276                        (
4277                            crate::CollectiveOperation::ExclusiveScan,
4278                            crate::SubgroupOperation::Mul,
4279                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4280                        (
4281                            crate::CollectiveOperation::InclusiveScan,
4282                            crate::SubgroupOperation::Add,
4283                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4284                        (
4285                            crate::CollectiveOperation::InclusiveScan,
4286                            crate::SubgroupOperation::Mul,
4287                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4288                        _ => unimplemented!(),
4289                    }
4290                    self.put_expression(argument, &context.expression, true)?;
4291                    writeln!(self.out, ");")?;
4292                }
4293                crate::Statement::SubgroupGather {
4294                    mode,
4295                    argument,
4296                    result,
4297                } => {
4298                    write!(self.out, "{level}")?;
4299                    let name = self.namer.call("");
4300                    self.start_baking_expression(result, &context.expression, &name)?;
4301                    self.named_expressions.insert(result, name);
4302                    match mode {
4303                        crate::GatherMode::BroadcastFirst => {
4304                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4305                        }
4306                        crate::GatherMode::Broadcast(_) => {
4307                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4308                        }
4309                        crate::GatherMode::Shuffle(_) => {
4310                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4311                        }
4312                        crate::GatherMode::ShuffleDown(_) => {
4313                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4314                        }
4315                        crate::GatherMode::ShuffleUp(_) => {
4316                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4317                        }
4318                        crate::GatherMode::ShuffleXor(_) => {
4319                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4320                        }
4321                        crate::GatherMode::QuadBroadcast(_) => {
4322                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4323                        }
4324                        crate::GatherMode::QuadSwap(_) => {
4325                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4326                        }
4327                    }
4328                    self.put_expression(argument, &context.expression, true)?;
4329                    match mode {
4330                        crate::GatherMode::BroadcastFirst => {}
4331                        crate::GatherMode::Broadcast(index)
4332                        | crate::GatherMode::Shuffle(index)
4333                        | crate::GatherMode::ShuffleDown(index)
4334                        | crate::GatherMode::ShuffleUp(index)
4335                        | crate::GatherMode::ShuffleXor(index)
4336                        | crate::GatherMode::QuadBroadcast(index) => {
4337                            write!(self.out, ", ")?;
4338                            self.put_expression(index, &context.expression, true)?;
4339                        }
4340                        crate::GatherMode::QuadSwap(direction) => {
4341                            write!(self.out, ", ")?;
4342                            match direction {
4343                                crate::Direction::X => {
4344                                    write!(self.out, "1u")?;
4345                                }
4346                                crate::Direction::Y => {
4347                                    write!(self.out, "2u")?;
4348                                }
4349                                crate::Direction::Diagonal => {
4350                                    write!(self.out, "3u")?;
4351                                }
4352                            }
4353                        }
4354                    }
4355                    writeln!(self.out, ");")?;
4356                }
4357                crate::Statement::CooperativeStore { target, ref data } => {
4358                    write!(self.out, "{level}simdgroup_store(")?;
4359                    self.put_expression(target, &context.expression, true)?;
4360                    write!(self.out, ", &")?;
4361                    self.put_access_chain(
4362                        data.pointer,
4363                        context.expression.policies.index,
4364                        &context.expression,
4365                    )?;
4366                    write!(self.out, ", ")?;
4367                    self.put_expression(data.stride, &context.expression, true)?;
4368                    if data.row_major {
4369                        let matrix_origin = "0";
4370                        let transpose = true;
4371                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4372                    }
4373                    writeln!(self.out, ");")?;
4374                }
4375                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4376            }
4377        }
4378
4379        // un-emit expressions
4380        //TODO: take care of loop/continuing?
4381        for statement in statements {
4382            if let crate::Statement::Emit(ref range) = *statement {
4383                for handle in range.clone() {
4384                    self.named_expressions.shift_remove(&handle);
4385                }
4386            }
4387        }
4388        Ok(())
4389    }
4390
4391    fn put_store(
4392        &mut self,
4393        pointer: Handle<crate::Expression>,
4394        value: Handle<crate::Expression>,
4395        level: back::Level,
4396        context: &StatementContext,
4397    ) -> BackendResult {
4398        let policy = context.expression.choose_bounds_check_policy(pointer);
4399        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4400            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4401        {
4402            writeln!(self.out, ") {{")?;
4403            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4404            writeln!(self.out, "{level}}}")?;
4405        } else {
4406            self.put_unchecked_store(pointer, value, policy, level, context)?;
4407        }
4408
4409        Ok(())
4410    }
4411
4412    fn put_unchecked_store(
4413        &mut self,
4414        pointer: Handle<crate::Expression>,
4415        value: Handle<crate::Expression>,
4416        policy: index::BoundsCheckPolicy,
4417        level: back::Level,
4418        context: &StatementContext,
4419    ) -> BackendResult {
4420        let is_atomic_pointer = context
4421            .expression
4422            .resolve_type(pointer)
4423            .is_atomic_pointer(&context.expression.module.types);
4424
4425        if is_atomic_pointer {
4426            write!(
4427                self.out,
4428                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4429            )?;
4430            self.put_access_chain(pointer, policy, &context.expression)?;
4431            write!(self.out, ", ")?;
4432            self.put_expression(value, &context.expression, true)?;
4433            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4434        } else {
4435            write!(self.out, "{level}")?;
4436            self.put_access_chain(pointer, policy, &context.expression)?;
4437            write!(self.out, " = ")?;
4438            self.put_expression(value, &context.expression, true)?;
4439            writeln!(self.out, ";")?;
4440        }
4441
4442        Ok(())
4443    }
4444
4445    pub fn write(
4446        &mut self,
4447        module: &crate::Module,
4448        info: &valid::ModuleInfo,
4449        options: &Options,
4450        pipeline_options: &PipelineOptions,
4451    ) -> Result<TranslationInfo, Error> {
4452        self.names.clear();
4453        self.namer.reset(
4454            module,
4455            &super::keywords::RESERVED_SET,
4456            proc::KeywordSet::empty(),
4457            proc::CaseInsensitiveKeywordSet::empty(),
4458            &[CLAMPED_LOD_LOAD_PREFIX],
4459            &mut self.names,
4460        );
4461        self.wrapped_functions.clear();
4462        self.struct_member_pads.clear();
4463
4464        writeln!(
4465            self.out,
4466            "// language: metal{}.{}",
4467            options.lang_version.0, options.lang_version.1
4468        )?;
4469        writeln!(self.out, "#include <metal_stdlib>")?;
4470        writeln!(self.out, "#include <simd/simd.h>")?;
4471        writeln!(self.out)?;
4472        // Work around Metal bug where `uint` is not available by default
4473        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4474
4475        if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4476            return Err(Error::UnsupportedMeshShader);
4477        }
4478        self.needs_object_memory_barriers = module
4479            .entry_points
4480            .iter()
4481            .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4482
4483        let mut uses_ray_query = false;
4484        for (_, ty) in module.types.iter() {
4485            match ty.inner {
4486                crate::TypeInner::AccelerationStructure { .. } => {
4487                    if options.lang_version < (2, 4) {
4488                        return Err(Error::UnsupportedRayTracing);
4489                    }
4490                }
4491                crate::TypeInner::RayQuery { .. } => {
4492                    if options.lang_version < (2, 4) {
4493                        return Err(Error::UnsupportedRayTracing);
4494                    }
4495                    uses_ray_query = true;
4496                }
4497                _ => (),
4498            }
4499        }
4500
4501        if module.special_types.ray_desc.is_some()
4502            || module.special_types.ray_intersection.is_some()
4503        {
4504            if options.lang_version < (2, 4) {
4505                return Err(Error::UnsupportedRayTracing);
4506            }
4507        }
4508
4509        if uses_ray_query {
4510            self.put_ray_query_type()?;
4511        }
4512
4513        if options
4514            .bounds_check_policies
4515            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4516        {
4517            self.put_default_constructible()?;
4518        }
4519        writeln!(self.out)?;
4520
4521        {
4522            // Make a `Vec` of all the `GlobalVariable`s that contain
4523            // runtime-sized arrays.
4524            let globals: Vec<Handle<crate::GlobalVariable>> = module
4525                .global_variables
4526                .iter()
4527                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4528                .map(|(handle, _)| handle)
4529                .collect();
4530
4531            let mut buffer_indices = vec![];
4532            for vbm in &pipeline_options.vertex_buffer_mappings {
4533                buffer_indices.push(vbm.id);
4534            }
4535
4536            if !globals.is_empty() || !buffer_indices.is_empty() {
4537                writeln!(self.out, "struct _mslBufferSizes {{")?;
4538
4539                for global in globals {
4540                    writeln!(
4541                        self.out,
4542                        "{}uint {};",
4543                        back::INDENT,
4544                        ArraySizeMember(global)
4545                    )?;
4546                }
4547
4548                for idx in buffer_indices {
4549                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4550                }
4551
4552                writeln!(self.out, "}};")?;
4553                writeln!(self.out)?;
4554            }
4555        };
4556
4557        self.write_type_defs(module)?;
4558        self.write_global_constants(module, info)?;
4559        self.write_functions(module, info, options, pipeline_options)
4560    }
4561
4562    /// Write the definition for the `DefaultConstructible` class.
4563    ///
4564    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4565    /// produce 'zero' values for any type, including structs, arrays, and so
4566    /// on. We could do this by emitting default constructor applications, but
4567    /// that would entail printing the name of the type, which is more trouble
4568    /// than you'd think. Instead, we just construct this magic C++14 class that
4569    /// can be converted to any type that can be default constructed, using
4570    /// template parameter inference to detect which type is needed, so we don't
4571    /// have to figure out the name.
4572    ///
4573    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4574    fn put_default_constructible(&mut self) -> BackendResult {
4575        let tab = back::INDENT;
4576        writeln!(self.out, "struct DefaultConstructible {{")?;
4577        writeln!(self.out, "{tab}template<typename T>")?;
4578        writeln!(self.out, "{tab}operator T() && {{")?;
4579        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4580        writeln!(self.out, "{tab}}}")?;
4581        writeln!(self.out, "}};")?;
4582        Ok(())
4583    }
4584
4585    fn put_ray_query_type(&mut self) -> BackendResult {
4586        let tab = back::INDENT;
4587        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4588        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4589        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4590        writeln!(
4591            self.out,
4592            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4593        )?;
4594        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4595        writeln!(self.out, "}};")?;
4596        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4597        let v_triangle = back::RayIntersectionType::Triangle as u32;
4598        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4599        writeln!(
4600            self.out,
4601            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4602        )?;
4603        writeln!(
4604            self.out,
4605            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4606        )?;
4607        writeln!(self.out, "}}")?;
4608        Ok(())
4609    }
4610
4611    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4612        let mut generated_argument_buffer_wrapper = false;
4613        let mut generated_external_texture_wrapper = false;
4614        for (handle, ty) in module.types.iter() {
4615            match ty.inner {
4616                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4617                    writeln!(self.out, "template <typename T>")?;
4618                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4619                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4620                    writeln!(self.out, "}};")?;
4621                    generated_argument_buffer_wrapper = true;
4622                }
4623                crate::TypeInner::Image {
4624                    class: crate::ImageClass::External,
4625                    ..
4626                } if !generated_external_texture_wrapper => {
4627                    let params_ty_name = &self.names
4628                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4629                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4630                    writeln!(
4631                        self.out,
4632                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4633                        back::INDENT
4634                    )?;
4635                    writeln!(
4636                        self.out,
4637                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4638                        back::INDENT
4639                    )?;
4640                    writeln!(
4641                        self.out,
4642                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4643                        back::INDENT
4644                    )?;
4645                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4646                    writeln!(self.out, "}};")?;
4647                    generated_external_texture_wrapper = true;
4648                }
4649                _ => {}
4650            }
4651
4652            if !ty.needs_alias() {
4653                continue;
4654            }
4655            let name = &self.names[&NameKey::Type(handle)];
4656            match ty.inner {
4657                // Naga IR can pass around arrays by value, but Metal, following
4658                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4659                // on expressions of array type, so assigning the array by value
4660                // isn't possible. However, Metal *does* assign structs by
4661                // value. So in our Metal output, we wrap all array types in
4662                // synthetic struct types:
4663                //
4664                //     struct type1 {
4665                //         float inner[10]
4666                //     };
4667                //
4668                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4669                // any expression that actually wants access to the array.
4670                crate::TypeInner::Array {
4671                    base,
4672                    size,
4673                    stride: _,
4674                } => {
4675                    let base_name = TypeContext {
4676                        handle: base,
4677                        gctx: module.to_ctx(),
4678                        names: &self.names,
4679                        access: crate::StorageAccess::empty(),
4680                        first_time: false,
4681                    };
4682
4683                    match size.resolve(module.to_ctx())? {
4684                        proc::IndexableLength::Known(size) => {
4685                            writeln!(self.out, "struct {name} {{")?;
4686                            writeln!(
4687                                self.out,
4688                                "{}{} {}[{}];",
4689                                back::INDENT,
4690                                base_name,
4691                                WRAPPED_ARRAY_FIELD,
4692                                size
4693                            )?;
4694                            writeln!(self.out, "}};")?;
4695                        }
4696                        proc::IndexableLength::Dynamic => {
4697                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4698                        }
4699                    }
4700                }
4701                crate::TypeInner::Struct {
4702                    ref members, span, ..
4703                } => {
4704                    writeln!(self.out, "struct {name} {{")?;
4705                    let mut last_offset = 0;
4706                    for (index, member) in members.iter().enumerate() {
4707                        if member.offset > last_offset {
4708                            self.struct_member_pads.insert((handle, index as u32));
4709                            let pad = member.offset - last_offset;
4710                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4711                        }
4712                        let ty_inner = &module.types[member.ty].inner;
4713                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4714
4715                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4716
4717                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4718                        match should_pack_struct_member(members, span, index, module) {
4719                            Some(scalar) => {
4720                                writeln!(
4721                                    self.out,
4722                                    "{}{}::packed_{}3 {};",
4723                                    back::INDENT,
4724                                    NAMESPACE,
4725                                    scalar.to_msl_name(),
4726                                    member_name
4727                                )?;
4728                            }
4729                            None => {
4730                                let base_name = TypeContext {
4731                                    handle: member.ty,
4732                                    gctx: module.to_ctx(),
4733                                    names: &self.names,
4734                                    access: crate::StorageAccess::empty(),
4735                                    first_time: false,
4736                                };
4737                                writeln!(
4738                                    self.out,
4739                                    "{}{} {};",
4740                                    back::INDENT,
4741                                    base_name,
4742                                    member_name
4743                                )?;
4744
4745                                // for 3-component vectors, add one component
4746                                if let crate::TypeInner::Vector {
4747                                    size: crate::VectorSize::Tri,
4748                                    scalar,
4749                                } = *ty_inner
4750                                {
4751                                    last_offset += scalar.width as u32;
4752                                }
4753                            }
4754                        }
4755                    }
4756                    if last_offset < span {
4757                        let pad = span - last_offset;
4758                        writeln!(
4759                            self.out,
4760                            "{}char _pad{}[{}];",
4761                            back::INDENT,
4762                            members.len(),
4763                            pad
4764                        )?;
4765                    }
4766                    writeln!(self.out, "}};")?;
4767                }
4768                _ => {
4769                    let ty_name = TypeContext {
4770                        handle,
4771                        gctx: module.to_ctx(),
4772                        names: &self.names,
4773                        access: crate::StorageAccess::empty(),
4774                        first_time: true,
4775                    };
4776                    writeln!(self.out, "typedef {ty_name} {name};")?;
4777                }
4778            }
4779        }
4780
4781        // Write functions to create special types.
4782        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4783            match type_key {
4784                &crate::PredeclaredType::ModfResult { size, scalar }
4785                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4786                    let arg_type_name_owner;
4787                    let arg_type_name = if let Some(size) = size {
4788                        arg_type_name_owner = format!(
4789                            "{NAMESPACE}::{}{}",
4790                            if scalar.width == 8 { "double" } else { "float" },
4791                            size as u8
4792                        );
4793                        &arg_type_name_owner
4794                    } else if scalar.width == 8 {
4795                        "double"
4796                    } else {
4797                        "float"
4798                    };
4799
4800                    let other_type_name_owner;
4801                    let (defined_func_name, called_func_name, other_type_name) =
4802                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4803                            (MODF_FUNCTION, "modf", arg_type_name)
4804                        } else {
4805                            let other_type_name = if let Some(size) = size {
4806                                other_type_name_owner = format!("int{}", size as u8);
4807                                &other_type_name_owner
4808                            } else {
4809                                "int"
4810                            };
4811                            (FREXP_FUNCTION, "frexp", other_type_name)
4812                        };
4813
4814                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4815
4816                    writeln!(self.out)?;
4817                    writeln!(
4818                        self.out,
4819                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4820    {other_type_name} other;
4821    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4822    return {struct_name}{{ fract, other }};
4823}}"
4824                    )?;
4825                }
4826                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4827                    let arg_type_name = scalar.to_msl_name();
4828                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4829                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4830                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4831
4832                    writeln!(self.out)?;
4833
4834                    for address_space_name in ["device", "threadgroup"] {
4835                        writeln!(
4836                            self.out,
4837                            "\
4838template <typename A>
4839{struct_name} {defined_func_name}(
4840    {address_space_name} A *atomic_ptr,
4841    {arg_type_name} cmp,
4842    {arg_type_name} v
4843) {{
4844    bool swapped = {NAMESPACE}::{called_func_name}(
4845        atomic_ptr, &cmp, v,
4846        metal::memory_order_relaxed, metal::memory_order_relaxed
4847    );
4848    return {struct_name}{{cmp, swapped}};
4849}}"
4850                        )?;
4851                    }
4852                }
4853            }
4854        }
4855
4856        Ok(())
4857    }
4858
4859    /// Writes all named constants
4860    fn write_global_constants(
4861        &mut self,
4862        module: &crate::Module,
4863        mod_info: &valid::ModuleInfo,
4864    ) -> BackendResult {
4865        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4866
4867        for (handle, constant) in constants {
4868            let ty_name = TypeContext {
4869                handle: constant.ty,
4870                gctx: module.to_ctx(),
4871                names: &self.names,
4872                access: crate::StorageAccess::empty(),
4873                first_time: false,
4874            };
4875            let name = &self.names[&NameKey::Constant(handle)];
4876            write!(self.out, "constant {ty_name} {name} = ")?;
4877            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4878            writeln!(self.out, ";")?;
4879        }
4880
4881        Ok(())
4882    }
4883
4884    fn put_inline_sampler_properties(
4885        &mut self,
4886        level: back::Level,
4887        sampler: &sm::InlineSampler,
4888    ) -> BackendResult {
4889        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4890            writeln!(
4891                self.out,
4892                "{}{}::{}_address::{},",
4893                level,
4894                NAMESPACE,
4895                letter,
4896                address.as_str(),
4897            )?;
4898        }
4899        writeln!(
4900            self.out,
4901            "{}{}::mag_filter::{},",
4902            level,
4903            NAMESPACE,
4904            sampler.mag_filter.as_str(),
4905        )?;
4906        writeln!(
4907            self.out,
4908            "{}{}::min_filter::{},",
4909            level,
4910            NAMESPACE,
4911            sampler.min_filter.as_str(),
4912        )?;
4913        if let Some(filter) = sampler.mip_filter {
4914            writeln!(
4915                self.out,
4916                "{}{}::mip_filter::{},",
4917                level,
4918                NAMESPACE,
4919                filter.as_str(),
4920            )?;
4921        }
4922        // avoid setting it on platforms that don't support it
4923        if sampler.border_color != sm::BorderColor::TransparentBlack {
4924            writeln!(
4925                self.out,
4926                "{}{}::border_color::{},",
4927                level,
4928                NAMESPACE,
4929                sampler.border_color.as_str(),
4930            )?;
4931        }
4932        //TODO: I'm not able to feed this in a way that MSL likes:
4933        //>error: use of undeclared identifier 'lod_clamp'
4934        //>error: no member named 'max_anisotropy' in namespace 'metal'
4935        if false {
4936            if let Some(ref lod) = sampler.lod_clamp {
4937                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4938            }
4939            if let Some(aniso) = sampler.max_anisotropy {
4940                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4941            }
4942        }
4943        if sampler.compare_func != sm::CompareFunc::Never {
4944            writeln!(
4945                self.out,
4946                "{}{}::compare_func::{},",
4947                level,
4948                NAMESPACE,
4949                sampler.compare_func.as_str(),
4950            )?;
4951        }
4952        writeln!(
4953            self.out,
4954            "{}{}::coord::{}",
4955            level,
4956            NAMESPACE,
4957            sampler.coord.as_str()
4958        )?;
4959        Ok(())
4960    }
4961
4962    fn write_unpacking_function(
4963        &mut self,
4964        format: back::msl::VertexFormat,
4965    ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4966        use crate::{Scalar, VectorSize};
4967        use back::msl::VertexFormat::*;
4968        match format {
4969            Uint8 => {
4970                let name = self.namer.call("unpackUint8");
4971                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4972                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4973                writeln!(self.out, "}}")?;
4974                Ok((name, 1, None, Scalar::U32))
4975            }
4976            Uint8x2 => {
4977                let name = self.namer.call("unpackUint8x2");
4978                writeln!(
4979                    self.out,
4980                    "metal::uint2 {name}(metal::uchar b0, \
4981                                         metal::uchar b1) {{"
4982                )?;
4983                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4984                writeln!(self.out, "}}")?;
4985                Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4986            }
4987            Uint8x4 => {
4988                let name = self.namer.call("unpackUint8x4");
4989                writeln!(
4990                    self.out,
4991                    "metal::uint4 {name}(metal::uchar b0, \
4992                                         metal::uchar b1, \
4993                                         metal::uchar b2, \
4994                                         metal::uchar b3) {{"
4995                )?;
4996                writeln!(
4997                    self.out,
4998                    "{}return metal::uint4(b0, b1, b2, b3);",
4999                    back::INDENT
5000                )?;
5001                writeln!(self.out, "}}")?;
5002                Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
5003            }
5004            Sint8 => {
5005                let name = self.namer.call("unpackSint8");
5006                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
5007                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
5008                writeln!(self.out, "}}")?;
5009                Ok((name, 1, None, Scalar::I32))
5010            }
5011            Sint8x2 => {
5012                let name = self.namer.call("unpackSint8x2");
5013                writeln!(
5014                    self.out,
5015                    "metal::int2 {name}(metal::uchar b0, \
5016                                        metal::uchar b1) {{"
5017                )?;
5018                writeln!(
5019                    self.out,
5020                    "{}return metal::int2(as_type<char>(b0), \
5021                                          as_type<char>(b1));",
5022                    back::INDENT
5023                )?;
5024                writeln!(self.out, "}}")?;
5025                Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
5026            }
5027            Sint8x4 => {
5028                let name = self.namer.call("unpackSint8x4");
5029                writeln!(
5030                    self.out,
5031                    "metal::int4 {name}(metal::uchar b0, \
5032                                        metal::uchar b1, \
5033                                        metal::uchar b2, \
5034                                        metal::uchar b3) {{"
5035                )?;
5036                writeln!(
5037                    self.out,
5038                    "{}return metal::int4(as_type<char>(b0), \
5039                                          as_type<char>(b1), \
5040                                          as_type<char>(b2), \
5041                                          as_type<char>(b3));",
5042                    back::INDENT
5043                )?;
5044                writeln!(self.out, "}}")?;
5045                Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
5046            }
5047            Unorm8 => {
5048                let name = self.namer.call("unpackUnorm8");
5049                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5050                writeln!(
5051                    self.out,
5052                    "{}return float(float(b0) / 255.0f);",
5053                    back::INDENT
5054                )?;
5055                writeln!(self.out, "}}")?;
5056                Ok((name, 1, None, Scalar::F32))
5057            }
5058            Unorm8x2 => {
5059                let name = self.namer.call("unpackUnorm8x2");
5060                writeln!(
5061                    self.out,
5062                    "metal::float2 {name}(metal::uchar b0, \
5063                                          metal::uchar b1) {{"
5064                )?;
5065                writeln!(
5066                    self.out,
5067                    "{}return metal::float2(float(b0) / 255.0f, \
5068                                            float(b1) / 255.0f);",
5069                    back::INDENT
5070                )?;
5071                writeln!(self.out, "}}")?;
5072                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5073            }
5074            Unorm8x4 => {
5075                let name = self.namer.call("unpackUnorm8x4");
5076                writeln!(
5077                    self.out,
5078                    "metal::float4 {name}(metal::uchar b0, \
5079                                          metal::uchar b1, \
5080                                          metal::uchar b2, \
5081                                          metal::uchar b3) {{"
5082                )?;
5083                writeln!(
5084                    self.out,
5085                    "{}return metal::float4(float(b0) / 255.0f, \
5086                                            float(b1) / 255.0f, \
5087                                            float(b2) / 255.0f, \
5088                                            float(b3) / 255.0f);",
5089                    back::INDENT
5090                )?;
5091                writeln!(self.out, "}}")?;
5092                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5093            }
5094            Snorm8 => {
5095                let name = self.namer.call("unpackSnorm8");
5096                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5097                writeln!(
5098                    self.out,
5099                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5100                    back::INDENT
5101                )?;
5102                writeln!(self.out, "}}")?;
5103                Ok((name, 1, None, Scalar::F32))
5104            }
5105            Snorm8x2 => {
5106                let name = self.namer.call("unpackSnorm8x2");
5107                writeln!(
5108                    self.out,
5109                    "metal::float2 {name}(metal::uchar b0, \
5110                                          metal::uchar b1) {{"
5111                )?;
5112                writeln!(
5113                    self.out,
5114                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5115                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5116                    back::INDENT
5117                )?;
5118                writeln!(self.out, "}}")?;
5119                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5120            }
5121            Snorm8x4 => {
5122                let name = self.namer.call("unpackSnorm8x4");
5123                writeln!(
5124                    self.out,
5125                    "metal::float4 {name}(metal::uchar b0, \
5126                                          metal::uchar b1, \
5127                                          metal::uchar b2, \
5128                                          metal::uchar b3) {{"
5129                )?;
5130                writeln!(
5131                    self.out,
5132                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5133                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5134                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5135                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5136                    back::INDENT
5137                )?;
5138                writeln!(self.out, "}}")?;
5139                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5140            }
5141            Uint16 => {
5142                let name = self.namer.call("unpackUint16");
5143                writeln!(
5144                    self.out,
5145                    "metal::uint {name}(metal::uint b0, \
5146                                        metal::uint b1) {{"
5147                )?;
5148                writeln!(
5149                    self.out,
5150                    "{}return metal::uint(b1 << 8 | b0);",
5151                    back::INDENT
5152                )?;
5153                writeln!(self.out, "}}")?;
5154                Ok((name, 2, None, Scalar::U32))
5155            }
5156            Uint16x2 => {
5157                let name = self.namer.call("unpackUint16x2");
5158                writeln!(
5159                    self.out,
5160                    "metal::uint2 {name}(metal::uint b0, \
5161                                         metal::uint b1, \
5162                                         metal::uint b2, \
5163                                         metal::uint b3) {{"
5164                )?;
5165                writeln!(
5166                    self.out,
5167                    "{}return metal::uint2(b1 << 8 | b0, \
5168                                           b3 << 8 | b2);",
5169                    back::INDENT
5170                )?;
5171                writeln!(self.out, "}}")?;
5172                Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5173            }
5174            Uint16x4 => {
5175                let name = self.namer.call("unpackUint16x4");
5176                writeln!(
5177                    self.out,
5178                    "metal::uint4 {name}(metal::uint b0, \
5179                                         metal::uint b1, \
5180                                         metal::uint b2, \
5181                                         metal::uint b3, \
5182                                         metal::uint b4, \
5183                                         metal::uint b5, \
5184                                         metal::uint b6, \
5185                                         metal::uint b7) {{"
5186                )?;
5187                writeln!(
5188                    self.out,
5189                    "{}return metal::uint4(b1 << 8 | b0, \
5190                                           b3 << 8 | b2, \
5191                                           b5 << 8 | b4, \
5192                                           b7 << 8 | b6);",
5193                    back::INDENT
5194                )?;
5195                writeln!(self.out, "}}")?;
5196                Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5197            }
5198            Sint16 => {
5199                let name = self.namer.call("unpackSint16");
5200                writeln!(
5201                    self.out,
5202                    "int {name}(metal::ushort b0, \
5203                                metal::ushort b1) {{"
5204                )?;
5205                writeln!(
5206                    self.out,
5207                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5208                    back::INDENT
5209                )?;
5210                writeln!(self.out, "}}")?;
5211                Ok((name, 2, None, Scalar::I32))
5212            }
5213            Sint16x2 => {
5214                let name = self.namer.call("unpackSint16x2");
5215                writeln!(
5216                    self.out,
5217                    "metal::int2 {name}(metal::ushort b0, \
5218                                        metal::ushort b1, \
5219                                        metal::ushort b2, \
5220                                        metal::ushort b3) {{"
5221                )?;
5222                writeln!(
5223                    self.out,
5224                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5225                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5226                    back::INDENT
5227                )?;
5228                writeln!(self.out, "}}")?;
5229                Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5230            }
5231            Sint16x4 => {
5232                let name = self.namer.call("unpackSint16x4");
5233                writeln!(
5234                    self.out,
5235                    "metal::int4 {name}(metal::ushort b0, \
5236                                        metal::ushort b1, \
5237                                        metal::ushort b2, \
5238                                        metal::ushort b3, \
5239                                        metal::ushort b4, \
5240                                        metal::ushort b5, \
5241                                        metal::ushort b6, \
5242                                        metal::ushort b7) {{"
5243                )?;
5244                writeln!(
5245                    self.out,
5246                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5247                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5248                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5249                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5250                    back::INDENT
5251                )?;
5252                writeln!(self.out, "}}")?;
5253                Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5254            }
5255            Unorm16 => {
5256                let name = self.namer.call("unpackUnorm16");
5257                writeln!(
5258                    self.out,
5259                    "float {name}(metal::ushort b0, \
5260                                  metal::ushort b1) {{"
5261                )?;
5262                writeln!(
5263                    self.out,
5264                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5265                    back::INDENT
5266                )?;
5267                writeln!(self.out, "}}")?;
5268                Ok((name, 2, None, Scalar::F32))
5269            }
5270            Unorm16x2 => {
5271                let name = self.namer.call("unpackUnorm16x2");
5272                writeln!(
5273                    self.out,
5274                    "metal::float2 {name}(metal::ushort b0, \
5275                                          metal::ushort b1, \
5276                                          metal::ushort b2, \
5277                                          metal::ushort b3) {{"
5278                )?;
5279                writeln!(
5280                    self.out,
5281                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5282                                            float(b3 << 8 | b2) / 65535.0f);",
5283                    back::INDENT
5284                )?;
5285                writeln!(self.out, "}}")?;
5286                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5287            }
5288            Unorm16x4 => {
5289                let name = self.namer.call("unpackUnorm16x4");
5290                writeln!(
5291                    self.out,
5292                    "metal::float4 {name}(metal::ushort b0, \
5293                                          metal::ushort b1, \
5294                                          metal::ushort b2, \
5295                                          metal::ushort b3, \
5296                                          metal::ushort b4, \
5297                                          metal::ushort b5, \
5298                                          metal::ushort b6, \
5299                                          metal::ushort b7) {{"
5300                )?;
5301                writeln!(
5302                    self.out,
5303                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5304                                            float(b3 << 8 | b2) / 65535.0f, \
5305                                            float(b5 << 8 | b4) / 65535.0f, \
5306                                            float(b7 << 8 | b6) / 65535.0f);",
5307                    back::INDENT
5308                )?;
5309                writeln!(self.out, "}}")?;
5310                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5311            }
5312            Snorm16 => {
5313                let name = self.namer.call("unpackSnorm16");
5314                writeln!(
5315                    self.out,
5316                    "float {name}(metal::ushort b0, \
5317                                  metal::ushort b1) {{"
5318                )?;
5319                writeln!(
5320                    self.out,
5321                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5322                    back::INDENT
5323                )?;
5324                writeln!(self.out, "}}")?;
5325                Ok((name, 2, None, Scalar::F32))
5326            }
5327            Snorm16x2 => {
5328                let name = self.namer.call("unpackSnorm16x2");
5329                writeln!(
5330                    self.out,
5331                    "metal::float2 {name}(uint b0, \
5332                                          uint b1, \
5333                                          uint b2, \
5334                                          uint b3) {{"
5335                )?;
5336                writeln!(
5337                    self.out,
5338                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5339                    back::INDENT
5340                )?;
5341                writeln!(self.out, "}}")?;
5342                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5343            }
5344            Snorm16x4 => {
5345                let name = self.namer.call("unpackSnorm16x4");
5346                writeln!(
5347                    self.out,
5348                    "metal::float4 {name}(uint b0, \
5349                                          uint b1, \
5350                                          uint b2, \
5351                                          uint b3, \
5352                                          uint b4, \
5353                                          uint b5, \
5354                                          uint b6, \
5355                                          uint b7) {{"
5356                )?;
5357                writeln!(
5358                    self.out,
5359                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5360                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5361                    back::INDENT
5362                )?;
5363                writeln!(self.out, "}}")?;
5364                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5365            }
5366            Float16 => {
5367                let name = self.namer.call("unpackFloat16");
5368                writeln!(
5369                    self.out,
5370                    "float {name}(metal::ushort b0, \
5371                                  metal::ushort b1) {{"
5372                )?;
5373                writeln!(
5374                    self.out,
5375                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5376                    back::INDENT
5377                )?;
5378                writeln!(self.out, "}}")?;
5379                Ok((name, 2, None, Scalar::F32))
5380            }
5381            Float16x2 => {
5382                let name = self.namer.call("unpackFloat16x2");
5383                writeln!(
5384                    self.out,
5385                    "metal::float2 {name}(metal::ushort b0, \
5386                                          metal::ushort b1, \
5387                                          metal::ushort b2, \
5388                                          metal::ushort b3) {{"
5389                )?;
5390                writeln!(
5391                    self.out,
5392                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5393                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5394                    back::INDENT
5395                )?;
5396                writeln!(self.out, "}}")?;
5397                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5398            }
5399            Float16x4 => {
5400                let name = self.namer.call("unpackFloat16x4");
5401                writeln!(
5402                    self.out,
5403                    "metal::float4 {name}(metal::ushort b0, \
5404                                        metal::ushort b1, \
5405                                        metal::ushort b2, \
5406                                        metal::ushort b3, \
5407                                        metal::ushort b4, \
5408                                        metal::ushort b5, \
5409                                        metal::ushort b6, \
5410                                        metal::ushort b7) {{"
5411                )?;
5412                writeln!(
5413                    self.out,
5414                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5415                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5416                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5417                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5418                    back::INDENT
5419                )?;
5420                writeln!(self.out, "}}")?;
5421                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5422            }
5423            Float32 => {
5424                let name = self.namer.call("unpackFloat32");
5425                writeln!(
5426                    self.out,
5427                    "float {name}(uint b0, \
5428                                  uint b1, \
5429                                  uint b2, \
5430                                  uint b3) {{"
5431                )?;
5432                writeln!(
5433                    self.out,
5434                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5435                    back::INDENT
5436                )?;
5437                writeln!(self.out, "}}")?;
5438                Ok((name, 4, None, Scalar::F32))
5439            }
5440            Float32x2 => {
5441                let name = self.namer.call("unpackFloat32x2");
5442                writeln!(
5443                    self.out,
5444                    "metal::float2 {name}(uint b0, \
5445                                          uint b1, \
5446                                          uint b2, \
5447                                          uint b3, \
5448                                          uint b4, \
5449                                          uint b5, \
5450                                          uint b6, \
5451                                          uint b7) {{"
5452                )?;
5453                writeln!(
5454                    self.out,
5455                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5456                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5457                    back::INDENT
5458                )?;
5459                writeln!(self.out, "}}")?;
5460                Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5461            }
5462            Float32x3 => {
5463                let name = self.namer.call("unpackFloat32x3");
5464                writeln!(
5465                    self.out,
5466                    "metal::float3 {name}(uint b0, \
5467                                          uint b1, \
5468                                          uint b2, \
5469                                          uint b3, \
5470                                          uint b4, \
5471                                          uint b5, \
5472                                          uint b6, \
5473                                          uint b7, \
5474                                          uint b8, \
5475                                          uint b9, \
5476                                          uint b10, \
5477                                          uint b11) {{"
5478                )?;
5479                writeln!(
5480                    self.out,
5481                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5482                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5483                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5484                    back::INDENT
5485                )?;
5486                writeln!(self.out, "}}")?;
5487                Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5488            }
5489            Float32x4 => {
5490                let name = self.namer.call("unpackFloat32x4");
5491                writeln!(
5492                    self.out,
5493                    "metal::float4 {name}(uint b0, \
5494                                          uint b1, \
5495                                          uint b2, \
5496                                          uint b3, \
5497                                          uint b4, \
5498                                          uint b5, \
5499                                          uint b6, \
5500                                          uint b7, \
5501                                          uint b8, \
5502                                          uint b9, \
5503                                          uint b10, \
5504                                          uint b11, \
5505                                          uint b12, \
5506                                          uint b13, \
5507                                          uint b14, \
5508                                          uint b15) {{"
5509                )?;
5510                writeln!(
5511                    self.out,
5512                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5513                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5514                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5515                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5516                    back::INDENT
5517                )?;
5518                writeln!(self.out, "}}")?;
5519                Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5520            }
5521            Uint32 => {
5522                let name = self.namer.call("unpackUint32");
5523                writeln!(
5524                    self.out,
5525                    "uint {name}(uint b0, \
5526                                 uint b1, \
5527                                 uint b2, \
5528                                 uint b3) {{"
5529                )?;
5530                writeln!(
5531                    self.out,
5532                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5533                    back::INDENT
5534                )?;
5535                writeln!(self.out, "}}")?;
5536                Ok((name, 4, None, Scalar::U32))
5537            }
5538            Uint32x2 => {
5539                let name = self.namer.call("unpackUint32x2");
5540                writeln!(
5541                    self.out,
5542                    "uint2 {name}(uint b0, \
5543                                  uint b1, \
5544                                  uint b2, \
5545                                  uint b3, \
5546                                  uint b4, \
5547                                  uint b5, \
5548                                  uint b6, \
5549                                  uint b7) {{"
5550                )?;
5551                writeln!(
5552                    self.out,
5553                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5554                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5555                    back::INDENT
5556                )?;
5557                writeln!(self.out, "}}")?;
5558                Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5559            }
5560            Uint32x3 => {
5561                let name = self.namer.call("unpackUint32x3");
5562                writeln!(
5563                    self.out,
5564                    "uint3 {name}(uint b0, \
5565                                  uint b1, \
5566                                  uint b2, \
5567                                  uint b3, \
5568                                  uint b4, \
5569                                  uint b5, \
5570                                  uint b6, \
5571                                  uint b7, \
5572                                  uint b8, \
5573                                  uint b9, \
5574                                  uint b10, \
5575                                  uint b11) {{"
5576                )?;
5577                writeln!(
5578                    self.out,
5579                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5580                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5581                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5582                    back::INDENT
5583                )?;
5584                writeln!(self.out, "}}")?;
5585                Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5586            }
5587            Uint32x4 => {
5588                let name = self.namer.call("unpackUint32x4");
5589                writeln!(
5590                    self.out,
5591                    "{NAMESPACE}::uint4 {name}(uint b0, \
5592                                  uint b1, \
5593                                  uint b2, \
5594                                  uint b3, \
5595                                  uint b4, \
5596                                  uint b5, \
5597                                  uint b6, \
5598                                  uint b7, \
5599                                  uint b8, \
5600                                  uint b9, \
5601                                  uint b10, \
5602                                  uint b11, \
5603                                  uint b12, \
5604                                  uint b13, \
5605                                  uint b14, \
5606                                  uint b15) {{"
5607                )?;
5608                writeln!(
5609                    self.out,
5610                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5611                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5612                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5613                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5614                    back::INDENT
5615                )?;
5616                writeln!(self.out, "}}")?;
5617                Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5618            }
5619            Sint32 => {
5620                let name = self.namer.call("unpackSint32");
5621                writeln!(
5622                    self.out,
5623                    "int {name}(uint b0, \
5624                                uint b1, \
5625                                uint b2, \
5626                                uint b3) {{"
5627                )?;
5628                writeln!(
5629                    self.out,
5630                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5631                    back::INDENT
5632                )?;
5633                writeln!(self.out, "}}")?;
5634                Ok((name, 4, None, Scalar::I32))
5635            }
5636            Sint32x2 => {
5637                let name = self.namer.call("unpackSint32x2");
5638                writeln!(
5639                    self.out,
5640                    "metal::int2 {name}(uint b0, \
5641                                        uint b1, \
5642                                        uint b2, \
5643                                        uint b3, \
5644                                        uint b4, \
5645                                        uint b5, \
5646                                        uint b6, \
5647                                        uint b7) {{"
5648                )?;
5649                writeln!(
5650                    self.out,
5651                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5652                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5653                    back::INDENT
5654                )?;
5655                writeln!(self.out, "}}")?;
5656                Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5657            }
5658            Sint32x3 => {
5659                let name = self.namer.call("unpackSint32x3");
5660                writeln!(
5661                    self.out,
5662                    "metal::int3 {name}(uint b0, \
5663                                        uint b1, \
5664                                        uint b2, \
5665                                        uint b3, \
5666                                        uint b4, \
5667                                        uint b5, \
5668                                        uint b6, \
5669                                        uint b7, \
5670                                        uint b8, \
5671                                        uint b9, \
5672                                        uint b10, \
5673                                        uint b11) {{"
5674                )?;
5675                writeln!(
5676                    self.out,
5677                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5678                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5679                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5680                    back::INDENT
5681                )?;
5682                writeln!(self.out, "}}")?;
5683                Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5684            }
5685            Sint32x4 => {
5686                let name = self.namer.call("unpackSint32x4");
5687                writeln!(
5688                    self.out,
5689                    "metal::int4 {name}(uint b0, \
5690                                        uint b1, \
5691                                        uint b2, \
5692                                        uint b3, \
5693                                        uint b4, \
5694                                        uint b5, \
5695                                        uint b6, \
5696                                        uint b7, \
5697                                        uint b8, \
5698                                        uint b9, \
5699                                        uint b10, \
5700                                        uint b11, \
5701                                        uint b12, \
5702                                        uint b13, \
5703                                        uint b14, \
5704                                        uint b15) {{"
5705                )?;
5706                writeln!(
5707                    self.out,
5708                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5709                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5710                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5711                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5712                    back::INDENT
5713                )?;
5714                writeln!(self.out, "}}")?;
5715                Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5716            }
5717            Unorm10_10_10_2 => {
5718                let name = self.namer.call("unpackUnorm10_10_10_2");
5719                writeln!(
5720                    self.out,
5721                    "metal::float4 {name}(uint b0, \
5722                                          uint b1, \
5723                                          uint b2, \
5724                                          uint b3) {{"
5725                )?;
5726                writeln!(
5727                    self.out,
5728                    // The following is correct for RGBA packing, but our format seems to
5729                    // match ABGR, which can be fed into the Metal builtin function
5730                    // unpack_unorm10a2_to_float.
5731                    /*
5732                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5733                       uint r = (v & 0xFFC00000) >> 22; \
5734                       uint g = (v & 0x003FF000) >> 12; \
5735                       uint b = (v & 0x00000FFC) >> 2; \
5736                       uint a = (v & 0x00000003); \
5737                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5738                    */
5739                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5740                    back::INDENT
5741                )?;
5742                writeln!(self.out, "}}")?;
5743                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5744            }
5745            Unorm8x4Bgra => {
5746                let name = self.namer.call("unpackUnorm8x4Bgra");
5747                writeln!(
5748                    self.out,
5749                    "metal::float4 {name}(metal::uchar b0, \
5750                                          metal::uchar b1, \
5751                                          metal::uchar b2, \
5752                                          metal::uchar b3) {{"
5753                )?;
5754                writeln!(
5755                    self.out,
5756                    "{}return metal::float4(float(b2) / 255.0f, \
5757                                            float(b1) / 255.0f, \
5758                                            float(b0) / 255.0f, \
5759                                            float(b3) / 255.0f);",
5760                    back::INDENT
5761                )?;
5762                writeln!(self.out, "}}")?;
5763                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5764            }
5765        }
5766    }
5767
5768    fn write_wrapped_unary_op(
5769        &mut self,
5770        module: &crate::Module,
5771        func_ctx: &back::FunctionCtx,
5772        op: crate::UnaryOperator,
5773        operand: Handle<crate::Expression>,
5774    ) -> BackendResult {
5775        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5776        match op {
5777            // Negating the TYPE_MIN of a two's complement signed integer
5778            // type causes overflow, which is undefined behaviour in MSL. To
5779            // avoid this we bitcast the value to unsigned and negate it,
5780            // then bitcast back to signed.
5781            // This adheres to the WGSL spec in that the negative of the
5782            // type's minimum value should equal to the minimum value.
5783            crate::UnaryOperator::Negate
5784                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5785            {
5786                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5787                    return Ok(());
5788                };
5789                let wrapped = WrappedFunction::UnaryOp {
5790                    op,
5791                    ty: (vector_size, scalar),
5792                };
5793                if !self.wrapped_functions.insert(wrapped) {
5794                    return Ok(());
5795                }
5796
5797                let unsigned_scalar = crate::Scalar {
5798                    kind: crate::ScalarKind::Uint,
5799                    ..scalar
5800                };
5801                let mut type_name = String::new();
5802                let mut unsigned_type_name = String::new();
5803                match vector_size {
5804                    None => {
5805                        put_numeric_type(&mut type_name, scalar, &[])?;
5806                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5807                    }
5808                    Some(size) => {
5809                        put_numeric_type(&mut type_name, scalar, &[size])?;
5810                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5811                    }
5812                };
5813
5814                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5815                let level = back::Level(1);
5816                writeln!(
5817                    self.out,
5818                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5819                )?;
5820                writeln!(self.out, "}}")?;
5821                writeln!(self.out)?;
5822            }
5823            _ => {}
5824        }
5825        Ok(())
5826    }
5827
5828    fn write_wrapped_binary_op(
5829        &mut self,
5830        module: &crate::Module,
5831        func_ctx: &back::FunctionCtx,
5832        expr: Handle<crate::Expression>,
5833        op: crate::BinaryOperator,
5834        left: Handle<crate::Expression>,
5835        right: Handle<crate::Expression>,
5836    ) -> BackendResult {
5837        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5838        let left_ty = func_ctx.resolve_type(left, &module.types);
5839        let right_ty = func_ctx.resolve_type(right, &module.types);
5840        match (op, expr_ty.scalar_kind()) {
5841            // Signed integer division of TYPE_MIN / -1, or signed or
5842            // unsigned division by zero, gives an unspecified value in MSL.
5843            // We override the divisor to 1 in these cases.
5844            // This adheres to the WGSL spec in that:
5845            // * TYPE_MIN / -1 == TYPE_MIN
5846            // * x / 0 == x
5847            (
5848                crate::BinaryOperator::Divide,
5849                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5850            ) => {
5851                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5852                    return Ok(());
5853                };
5854                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5855                    return Ok(());
5856                };
5857                let wrapped = WrappedFunction::BinaryOp {
5858                    op,
5859                    left_ty: left_wrapped_ty,
5860                    right_ty: right_wrapped_ty,
5861                };
5862                if !self.wrapped_functions.insert(wrapped) {
5863                    return Ok(());
5864                }
5865
5866                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5867                    return Ok(());
5868                };
5869                let mut type_name = String::new();
5870                match vector_size {
5871                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5872                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5873                };
5874                writeln!(
5875                    self.out,
5876                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5877                )?;
5878                let level = back::Level(1);
5879                match scalar.kind {
5880                    crate::ScalarKind::Sint => {
5881                        let min_val = match scalar.width {
5882                            4 => crate::Literal::I32(i32::MIN),
5883                            8 => crate::Literal::I64(i64::MIN),
5884                            _ => {
5885                                return Err(Error::GenericValidation(format!(
5886                                    "Unexpected width for scalar {scalar:?}"
5887                                )));
5888                            }
5889                        };
5890                        write!(
5891                            self.out,
5892                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5893                        )?;
5894                        self.put_literal(min_val)?;
5895                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5896                    }
5897                    crate::ScalarKind::Uint => writeln!(
5898                        self.out,
5899                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5900                    )?,
5901                    _ => unreachable!(),
5902                }
5903                writeln!(self.out, "}}")?;
5904                writeln!(self.out)?;
5905            }
5906            // Integer modulo where one or both operands are negative, or the
5907            // divisor is zero, is undefined behaviour in MSL. To avoid this
5908            // we use the following equation:
5909            //
5910            // dividend - (dividend / divisor) * divisor
5911            //
5912            // overriding the divisor to 1 if either it is 0, or it is -1
5913            // and the dividend is TYPE_MIN.
5914            //
5915            // This adheres to the WGSL spec in that:
5916            // * TYPE_MIN % -1 == 0
5917            // * x % 0 == 0
5918            (
5919                crate::BinaryOperator::Modulo,
5920                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5921            ) => {
5922                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5923                    return Ok(());
5924                };
5925                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5926                else {
5927                    return Ok(());
5928                };
5929                let wrapped = WrappedFunction::BinaryOp {
5930                    op,
5931                    left_ty: left_wrapped_ty,
5932                    right_ty: (right_vector_size, right_scalar),
5933                };
5934                if !self.wrapped_functions.insert(wrapped) {
5935                    return Ok(());
5936                }
5937
5938                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5939                    return Ok(());
5940                };
5941                let mut type_name = String::new();
5942                match vector_size {
5943                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5944                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5945                };
5946                let mut rhs_type_name = String::new();
5947                match right_vector_size {
5948                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5949                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5950                };
5951
5952                writeln!(
5953                    self.out,
5954                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5955                )?;
5956                let level = back::Level(1);
5957                match scalar.kind {
5958                    crate::ScalarKind::Sint => {
5959                        let min_val = match scalar.width {
5960                            4 => crate::Literal::I32(i32::MIN),
5961                            8 => crate::Literal::I64(i64::MIN),
5962                            _ => {
5963                                return Err(Error::GenericValidation(format!(
5964                                    "Unexpected width for scalar {scalar:?}"
5965                                )));
5966                            }
5967                        };
5968                        write!(
5969                            self.out,
5970                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5971                        )?;
5972                        self.put_literal(min_val)?;
5973                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5974                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5975                    }
5976                    crate::ScalarKind::Uint => writeln!(
5977                        self.out,
5978                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5979                    )?,
5980                    _ => unreachable!(),
5981                }
5982                writeln!(self.out, "}}")?;
5983                writeln!(self.out)?;
5984            }
5985            _ => {}
5986        }
5987        Ok(())
5988    }
5989
5990    /// Build the mangled helper name for integer vector dot products.
5991    ///
5992    /// `scalar` must be a concrete integer scalar type.
5993    ///
5994    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5995    fn get_dot_wrapper_function_helper_name(
5996        &self,
5997        scalar: crate::Scalar,
5998        size: crate::VectorSize,
5999    ) -> String {
6000        // Check for consistency with [`super::keywords::RESERVED_SET`]
6001        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
6002
6003        let type_name = scalar.to_msl_name();
6004        let size_suffix = common::vector_size_str(size);
6005        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
6006    }
6007
6008    #[allow(clippy::too_many_arguments)]
6009    fn write_wrapped_math_function(
6010        &mut self,
6011        module: &crate::Module,
6012        func_ctx: &back::FunctionCtx,
6013        fun: crate::MathFunction,
6014        arg: Handle<crate::Expression>,
6015        _arg1: Option<Handle<crate::Expression>>,
6016        _arg2: Option<Handle<crate::Expression>>,
6017        _arg3: Option<Handle<crate::Expression>>,
6018    ) -> BackendResult {
6019        let arg_ty = func_ctx.resolve_type(arg, &module.types);
6020        match fun {
6021            // Taking the absolute value of the TYPE_MIN of a two's
6022            // complement signed integer type causes overflow, which is
6023            // undefined behaviour in MSL. To avoid this, when the value is
6024            // negative we bitcast the value to unsigned and negate it, then
6025            // bitcast back to signed.
6026            // This adheres to the WGSL spec in that the absolute of the
6027            // type's minimum value should equal to the minimum value.
6028            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
6029                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
6030                    return Ok(());
6031                };
6032                let wrapped = WrappedFunction::Math {
6033                    fun,
6034                    arg_ty: (vector_size, scalar),
6035                };
6036                if !self.wrapped_functions.insert(wrapped) {
6037                    return Ok(());
6038                }
6039
6040                let unsigned_scalar = crate::Scalar {
6041                    kind: crate::ScalarKind::Uint,
6042                    ..scalar
6043                };
6044                let mut type_name = String::new();
6045                let mut unsigned_type_name = String::new();
6046                match vector_size {
6047                    None => {
6048                        put_numeric_type(&mut type_name, scalar, &[])?;
6049                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
6050                    }
6051                    Some(size) => {
6052                        put_numeric_type(&mut type_name, scalar, &[size])?;
6053                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
6054                    }
6055                };
6056
6057                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
6058                let level = back::Level(1);
6059                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
6060                writeln!(self.out, "}}")?;
6061                writeln!(self.out)?;
6062            }
6063
6064            crate::MathFunction::Dot => match *arg_ty {
6065                crate::TypeInner::Vector { size, scalar }
6066                    if matches!(
6067                        scalar.kind,
6068                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
6069                    ) =>
6070                {
6071                    // De-duplicate per (fun, arg type) like other wrapped math functions
6072                    let wrapped = WrappedFunction::Math {
6073                        fun,
6074                        arg_ty: (Some(size), scalar),
6075                    };
6076                    if !self.wrapped_functions.insert(wrapped) {
6077                        return Ok(());
6078                    }
6079
6080                    let mut vec_ty = String::new();
6081                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
6082                    let mut ret_ty = String::new();
6083                    put_numeric_type(&mut ret_ty, scalar, &[])?;
6084
6085                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6086
6087                    // Emit function signature and body using put_dot_product for the expression
6088                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6089                    let level = back::Level(1);
6090                    write!(self.out, "{level}return ")?;
6091                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6092                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6093                        Ok(())
6094                    })?;
6095                    writeln!(self.out, ";")?;
6096                    writeln!(self.out, "}}")?;
6097                    writeln!(self.out)?;
6098                }
6099                _ => {}
6100            },
6101
6102            _ => {}
6103        }
6104        Ok(())
6105    }
6106
6107    fn write_wrapped_cast(
6108        &mut self,
6109        module: &crate::Module,
6110        func_ctx: &back::FunctionCtx,
6111        expr: Handle<crate::Expression>,
6112        kind: crate::ScalarKind,
6113        convert: Option<crate::Bytes>,
6114    ) -> BackendResult {
6115        // Avoid undefined behaviour when casting from a float to integer
6116        // when the value is out of range for the target type. Additionally
6117        // ensure we clamp to the correct value as per the WGSL spec.
6118        //
6119        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6120        // * If X is exactly representable in the target type T, then the
6121        //   result is that value.
6122        // * Otherwise, the result is the value in T closest to
6123        //   truncate(X) and also exactly representable in the original
6124        //   floating point type.
6125        let src_ty = func_ctx.resolve_type(expr, &module.types);
6126        let Some(width) = convert else {
6127            return Ok(());
6128        };
6129        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6130            return Ok(());
6131        };
6132        let dst_scalar = crate::Scalar { kind, width };
6133        if src_scalar.kind != crate::ScalarKind::Float
6134            || (dst_scalar.kind != crate::ScalarKind::Sint
6135                && dst_scalar.kind != crate::ScalarKind::Uint)
6136        {
6137            return Ok(());
6138        }
6139        let wrapped = WrappedFunction::Cast {
6140            src_scalar,
6141            vector_size,
6142            dst_scalar,
6143        };
6144        if !self.wrapped_functions.insert(wrapped) {
6145            return Ok(());
6146        }
6147        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6148
6149        let mut src_type_name = String::new();
6150        match vector_size {
6151            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6152            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6153        };
6154        let mut dst_type_name = String::new();
6155        match vector_size {
6156            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6157            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6158        };
6159        let fun_name = match dst_scalar {
6160            crate::Scalar::I32 => F2I32_FUNCTION,
6161            crate::Scalar::U32 => F2U32_FUNCTION,
6162            crate::Scalar::I64 => F2I64_FUNCTION,
6163            crate::Scalar::U64 => F2U64_FUNCTION,
6164            _ => unreachable!(),
6165        };
6166
6167        writeln!(
6168            self.out,
6169            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6170        )?;
6171        let level = back::Level(1);
6172        write!(
6173            self.out,
6174            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6175        )?;
6176        self.put_literal(min)?;
6177        write!(self.out, ", ")?;
6178        self.put_literal(max)?;
6179        writeln!(self.out, "));")?;
6180        writeln!(self.out, "}}")?;
6181        writeln!(self.out)?;
6182        Ok(())
6183    }
6184
6185    /// Helper function used by [`Self::write_wrapped_image_load`] and
6186    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6187    /// conversion code for external textures. Expects the preceding code to
6188    /// declare the Y component as a `float` variable of name `y`, the UV
6189    /// components as a `float2` variable of name `uv`, and the external
6190    /// texture params as a variable of name `params`. The emitted code will
6191    /// return the result.
6192    fn write_convert_yuv_to_rgb_and_return(
6193        &mut self,
6194        level: back::Level,
6195        y: &str,
6196        uv: &str,
6197        params: &str,
6198    ) -> BackendResult {
6199        let l1 = level;
6200        let l2 = l1.next();
6201
6202        // Convert from YUV to non-linear RGB in the source color space.
6203        writeln!(
6204            self.out,
6205            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6206        )?;
6207
6208        // Apply the inverse of the source transfer function to convert to
6209        // linear RGB in the source color space.
6210        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6211        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6212        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6213        writeln!(
6214            self.out,
6215            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6216        )?;
6217
6218        // Multiply by the gamut conversion matrix to convert to linear RGB in
6219        // the destination color space.
6220        writeln!(
6221            self.out,
6222            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6223        )?;
6224
6225        // Finally, apply the dest transfer function to convert to non-linear
6226        // RGB in the destination color space, and return the result.
6227        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6228        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6229        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6230        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6231
6232        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6233        Ok(())
6234    }
6235
6236    #[allow(clippy::too_many_arguments)]
6237    fn write_wrapped_image_load(
6238        &mut self,
6239        module: &crate::Module,
6240        func_ctx: &back::FunctionCtx,
6241        image: Handle<crate::Expression>,
6242        _coordinate: Handle<crate::Expression>,
6243        _array_index: Option<Handle<crate::Expression>>,
6244        _sample: Option<Handle<crate::Expression>>,
6245        _level: Option<Handle<crate::Expression>>,
6246    ) -> BackendResult {
6247        // We currently only need to wrap image loads for external textures
6248        let class = match *func_ctx.resolve_type(image, &module.types) {
6249            crate::TypeInner::Image { class, .. } => class,
6250            _ => unreachable!(),
6251        };
6252        if class != crate::ImageClass::External {
6253            return Ok(());
6254        }
6255        let wrapped = WrappedFunction::ImageLoad { class };
6256        if !self.wrapped_functions.insert(wrapped) {
6257            return Ok(());
6258        }
6259
6260        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6261        let l1 = back::Level(1);
6262        let l2 = l1.next();
6263        let l3 = l2.next();
6264        writeln!(
6265            self.out,
6266            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6267        )?;
6268        // Clamp coords to provided size of external texture to prevent OOB
6269        // read. If params.size is zero then clamp to the actual size of the
6270        // texture.
6271        writeln!(
6272            self.out,
6273            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6274        )?;
6275        writeln!(
6276            self.out,
6277            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6278        )?;
6279
6280        // Apply load transformation
6281        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6282        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6283        // For single plane, simply read from plane0
6284        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6285        writeln!(self.out, "{l1}}} else {{")?;
6286
6287        // Chroma planes may be subsampled so we must scale the coords accordingly.
6288        writeln!(
6289            self.out,
6290            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6291        )?;
6292        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6293
6294        // For multi-plane, read the Y value from plane 0
6295        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6296
6297        writeln!(self.out, "{l2}float2 uv;")?;
6298        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6299        // For 2 planes, read UV from interleaved plane 1
6300        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6301        writeln!(self.out, "{l2}}} else {{")?;
6302        // For 3 planes, read U and V from planes 1 and 2 respectively
6303        writeln!(
6304            self.out,
6305            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6306        )?;
6307        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6308        writeln!(
6309            self.out,
6310            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6311        )?;
6312        writeln!(self.out, "{l2}}}")?;
6313
6314        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6315
6316        writeln!(self.out, "{l1}}}")?;
6317        writeln!(self.out, "}}")?;
6318        writeln!(self.out)?;
6319        Ok(())
6320    }
6321
6322    #[allow(clippy::too_many_arguments)]
6323    fn write_wrapped_image_sample(
6324        &mut self,
6325        module: &crate::Module,
6326        func_ctx: &back::FunctionCtx,
6327        image: Handle<crate::Expression>,
6328        _sampler: Handle<crate::Expression>,
6329        _gather: Option<crate::SwizzleComponent>,
6330        _coordinate: Handle<crate::Expression>,
6331        _array_index: Option<Handle<crate::Expression>>,
6332        _offset: Option<Handle<crate::Expression>>,
6333        _level: crate::SampleLevel,
6334        _depth_ref: Option<Handle<crate::Expression>>,
6335        clamp_to_edge: bool,
6336    ) -> BackendResult {
6337        // We currently only need to wrap textureSampleBaseClampToEdge, for
6338        // both sampled and external textures.
6339        if !clamp_to_edge {
6340            return Ok(());
6341        }
6342        let class = match *func_ctx.resolve_type(image, &module.types) {
6343            crate::TypeInner::Image { class, .. } => class,
6344            _ => unreachable!(),
6345        };
6346        let wrapped = WrappedFunction::ImageSample {
6347            class,
6348            clamp_to_edge: true,
6349        };
6350        if !self.wrapped_functions.insert(wrapped) {
6351            return Ok(());
6352        }
6353        match class {
6354            crate::ImageClass::External => {
6355                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6356                let l1 = back::Level(1);
6357                let l2 = l1.next();
6358                let l3 = l2.next();
6359                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6360                writeln!(
6361                    self.out,
6362                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6363                )?;
6364
6365                // Calculate the sample bounds. The purported size of the texture
6366                // (params.size) is irrelevant here as we are dealing with normalized
6367                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6368                // apply the sample transformation to that, also bearing in mind that it
6369                // may contain a flip on either axis. We calculate and adjust for the
6370                // half-texel separately for each plane as it depends on the actual
6371                // texture size which may vary between planes.
6372                writeln!(
6373                    self.out,
6374                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6375                )?;
6376                writeln!(
6377                    self.out,
6378                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6379                )?;
6380                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6381                writeln!(
6382                    self.out,
6383                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6384                )?;
6385                writeln!(
6386                    self.out,
6387                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6388                )?;
6389                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6390                // For single plane, simply sample from plane0
6391                writeln!(
6392                    self.out,
6393                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6394                )?;
6395                writeln!(self.out, "{l1}}} else {{")?;
6396                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6397                writeln!(
6398                    self.out,
6399                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6400                )?;
6401                writeln!(
6402                    self.out,
6403                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6404                )?;
6405
6406                // For multi-plane, sample the Y value from plane 0
6407                writeln!(
6408                    self.out,
6409                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6410                )?;
6411                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6412                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6413                // For 2 planes, sample UV from interleaved plane 1
6414                writeln!(
6415                    self.out,
6416                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6417                )?;
6418                writeln!(self.out, "{l2}}} else {{")?;
6419                // For 3 planes, sample U and V from planes 1 and 2 respectively
6420                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6421                writeln!(
6422                    self.out,
6423                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6424                )?;
6425                writeln!(
6426                    self.out,
6427                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6428                )?;
6429                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6430                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6431                writeln!(self.out, "{l2}}}")?;
6432
6433                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6434
6435                writeln!(self.out, "{l1}}}")?;
6436                writeln!(self.out, "}}")?;
6437                writeln!(self.out)?;
6438            }
6439            _ => {
6440                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6441                let l1 = back::Level(1);
6442                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6443                writeln!(
6444                    self.out,
6445                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6446                )?;
6447                writeln!(self.out, "}}")?;
6448                writeln!(self.out)?;
6449            }
6450        }
6451        Ok(())
6452    }
6453
6454    fn write_wrapped_image_query(
6455        &mut self,
6456        module: &crate::Module,
6457        func_ctx: &back::FunctionCtx,
6458        image: Handle<crate::Expression>,
6459        query: crate::ImageQuery,
6460    ) -> BackendResult {
6461        // We currently only need to wrap size image queries for external textures
6462        if !matches!(query, crate::ImageQuery::Size { .. }) {
6463            return Ok(());
6464        }
6465        let class = match *func_ctx.resolve_type(image, &module.types) {
6466            crate::TypeInner::Image { class, .. } => class,
6467            _ => unreachable!(),
6468        };
6469        if class != crate::ImageClass::External {
6470            return Ok(());
6471        }
6472        let wrapped = WrappedFunction::ImageQuerySize { class };
6473        if !self.wrapped_functions.insert(wrapped) {
6474            return Ok(());
6475        }
6476        writeln!(
6477            self.out,
6478            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6479        )?;
6480        let l1 = back::Level(1);
6481        let l2 = l1.next();
6482        writeln!(
6483            self.out,
6484            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6485        )?;
6486        writeln!(self.out, "{l2}return tex.params.size;")?;
6487        writeln!(self.out, "{l1}}} else {{")?;
6488        // params.size == (0, 0) indicates to query and return plane 0's actual size
6489        writeln!(
6490            self.out,
6491            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6492        )?;
6493        writeln!(self.out, "{l1}}}")?;
6494        writeln!(self.out, "}}")?;
6495        writeln!(self.out)?;
6496        Ok(())
6497    }
6498
6499    fn write_wrapped_cooperative_load(
6500        &mut self,
6501        module: &crate::Module,
6502        func_ctx: &back::FunctionCtx,
6503        columns: crate::CooperativeSize,
6504        rows: crate::CooperativeSize,
6505        pointer: Handle<crate::Expression>,
6506    ) -> BackendResult {
6507        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6508        let space = ptr_ty.pointer_space().unwrap();
6509        let space_name = space.to_msl_name().unwrap_or_default();
6510        let scalar = ptr_ty
6511            .pointer_base_type()
6512            .unwrap()
6513            .inner_with(&module.types)
6514            .scalar()
6515            .unwrap();
6516        let wrapped = WrappedFunction::CooperativeLoad {
6517            space_name,
6518            columns,
6519            rows,
6520            scalar,
6521        };
6522        if !self.wrapped_functions.insert(wrapped) {
6523            return Ok(());
6524        }
6525        let scalar_name = scalar.to_msl_name();
6526        writeln!(
6527            self.out,
6528            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6529            columns as u32, rows as u32,
6530        )?;
6531        let l1 = back::Level(1);
6532        writeln!(
6533            self.out,
6534            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6535            columns as u32, rows as u32
6536        )?;
6537        let matrix_origin = "0";
6538        writeln!(
6539            self.out,
6540            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6541        )?;
6542        writeln!(self.out, "{l1}return m;")?;
6543        writeln!(self.out, "}}")?;
6544        writeln!(self.out)?;
6545        Ok(())
6546    }
6547
6548    fn write_wrapped_cooperative_multiply_add(
6549        &mut self,
6550        module: &crate::Module,
6551        func_ctx: &back::FunctionCtx,
6552        space: crate::AddressSpace,
6553        a: Handle<crate::Expression>,
6554        b: Handle<crate::Expression>,
6555    ) -> BackendResult {
6556        let space_name = space.to_msl_name().unwrap_or_default();
6557        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6558            crate::TypeInner::CooperativeMatrix {
6559                columns,
6560                rows,
6561                scalar,
6562                ..
6563            } => (columns, rows, scalar),
6564            _ => unreachable!(),
6565        };
6566        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6567            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6568            _ => unreachable!(),
6569        };
6570        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6571            space_name,
6572            columns: b_c,
6573            rows: a_r,
6574            intermediate: a_c,
6575            scalar,
6576        };
6577        if !self.wrapped_functions.insert(wrapped) {
6578            return Ok(());
6579        }
6580        let scalar_name = scalar.to_msl_name();
6581        writeln!(
6582            self.out,
6583            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6584            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6585        )?;
6586        let l1 = back::Level(1);
6587        writeln!(
6588            self.out,
6589            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6590            b_c as u32, a_r as u32
6591        )?;
6592        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6593        writeln!(self.out, "{l1}return d;")?;
6594        writeln!(self.out, "}}")?;
6595        writeln!(self.out)?;
6596        Ok(())
6597    }
6598
6599    pub(super) fn write_wrapped_functions(
6600        &mut self,
6601        module: &crate::Module,
6602        func_ctx: &back::FunctionCtx,
6603    ) -> BackendResult {
6604        for (expr_handle, expr) in func_ctx.expressions.iter() {
6605            match *expr {
6606                crate::Expression::Unary { op, expr: operand } => {
6607                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6608                }
6609                crate::Expression::Binary { op, left, right } => {
6610                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6611                }
6612                crate::Expression::Math {
6613                    fun,
6614                    arg,
6615                    arg1,
6616                    arg2,
6617                    arg3,
6618                } => {
6619                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6620                }
6621                crate::Expression::As {
6622                    expr,
6623                    kind,
6624                    convert,
6625                } => {
6626                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6627                }
6628                crate::Expression::ImageLoad {
6629                    image,
6630                    coordinate,
6631                    array_index,
6632                    sample,
6633                    level,
6634                } => {
6635                    self.write_wrapped_image_load(
6636                        module,
6637                        func_ctx,
6638                        image,
6639                        coordinate,
6640                        array_index,
6641                        sample,
6642                        level,
6643                    )?;
6644                }
6645                crate::Expression::ImageSample {
6646                    image,
6647                    sampler,
6648                    gather,
6649                    coordinate,
6650                    array_index,
6651                    offset,
6652                    level,
6653                    depth_ref,
6654                    clamp_to_edge,
6655                } => {
6656                    self.write_wrapped_image_sample(
6657                        module,
6658                        func_ctx,
6659                        image,
6660                        sampler,
6661                        gather,
6662                        coordinate,
6663                        array_index,
6664                        offset,
6665                        level,
6666                        depth_ref,
6667                        clamp_to_edge,
6668                    )?;
6669                }
6670                crate::Expression::ImageQuery { image, query } => {
6671                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6672                }
6673                crate::Expression::CooperativeLoad {
6674                    columns,
6675                    rows,
6676                    role: _,
6677                    ref data,
6678                } => {
6679                    self.write_wrapped_cooperative_load(
6680                        module,
6681                        func_ctx,
6682                        columns,
6683                        rows,
6684                        data.pointer,
6685                    )?;
6686                }
6687                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6688                    let space = crate::AddressSpace::Private;
6689                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6690                }
6691                _ => {}
6692            }
6693        }
6694
6695        Ok(())
6696    }
6697
6698    // Returns the array of mapped entry point names.
6699    fn write_functions(
6700        &mut self,
6701        module: &crate::Module,
6702        mod_info: &valid::ModuleInfo,
6703        options: &Options,
6704        pipeline_options: &PipelineOptions,
6705    ) -> Result<TranslationInfo, Error> {
6706        use back::msl::VertexFormat;
6707
6708        // Define structs to hold resolved/generated data for vertex buffers and
6709        // their attributes.
6710        struct AttributeMappingResolved {
6711            ty_name: String,
6712            dimension: Option<crate::VectorSize>,
6713            scalar: crate::Scalar,
6714            name: String,
6715        }
6716        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6717
6718        struct VertexBufferMappingResolved<'a> {
6719            id: u32,
6720            stride: u32,
6721            step_mode: back::msl::VertexBufferStepMode,
6722            ty_name: String,
6723            param_name: String,
6724            elem_name: String,
6725            attributes: &'a Vec<back::msl::AttributeMapping>,
6726        }
6727        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6728
6729        // Define a struct to hold a named reference to a byte-unpacking function.
6730        struct UnpackingFunction {
6731            name: String,
6732            byte_count: u32,
6733            dimension: Option<crate::VectorSize>,
6734            scalar: crate::Scalar,
6735        }
6736        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6737
6738        // Check if we are attempting vertex pulling. If we are, generate some
6739        // names we'll need, and iterate the vertex buffer mappings to output
6740        // all the conversion functions we'll need to unpack the attribute data.
6741        // We can re-use these names for all entry points that need them, since
6742        // those entry points also use self.namer.
6743        let mut needs_vertex_id = false;
6744        let v_id = self.namer.call("v_id");
6745
6746        let mut needs_instance_id = false;
6747        let i_id = self.namer.call("i_id");
6748        if pipeline_options.vertex_pulling_transform {
6749            for vbm in &pipeline_options.vertex_buffer_mappings {
6750                let buffer_id = vbm.id;
6751                let buffer_stride = vbm.stride;
6752
6753                assert!(
6754                    buffer_stride > 0,
6755                    "Vertex pulling requires a non-zero buffer stride."
6756                );
6757
6758                match vbm.step_mode {
6759                    back::msl::VertexBufferStepMode::Constant => {}
6760                    back::msl::VertexBufferStepMode::ByVertex => {
6761                        needs_vertex_id = true;
6762                    }
6763                    back::msl::VertexBufferStepMode::ByInstance => {
6764                        needs_instance_id = true;
6765                    }
6766                }
6767
6768                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6769                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6770                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6771
6772                vbm_resolved.push(VertexBufferMappingResolved {
6773                    id: buffer_id,
6774                    stride: buffer_stride,
6775                    step_mode: vbm.step_mode,
6776                    ty_name: buffer_ty,
6777                    param_name: buffer_param,
6778                    elem_name: buffer_elem,
6779                    attributes: &vbm.attributes,
6780                });
6781
6782                // Iterate the attributes and generate needed unpacking functions.
6783                for attribute in &vbm.attributes {
6784                    if unpacking_functions.contains_key(&attribute.format) {
6785                        continue;
6786                    }
6787                    let (name, byte_count, dimension, scalar) =
6788                        match self.write_unpacking_function(attribute.format) {
6789                            Ok((name, byte_count, dimension, scalar)) => {
6790                                (name, byte_count, dimension, scalar)
6791                            }
6792                            _ => {
6793                                continue;
6794                            }
6795                        };
6796                    unpacking_functions.insert(
6797                        attribute.format,
6798                        UnpackingFunction {
6799                            name,
6800                            byte_count,
6801                            dimension,
6802                            scalar,
6803                        },
6804                    );
6805                }
6806            }
6807        }
6808
6809        let mut pass_through_globals = Vec::new();
6810        for (fun_handle, fun) in module.functions.iter() {
6811            log::trace!(
6812                "function {:?}, handle {:?}",
6813                fun.name.as_deref().unwrap_or("(anonymous)"),
6814                fun_handle
6815            );
6816
6817            let ctx = back::FunctionCtx {
6818                ty: back::FunctionType::Function(fun_handle),
6819                info: &mod_info[fun_handle],
6820                expressions: &fun.expressions,
6821                named_expressions: &fun.named_expressions,
6822            };
6823
6824            writeln!(self.out)?;
6825            self.write_wrapped_functions(module, &ctx)?;
6826
6827            let fun_info = &mod_info[fun_handle];
6828            pass_through_globals.clear();
6829            let mut needs_buffer_sizes = false;
6830            for (handle, var) in module.global_variables.iter() {
6831                if !fun_info[handle].is_empty() {
6832                    if var.space.needs_pass_through() {
6833                        pass_through_globals.push(handle);
6834                    }
6835                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6836                }
6837            }
6838
6839            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6840            match fun.result {
6841                Some(ref result) => {
6842                    let ty_name = TypeContext {
6843                        handle: result.ty,
6844                        gctx: module.to_ctx(),
6845                        names: &self.names,
6846                        access: crate::StorageAccess::empty(),
6847                        first_time: false,
6848                    };
6849                    write!(self.out, "{ty_name}")?;
6850                }
6851                None => {
6852                    write!(self.out, "void")?;
6853                }
6854            }
6855            writeln!(self.out, " {fun_name}(")?;
6856
6857            for (index, arg) in fun.arguments.iter().enumerate() {
6858                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6859                let param_type_name = TypeContext {
6860                    handle: arg.ty,
6861                    gctx: module.to_ctx(),
6862                    names: &self.names,
6863                    access: crate::StorageAccess::empty(),
6864                    first_time: false,
6865                };
6866                let separator = separate(
6867                    !pass_through_globals.is_empty()
6868                        || index + 1 != fun.arguments.len()
6869                        || needs_buffer_sizes,
6870                );
6871                writeln!(
6872                    self.out,
6873                    "{}{} {}{}",
6874                    back::INDENT,
6875                    param_type_name,
6876                    name,
6877                    separator
6878                )?;
6879            }
6880            for (index, &handle) in pass_through_globals.iter().enumerate() {
6881                let tyvar = TypedGlobalVariable {
6882                    module,
6883                    names: &self.names,
6884                    handle,
6885                    usage: fun_info[handle],
6886                    reference: true,
6887                };
6888                let separator =
6889                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6890                write!(self.out, "{}", back::INDENT)?;
6891                tyvar.try_fmt(&mut self.out)?;
6892                writeln!(self.out, "{separator}")?;
6893            }
6894
6895            if needs_buffer_sizes {
6896                writeln!(
6897                    self.out,
6898                    "{}constant _mslBufferSizes& _buffer_sizes",
6899                    back::INDENT
6900                )?;
6901            }
6902
6903            writeln!(self.out, ") {{")?;
6904
6905            let guarded_indices =
6906                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6907
6908            let context = StatementContext {
6909                expression: ExpressionContext {
6910                    function: fun,
6911                    origin: FunctionOrigin::Handle(fun_handle),
6912                    info: fun_info,
6913                    lang_version: options.lang_version,
6914                    policies: options.bounds_check_policies,
6915                    guarded_indices,
6916                    module,
6917                    mod_info,
6918                    pipeline_options,
6919                    force_loop_bounding: options.force_loop_bounding,
6920                },
6921                result_struct: None,
6922            };
6923
6924            self.put_locals(&context.expression)?;
6925            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6926            self.put_block(back::Level(1), &fun.body, &context)?;
6927            writeln!(self.out, "}}")?;
6928            self.named_expressions.clear();
6929        }
6930
6931        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6932            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6933
6934        let mut info = TranslationInfo {
6935            entry_point_names: Vec::with_capacity(ep_range.len()),
6936        };
6937
6938        for ep_index in ep_range {
6939            let ep = &module.entry_points[ep_index];
6940            let fun = &ep.function;
6941            let fun_info = mod_info.get_entry_point(ep_index);
6942            let mut ep_error = None;
6943
6944            // For vertex_id and instance_id arguments, presume that we'll
6945            // use our generated names, but switch to the name of an
6946            // existing @builtin param, if we find one.
6947            let mut v_existing_id = None;
6948            let mut i_existing_id = None;
6949
6950            log::trace!(
6951                "entry point {:?}, index {:?}",
6952                fun.name.as_deref().unwrap_or("(anonymous)"),
6953                ep_index
6954            );
6955
6956            let ctx = back::FunctionCtx {
6957                ty: back::FunctionType::EntryPoint(ep_index as u16),
6958                info: fun_info,
6959                expressions: &fun.expressions,
6960                named_expressions: &fun.named_expressions,
6961            };
6962
6963            self.write_wrapped_functions(module, &ctx)?;
6964
6965            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6966                crate::ShaderStage::Vertex => (
6967                    Some("vertex"),
6968                    LocationMode::VertexInput,
6969                    LocationMode::VertexOutput,
6970                    true,
6971                ),
6972                crate::ShaderStage::Fragment => (
6973                    Some("fragment"),
6974                    LocationMode::FragmentInput,
6975                    LocationMode::FragmentOutput,
6976                    false,
6977                ),
6978                crate::ShaderStage::Compute => (
6979                    Some("kernel"),
6980                    LocationMode::Uniform,
6981                    LocationMode::Uniform,
6982                    false,
6983                ),
6984                crate::ShaderStage::Task => {
6985                    (None, LocationMode::Uniform, LocationMode::Uniform, false)
6986                }
6987                crate::ShaderStage::Mesh => {
6988                    (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6989                }
6990                crate::ShaderStage::RayGeneration
6991                | crate::ShaderStage::AnyHit
6992                | crate::ShaderStage::ClosestHit
6993                | crate::ShaderStage::Miss => unimplemented!(),
6994            };
6995
6996            // Should this entry point be modified to do vertex pulling?
6997            let do_vertex_pulling = can_vertex_pull
6998                && pipeline_options.vertex_pulling_transform
6999                && !pipeline_options.vertex_buffer_mappings.is_empty();
7000
7001            // Is any global variable used by this entry point dynamically sized?
7002            let needs_buffer_sizes = do_vertex_pulling
7003                || module
7004                    .global_variables
7005                    .iter()
7006                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
7007                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
7008
7009            // skip this entry point if any global bindings are missing,
7010            // or their types are incompatible.
7011            if !options.fake_missing_bindings {
7012                for (var_handle, var) in module.global_variables.iter() {
7013                    if fun_info[var_handle].is_empty() {
7014                        continue;
7015                    }
7016                    match var.space {
7017                        crate::AddressSpace::Uniform
7018                        | crate::AddressSpace::Storage { .. }
7019                        | crate::AddressSpace::Handle => {
7020                            let br = match var.binding {
7021                                Some(ref br) => br,
7022                                None => {
7023                                    let var_name = var.name.clone().unwrap_or_default();
7024                                    ep_error =
7025                                        Some(super::EntryPointError::MissingBinding(var_name));
7026                                    break;
7027                                }
7028                            };
7029                            let target = options.get_resource_binding_target(ep, br);
7030                            let good = match target {
7031                                Some(target) => {
7032                                    // We intentionally don't dereference binding_arrays here,
7033                                    // so that binding arrays fall to the buffer location.
7034
7035                                    match module.types[var.ty].inner {
7036                                        crate::TypeInner::Image {
7037                                            class: crate::ImageClass::External,
7038                                            ..
7039                                        } => target.external_texture.is_some(),
7040                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
7041                                        crate::TypeInner::Sampler { .. } => {
7042                                            target.sampler.is_some()
7043                                        }
7044                                        _ => target.buffer.is_some(),
7045                                    }
7046                                }
7047                                None => false,
7048                            };
7049                            if !good {
7050                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
7051                                break;
7052                            }
7053                        }
7054                        crate::AddressSpace::Immediate => {
7055                            if let Err(e) = options.resolve_immediates(ep) {
7056                                ep_error = Some(e);
7057                                break;
7058                            }
7059                        }
7060                        crate::AddressSpace::Function
7061                        | crate::AddressSpace::Private
7062                        | crate::AddressSpace::WorkGroup
7063                        | crate::AddressSpace::TaskPayload => {}
7064                        crate::AddressSpace::RayPayload
7065                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
7066                    }
7067                }
7068                if needs_buffer_sizes {
7069                    if let Err(err) = options.resolve_sizes_buffer(ep) {
7070                        ep_error = Some(err);
7071                    }
7072                }
7073            }
7074
7075            if let Some(err) = ep_error {
7076                info.entry_point_names.push(Err(err));
7077                continue;
7078            }
7079            let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7080            info.entry_point_names.push(Ok(fun_name.clone()));
7081
7082            writeln!(self.out)?;
7083
7084            // Since `Namer.reset` wasn't expecting struct members to be
7085            // suddenly injected into another namespace like this,
7086            // `self.names` doesn't keep them distinct from other variables.
7087            // Generate fresh names for these arguments, and remember the
7088            // mapping.
7089            let mut flattened_member_names = FastHashMap::default();
7090            // Varyings' members get their own namespace
7091            let mut varyings_namer = proc::Namer::default();
7092
7093            let mut empty_names = FastHashMap::default(); // Create a throwaway map
7094            varyings_namer.reset(
7095                module,
7096                &super::keywords::RESERVED_SET,
7097                proc::KeywordSet::empty(),
7098                proc::CaseInsensitiveKeywordSet::empty(),
7099                &[CLAMPED_LOD_LOAD_PREFIX],
7100                &mut empty_names,
7101            );
7102
7103            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7104            // flattening structs into their members. In Metal, we will pass
7105            // each of these values to the entry point as a separate argument—
7106            // except for the varyings, handled next.
7107            let mut flattened_arguments = Vec::new();
7108            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7109                match module.types[arg.ty].inner {
7110                    crate::TypeInner::Struct { ref members, .. } => {
7111                        for (member_index, member) in members.iter().enumerate() {
7112                            let member_index = member_index as u32;
7113                            flattened_arguments.push((
7114                                NameKey::StructMember(arg.ty, member_index),
7115                                member.ty,
7116                                member.binding.as_ref(),
7117                            ));
7118                            let name_key = NameKey::StructMember(arg.ty, member_index);
7119                            let name = match member.binding {
7120                                Some(crate::Binding::Location { .. }) => {
7121                                    if do_vertex_pulling {
7122                                        self.namer.call(&self.names[&name_key])
7123                                    } else {
7124                                        varyings_namer.call(&self.names[&name_key])
7125                                    }
7126                                }
7127                                _ => self.namer.call(&self.names[&name_key]),
7128                            };
7129                            flattened_member_names.insert(name_key, name);
7130                        }
7131                    }
7132                    _ => flattened_arguments.push((
7133                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7134                        arg.ty,
7135                        arg.binding.as_ref(),
7136                    )),
7137                }
7138            }
7139
7140            // Identify the varyings among the argument values, and maybe emit
7141            // a struct type named `<fun>Input` to hold them. If we are doing
7142            // vertex pulling, we instead update our attribute mapping to
7143            // note the types, names, and zero values of the attributes.
7144            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7145            let varyings_member_name = self.namer.call("varyings");
7146            let mut has_varyings = false;
7147
7148            if !flattened_arguments.is_empty() {
7149                if !do_vertex_pulling {
7150                    writeln!(self.out, "struct {stage_in_name} {{")?;
7151                }
7152                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7153                    let Some(binding) = binding else {
7154                        continue;
7155                    };
7156                    let name = match *name_key {
7157                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7158                        _ => &self.names[name_key],
7159                    };
7160                    let ty_name = TypeContext {
7161                        handle: ty,
7162                        gctx: module.to_ctx(),
7163                        names: &self.names,
7164                        access: crate::StorageAccess::empty(),
7165                        first_time: false,
7166                    };
7167                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7168                    let location = match *binding {
7169                        crate::Binding::Location { location, .. } => Some(location),
7170                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7171                        crate::Binding::BuiltIn(_) => continue,
7172                    };
7173                    if do_vertex_pulling {
7174                        let Some(location) = location else {
7175                            continue;
7176                        };
7177                        // Update our attribute mapping.
7178                        am_resolved.insert(
7179                            location,
7180                            AttributeMappingResolved {
7181                                ty_name: ty_name.to_string(),
7182                                dimension: ty_name.vector_size(),
7183                                scalar: ty_name.scalar().unwrap(),
7184                                name: name.to_string(),
7185                            },
7186                        );
7187                    } else {
7188                        has_varyings = true;
7189                        if let super::ResolvedBinding::User {
7190                            prefix,
7191                            index,
7192                            interpolation: Some(super::ResolvedInterpolation::PerVertex),
7193                        } = resolved
7194                        {
7195                            if options.lang_version < (4, 0) {
7196                                return Err(Error::PerVertexNotSupported);
7197                            }
7198                            write!(
7199                                self.out,
7200                                "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7201                                back::INDENT,
7202                                ty_name.unwrap_array()
7203                            )?;
7204                        } else {
7205                            write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7206                            resolved.try_fmt(&mut self.out)?;
7207                        }
7208                        writeln!(self.out, ";")?;
7209                    }
7210                }
7211                if !do_vertex_pulling {
7212                    writeln!(self.out, "}};")?;
7213                }
7214            }
7215
7216            // Define a struct type named for the return value, if any, named
7217            // `<fun>Output`.
7218            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7219            let result_member_name = self.namer.call("member");
7220            let result_type_name = match fun.result {
7221                Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7222                    let mut result_members = Vec::new();
7223                    if let crate::TypeInner::Struct { ref members, .. } =
7224                        module.types[result.ty].inner
7225                    {
7226                        for (member_index, member) in members.iter().enumerate() {
7227                            result_members.push((
7228                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7229                                member.ty,
7230                                member.binding.as_ref(),
7231                            ));
7232                        }
7233                    } else {
7234                        result_members.push((
7235                            &result_member_name,
7236                            result.ty,
7237                            result.binding.as_ref(),
7238                        ));
7239                    }
7240
7241                    writeln!(self.out, "struct {stage_out_name} {{")?;
7242                    let mut has_point_size = false;
7243                    for (name, ty, binding) in result_members {
7244                        let ty_name = TypeContext {
7245                            handle: ty,
7246                            gctx: module.to_ctx(),
7247                            names: &self.names,
7248                            access: crate::StorageAccess::empty(),
7249                            first_time: true,
7250                        };
7251                        let binding = binding.ok_or_else(|| {
7252                            Error::GenericValidation("Expected binding, got None".into())
7253                        })?;
7254
7255                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7256                            has_point_size = true;
7257                            if !pipeline_options.allow_and_force_point_size {
7258                                continue;
7259                            }
7260                        }
7261
7262                        let array_len = match module.types[ty].inner {
7263                            crate::TypeInner::Array {
7264                                size: crate::ArraySize::Constant(size),
7265                                ..
7266                            } => Some(size),
7267                            _ => None,
7268                        };
7269                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7270                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7271                        resolved.try_fmt(&mut self.out)?;
7272                        if let Some(array_len) = array_len {
7273                            write!(self.out, " [{array_len}]")?;
7274                        }
7275                        writeln!(self.out, ";")?;
7276                    }
7277
7278                    if pipeline_options.allow_and_force_point_size
7279                        && ep.stage == crate::ShaderStage::Vertex
7280                        && !has_point_size
7281                    {
7282                        // inject the point size output last
7283                        writeln!(
7284                            self.out,
7285                            "{}float _point_size [[point_size]];",
7286                            back::INDENT
7287                        )?;
7288                    }
7289                    writeln!(self.out, "}};")?;
7290                    &stage_out_name
7291                }
7292                Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7293                    assert_eq!(
7294                        module.types[result.ty].inner,
7295                        crate::TypeInner::Vector {
7296                            size: crate::VectorSize::Tri,
7297                            scalar: crate::Scalar::U32
7298                        }
7299                    );
7300
7301                    "metal::uint3"
7302                }
7303                _ => "void",
7304            };
7305
7306            let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7307                Some(self.write_mesh_output_types(
7308                    mesh_info,
7309                    &fun_name,
7310                    module,
7311                    pipeline_options.allow_and_force_point_size,
7312                    options,
7313                )?)
7314            } else {
7315                None
7316            };
7317
7318            // If we're doing a vertex pulling transform, define the buffer
7319            // structure types.
7320            if do_vertex_pulling {
7321                for vbm in &vbm_resolved {
7322                    let buffer_stride = vbm.stride;
7323                    let buffer_ty = &vbm.ty_name;
7324
7325                    // Define a structure of bytes of the appropriate size.
7326                    // When we access the attributes, we'll be unpacking these
7327                    // bytes at some offset.
7328                    writeln!(
7329                        self.out,
7330                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7331                    )?;
7332                }
7333            }
7334
7335            let is_wrapped = matches!(
7336                ep.stage,
7337                crate::ShaderStage::Task | crate::ShaderStage::Mesh
7338            );
7339            let fun_name = fun_name.clone();
7340            let nested_fun_name = if is_wrapped {
7341                self.namer.call(&format!("_{fun_name}"))
7342            } else {
7343                fun_name.clone()
7344            };
7345
7346            // https://web.archive.org/web/20181029003926/https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
7347            if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7348                let total_threads =
7349                    ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7350                write!(
7351                    self.out,
7352                    "[[max_total_threads_per_threadgroup({total_threads})]] "
7353                )?;
7354            }
7355
7356            // Write the entry point function's name, and begin its argument list.
7357            if let Some(em_str) = em_str {
7358                write!(self.out, "{em_str} ")?;
7359            }
7360            writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7361
7362            let mut args = Vec::new();
7363
7364            // If we have produced a struct holding the `EntryPoint`'s
7365            // `Function`'s arguments' varyings, pass that struct first.
7366            if has_varyings {
7367                args.push(EntryPointArgument {
7368                    ty_name: stage_in_name,
7369                    name: varyings_member_name.clone(),
7370                    binding: " [[stage_in]]".to_string(),
7371                    init: None,
7372                });
7373            }
7374
7375            let mut local_invocation_index = None;
7376
7377            // Then pass the remaining arguments not included in the varyings
7378            // struct.
7379            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7380                let binding = match binding {
7381                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7382                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7383                    _ => continue,
7384                };
7385                let name = match *name_key {
7386                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7387                    _ => &self.names[name_key],
7388                };
7389
7390                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7391                    local_invocation_index = Some(name_key);
7392                }
7393
7394                let ty_name = TypeContext {
7395                    handle: ty,
7396                    gctx: module.to_ctx(),
7397                    names: &self.names,
7398                    access: crate::StorageAccess::empty(),
7399                    first_time: false,
7400                };
7401
7402                match *binding {
7403                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7404                        v_existing_id = Some(name.clone());
7405                    }
7406                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7407                        i_existing_id = Some(name.clone());
7408                    }
7409                    _ => {}
7410                };
7411
7412                let resolved = options.resolve_local_binding(binding, in_mode)?;
7413                let mut binding = String::new();
7414                resolved.try_fmt(&mut binding)?;
7415
7416                args.push(EntryPointArgument {
7417                    ty_name: format!("{ty_name}"),
7418                    name: name.clone(),
7419                    binding,
7420                    init: None,
7421                });
7422            }
7423
7424            let need_workgroup_variables_initialization =
7425                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7426
7427            if local_invocation_index.is_none()
7428                && (need_workgroup_variables_initialization
7429                    || ep.stage == crate::ShaderStage::Task
7430                    || ep.stage == crate::ShaderStage::Mesh)
7431            {
7432                args.push(EntryPointArgument {
7433                    ty_name: "uint".to_string(),
7434                    name: "__local_invocation_index".to_string(),
7435                    binding: " [[thread_index_in_threadgroup]]".to_string(),
7436                    init: None,
7437                });
7438            }
7439
7440            // Those global variables used by this entry point and its callees
7441            // get passed as arguments. `Private` globals are an exception, they
7442            // don't outlive this invocation, so we declare them below as locals
7443            // within the entry point.
7444            for (handle, var) in module.global_variables.iter() {
7445                let usage = fun_info[handle];
7446                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7447                    continue;
7448                }
7449
7450                if options.lang_version < (1, 2) {
7451                    match var.space {
7452                        // This restriction is not documented in the MSL spec
7453                        // but validation will fail if it is not upheld.
7454                        //
7455                        // We infer the required version from the "Function
7456                        // Buffer Read-Writes" section of [what's new], where
7457                        // the feature sets listed correspond with the ones
7458                        // supporting MSL 1.2.
7459                        //
7460                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7461                        crate::AddressSpace::Storage { access }
7462                            if access.contains(crate::StorageAccess::STORE)
7463                                && ep.stage == crate::ShaderStage::Fragment =>
7464                        {
7465                            return Err(Error::UnsupportedWritableStorageBuffer)
7466                        }
7467                        crate::AddressSpace::Handle => {
7468                            match module.types[var.ty].inner {
7469                                crate::TypeInner::Image {
7470                                    class: crate::ImageClass::Storage { access, .. },
7471                                    ..
7472                                } => {
7473                                    // This restriction is not documented in the MSL spec
7474                                    // but validation will fail if it is not upheld.
7475                                    //
7476                                    // We infer the required version from the "Function
7477                                    // Texture Read-Writes" section of [what's new], where
7478                                    // the feature sets listed correspond with the ones
7479                                    // supporting MSL 1.2.
7480                                    //
7481                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7482                                    if access.contains(crate::StorageAccess::STORE)
7483                                        && (ep.stage == crate::ShaderStage::Vertex
7484                                            || ep.stage == crate::ShaderStage::Fragment)
7485                                    {
7486                                        return Err(Error::UnsupportedWritableStorageTexture(
7487                                            ep.stage,
7488                                        ));
7489                                    }
7490
7491                                    if access.contains(
7492                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7493                                    ) {
7494                                        return Err(Error::UnsupportedRWStorageTexture);
7495                                    }
7496                                }
7497                                _ => {}
7498                            }
7499                        }
7500                        _ => {}
7501                    }
7502                }
7503
7504                // Check min MSL version for binding arrays
7505                match var.space {
7506                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7507                        crate::TypeInner::BindingArray { base, .. } => {
7508                            match module.types[base].inner {
7509                                crate::TypeInner::Sampler { .. } => {
7510                                    if options.lang_version < (2, 0) {
7511                                        return Err(Error::UnsupportedArrayOf(
7512                                            "samplers".to_string(),
7513                                        ));
7514                                    }
7515                                }
7516                                crate::TypeInner::Image { class, .. } => match class {
7517                                    crate::ImageClass::Sampled { .. }
7518                                    | crate::ImageClass::Depth { .. }
7519                                    | crate::ImageClass::Storage {
7520                                        access: crate::StorageAccess::LOAD,
7521                                        ..
7522                                    } => {
7523                                        // Array of textures since:
7524                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7525                                        // - macOS: Metal 2
7526
7527                                        if options.lang_version < (2, 0) {
7528                                            return Err(Error::UnsupportedArrayOf(
7529                                                "textures".to_string(),
7530                                            ));
7531                                        }
7532                                    }
7533                                    crate::ImageClass::Storage {
7534                                        access: crate::StorageAccess::STORE,
7535                                        ..
7536                                    } => {
7537                                        // Array of write-only textures since:
7538                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7539                                        // - macOS: Metal 2
7540
7541                                        if options.lang_version < (2, 0) {
7542                                            return Err(Error::UnsupportedArrayOf(
7543                                                "write-only textures".to_string(),
7544                                            ));
7545                                        }
7546                                    }
7547                                    crate::ImageClass::Storage { .. } => {
7548                                        if options.lang_version < (3, 0) {
7549                                            return Err(Error::UnsupportedArrayOf(
7550                                                "read-write textures".to_string(),
7551                                            ));
7552                                        }
7553                                    }
7554                                    crate::ImageClass::External => {
7555                                        return Err(Error::UnsupportedArrayOf(
7556                                            "external textures".to_string(),
7557                                        ));
7558                                    }
7559                                },
7560                                _ => {
7561                                    return Err(Error::UnsupportedArrayOfType(base));
7562                                }
7563                            }
7564                        }
7565                        _ => {}
7566                    },
7567                    _ => {}
7568                }
7569
7570                // the resolves have already been checked for `!fake_missing_bindings` case
7571                let resolved = match var.space {
7572                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7573                    crate::AddressSpace::WorkGroup => None,
7574                    crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7575                    _ => options
7576                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7577                        .ok(),
7578                };
7579                if let Some(ref resolved) = resolved {
7580                    // Inline samplers are be defined in the EP body
7581                    if resolved.as_inline_sampler(options).is_some() {
7582                        continue;
7583                    }
7584                }
7585
7586                match module.types[var.ty].inner {
7587                    crate::TypeInner::Image {
7588                        class: crate::ImageClass::External,
7589                        ..
7590                    } => {
7591                        // External texture global variables get lowered to 3 textures
7592                        // and a constant buffer. We must emit a separate argument for
7593                        // each of these.
7594                        let target = match resolved {
7595                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7596                                target.external_texture
7597                            }
7598                            _ => None,
7599                        };
7600
7601                        for i in 0..3 {
7602                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7603                                handle,
7604                                ExternalTextureNameKey::Plane(i),
7605                            )];
7606                            let ty_name = format!(
7607                                "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7608                            );
7609                            let name = plane_name.clone();
7610                            let binding = if let Some(ref target) = target {
7611                                format!(" [[texture({})]]", target.planes[i])
7612                            } else {
7613                                String::new()
7614                            };
7615                            args.push(EntryPointArgument {
7616                                ty_name,
7617                                name,
7618                                binding,
7619                                init: None,
7620                            });
7621                        }
7622                        let params_ty_name = &self.names
7623                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7624                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7625                            handle,
7626                            ExternalTextureNameKey::Params,
7627                        )];
7628                        let binding = if let Some(ref target) = target {
7629                            format!(" [[buffer({})]]", target.params)
7630                        } else {
7631                            String::new()
7632                        };
7633
7634                        args.push(EntryPointArgument {
7635                            ty_name: format!("constant {params_ty_name}&"),
7636                            name: params_name.clone(),
7637                            binding,
7638                            init: None,
7639                        });
7640                    }
7641                    _ => {
7642                        if var.space == crate::AddressSpace::WorkGroup
7643                            && ep.stage == crate::ShaderStage::Mesh
7644                        {
7645                            continue;
7646                        }
7647                        let tyvar = TypedGlobalVariable {
7648                            module,
7649                            names: &self.names,
7650                            handle,
7651                            usage,
7652                            reference: true,
7653                        };
7654                        let parts = tyvar.to_parts()?;
7655                        let mut binding = String::new();
7656                        if let Some(resolved) = resolved {
7657                            resolved.try_fmt(&mut binding)?;
7658                        }
7659                        args.push(EntryPointArgument {
7660                            ty_name: parts.ty_name,
7661                            name: parts.var_name,
7662                            binding,
7663                            init: var.init,
7664                        });
7665                    }
7666                }
7667            }
7668
7669            if do_vertex_pulling {
7670                if needs_vertex_id && v_existing_id.is_none() {
7671                    // Write the [[vertex_id]] argument.
7672                    args.push(EntryPointArgument {
7673                        ty_name: "uint".to_string(),
7674                        name: v_id.clone(),
7675                        binding: " [[vertex_id]]".to_string(),
7676                        init: None,
7677                    });
7678                }
7679
7680                if needs_instance_id && i_existing_id.is_none() {
7681                    args.push(EntryPointArgument {
7682                        ty_name: "uint".to_string(),
7683                        name: i_id.clone(),
7684                        binding: " [[instance_id]]".to_string(),
7685                        init: None,
7686                    });
7687                }
7688
7689                // Iterate vbm_resolved, output one argument for every vertex buffer,
7690                // using the names we generated earlier.
7691                for vbm in &vbm_resolved {
7692                    let id = &vbm.id;
7693                    let ty_name = &vbm.ty_name;
7694                    let param_name = &vbm.param_name;
7695                    args.push(EntryPointArgument {
7696                        ty_name: format!("const device {ty_name}*"),
7697                        name: param_name.clone(),
7698                        binding: format!(" [[buffer({id})]]"),
7699                        init: None,
7700                    });
7701                }
7702            }
7703
7704            // If this entry uses any variable-length arrays, their sizes are
7705            // passed as a final struct-typed argument.
7706            if needs_buffer_sizes {
7707                // this is checked earlier
7708                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7709                let mut binding = String::new();
7710                resolved.try_fmt(&mut binding)?;
7711                args.push(EntryPointArgument {
7712                    ty_name: "constant _mslBufferSizes&".to_string(),
7713                    name: "_buffer_sizes".to_string(),
7714                    binding,
7715                    init: None,
7716                });
7717            }
7718
7719            let mut is_first_arg = true;
7720            for arg in &args {
7721                if is_first_arg {
7722                    write!(self.out, "  ")?;
7723                } else {
7724                    write!(self.out, ", ")?;
7725                }
7726                is_first_arg = false;
7727                write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7728                if !is_wrapped {
7729                    write!(self.out, "{}", arg.binding)?;
7730                    if let Some(init) = arg.init {
7731                        write!(self.out, " = ")?;
7732                        self.put_const_expression(
7733                            init,
7734                            module,
7735                            mod_info,
7736                            &module.global_expressions,
7737                        )?;
7738                    }
7739                }
7740                writeln!(self.out)?;
7741            }
7742            if ep.stage == crate::ShaderStage::Mesh {
7743                for (handle, var) in module.global_variables.iter() {
7744                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7745                        continue;
7746                    }
7747                    if is_first_arg {
7748                        write!(self.out, "  ")?;
7749                    } else {
7750                        write!(self.out, ", ")?;
7751                    }
7752                    let ty_context = TypeContext {
7753                        handle: module.global_variables[handle].ty,
7754                        gctx: module.to_ctx(),
7755                        names: &self.names,
7756                        access: crate::StorageAccess::empty(),
7757                        first_time: false,
7758                    };
7759                    writeln!(
7760                        self.out,
7761                        "threadgroup {ty_context}& {}",
7762                        self.names[&NameKey::GlobalVariable(handle)]
7763                    )?;
7764                }
7765            }
7766
7767            // end of the entry point argument list
7768            writeln!(self.out, ") {{")?;
7769
7770            // Starting the function body.
7771            if do_vertex_pulling {
7772                // Provide zero values for all the attributes, which we will overwrite with
7773                // real data from the vertex attribute buffers, if the indices are in-bounds.
7774                for vbm in &vbm_resolved {
7775                    for attribute in vbm.attributes {
7776                        let location = attribute.shader_location;
7777                        let am_option = am_resolved.get(&location);
7778                        if am_option.is_none() {
7779                            // This bound attribute isn't used in this entry point, so
7780                            // don't bother zero-initializing it.
7781                            continue;
7782                        }
7783                        let am = am_option.unwrap();
7784                        let attribute_ty_name = &am.ty_name;
7785                        let attribute_name = &am.name;
7786
7787                        writeln!(
7788                            self.out,
7789                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7790                            back::Level(1)
7791                        )?;
7792                    }
7793
7794                    // Output a bounds check block that will set real values for the
7795                    // attributes, if the bounds are satisfied.
7796                    write!(self.out, "{}if (", back::Level(1))?;
7797
7798                    let idx = &vbm.id;
7799                    let stride = &vbm.stride;
7800                    let index_name = match vbm.step_mode {
7801                        back::msl::VertexBufferStepMode::Constant => "0",
7802                        back::msl::VertexBufferStepMode::ByVertex => {
7803                            if let Some(ref name) = v_existing_id {
7804                                name
7805                            } else {
7806                                &v_id
7807                            }
7808                        }
7809                        back::msl::VertexBufferStepMode::ByInstance => {
7810                            if let Some(ref name) = i_existing_id {
7811                                name
7812                            } else {
7813                                &i_id
7814                            }
7815                        }
7816                    };
7817                    write!(
7818                        self.out,
7819                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7820                    )?;
7821
7822                    writeln!(self.out, ") {{")?;
7823
7824                    // Pull the bytes out of the vertex buffer.
7825                    let ty_name = &vbm.ty_name;
7826                    let elem_name = &vbm.elem_name;
7827                    let param_name = &vbm.param_name;
7828
7829                    writeln!(
7830                        self.out,
7831                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7832                        back::Level(2),
7833                    )?;
7834
7835                    // Now set real values for each of the attributes, by unpacking the data
7836                    // from the buffer elements.
7837                    for attribute in vbm.attributes {
7838                        let location = attribute.shader_location;
7839                        let Some(am) = am_resolved.get(&location) else {
7840                            // This bound attribute isn't used in this entry point, so
7841                            // don't bother extracting the data. Too bad we emitted the
7842                            // unpacking function earlier -- it might not get used.
7843                            continue;
7844                        };
7845                        let attribute_name = &am.name;
7846                        let attribute_ty_name = &am.ty_name;
7847
7848                        let offset = attribute.offset;
7849                        let func = unpacking_functions
7850                            .get(&attribute.format)
7851                            .expect("Should have generated this unpacking function earlier.");
7852                        let func_name = &func.name;
7853
7854                        // Check dimensionality of the attribute compared to the unpacking
7855                        // function. If attribute dimension > unpack dimension, we have to
7856                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7857                        // scalar type. Otherwise, if attribute dimension is < unpack
7858                        // dimension, then we need to explicitly truncate the result.
7859                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7860
7861                        // We need an extra type conversion if the shader type does not
7862                        // match the type returned from the unpacking function.
7863                        let needs_conversion = am.scalar != func.scalar;
7864
7865                        if needs_padding_or_truncation != Ordering::Equal {
7866                            // Emit a comment flagging that a conversion is happening,
7867                            // since the actual logic can be at the end of a long line.
7868                            writeln!(
7869                                self.out,
7870                                "{}// {attribute_ty_name} <- {:?}",
7871                                back::Level(2),
7872                                attribute.format
7873                            )?;
7874                        }
7875
7876                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7877
7878                        if needs_padding_or_truncation == Ordering::Greater {
7879                            // Needs padding: emit constructor call for wider type
7880                            write!(self.out, "{attribute_ty_name}(")?;
7881                        }
7882
7883                        // Emit call to unpacking function
7884                        if needs_conversion {
7885                            put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7886                            write!(self.out, "(")?;
7887                        }
7888                        write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7889                        for i in (offset + 1)..(offset + func.byte_count) {
7890                            write!(self.out, ", {elem_name}.data[{i}]")?;
7891                        }
7892                        write!(self.out, ")")?;
7893                        if needs_conversion {
7894                            write!(self.out, ")")?;
7895                        }
7896
7897                        match needs_padding_or_truncation {
7898                            Ordering::Greater => {
7899                                // Padding
7900                                let ty_is_int = scalar_is_int(am.scalar);
7901                                let zero_value = if ty_is_int { "0" } else { "0.0" };
7902                                let one_value = if ty_is_int { "1" } else { "1.0" };
7903                                for i in func.dimension.map_or(1, u8::from)
7904                                    ..am.dimension.map_or(1, u8::from)
7905                                {
7906                                    write!(
7907                                        self.out,
7908                                        ", {}",
7909                                        if i == 3 { one_value } else { zero_value }
7910                                    )?;
7911                                }
7912                            }
7913                            Ordering::Less => {
7914                                // Truncate to the first `am.dimension` components
7915                                write!(
7916                                    self.out,
7917                                    ".{}",
7918                                    &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7919                                )?;
7920                            }
7921                            Ordering::Equal => {}
7922                        }
7923
7924                        if needs_padding_or_truncation == Ordering::Greater {
7925                            write!(self.out, ")")?;
7926                        }
7927
7928                        writeln!(self.out, ";")?;
7929                    }
7930
7931                    // End the bounds check / attribute setting block.
7932                    writeln!(self.out, "{}}}", back::Level(1))?;
7933                }
7934            }
7935
7936            // Metal doesn't support private mutable variables outside of functions,
7937            // so we put them here, just like the locals.
7938            for (handle, var) in module.global_variables.iter() {
7939                let usage = fun_info[handle];
7940                if usage.is_empty() {
7941                    continue;
7942                }
7943                if var.space == crate::AddressSpace::Private {
7944                    let tyvar = TypedGlobalVariable {
7945                        module,
7946                        names: &self.names,
7947                        handle,
7948                        usage,
7949
7950                        reference: false,
7951                    };
7952                    write!(self.out, "{}", back::INDENT)?;
7953                    tyvar.try_fmt(&mut self.out)?;
7954                    match var.init {
7955                        Some(value) => {
7956                            write!(self.out, " = ")?;
7957                            self.put_const_expression(
7958                                value,
7959                                module,
7960                                mod_info,
7961                                &module.global_expressions,
7962                            )?;
7963                            writeln!(self.out, ";")?;
7964                        }
7965                        None => {
7966                            writeln!(self.out, " = {{}};")?;
7967                        }
7968                    };
7969                } else if let Some(ref binding) = var.binding {
7970                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7971                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7972                        // write an inline sampler
7973                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7974                        writeln!(
7975                            self.out,
7976                            "{}constexpr {}::sampler {}(",
7977                            back::INDENT,
7978                            NAMESPACE,
7979                            name
7980                        )?;
7981                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7982                        writeln!(self.out, "{});", back::INDENT)?;
7983                    } else if let crate::TypeInner::Image {
7984                        class: crate::ImageClass::External,
7985                        ..
7986                    } = module.types[var.ty].inner
7987                    {
7988                        // Wrap the individual arguments for each external texture global
7989                        // in a struct which can be easily passed around.
7990                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7991                        let l1 = back::Level(1);
7992                        let l2 = l1.next();
7993                        writeln!(
7994                            self.out,
7995                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7996                        )?;
7997                        for i in 0..3 {
7998                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7999                                handle,
8000                                ExternalTextureNameKey::Plane(i),
8001                            )];
8002                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
8003                        }
8004                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8005                            handle,
8006                            ExternalTextureNameKey::Params,
8007                        )];
8008                        writeln!(self.out, "{l2}.params = {params_name},")?;
8009                        writeln!(self.out, "{l1}}};")?;
8010                    }
8011                }
8012            }
8013
8014            if need_workgroup_variables_initialization {
8015                self.write_workgroup_variables_initialization(
8016                    module,
8017                    mod_info,
8018                    fun_info,
8019                    local_invocation_index,
8020                    ep.stage,
8021                )?;
8022            }
8023
8024            // Now take the arguments that we gathered into structs, and the
8025            // structs that we flattened into arguments, and emit local
8026            // variables with initializers that put everything back the way the
8027            // body code expects.
8028            //
8029            // If we had to generate fresh names for struct members passed as
8030            // arguments, be sure to use those names when rebuilding the struct.
8031            //
8032            // "Each day, I change some zeros to ones, and some ones to zeros.
8033            // The rest, I leave alone."
8034            for (arg_index, arg) in fun.arguments.iter().enumerate() {
8035                let arg_name =
8036                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
8037                match module.types[arg.ty].inner {
8038                    crate::TypeInner::Struct { ref members, .. } => {
8039                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
8040                        write!(
8041                            self.out,
8042                            "{}const {} {} = {{ ",
8043                            back::INDENT,
8044                            struct_name,
8045                            arg_name
8046                        )?;
8047                        for (member_index, member) in members.iter().enumerate() {
8048                            let key = NameKey::StructMember(arg.ty, member_index as u32);
8049                            let name = &flattened_member_names[&key];
8050                            if member_index != 0 {
8051                                write!(self.out, ", ")?;
8052                            }
8053                            // insert padding initialization, if needed
8054                            if self
8055                                .struct_member_pads
8056                                .contains(&(arg.ty, member_index as u32))
8057                            {
8058                                write!(self.out, "{{}}, ")?;
8059                            }
8060                            match member.binding {
8061                                Some(crate::Binding::Location {
8062                                    interpolation: Some(crate::Interpolation::PerVertex),
8063                                    ..
8064                                }) => {
8065                                    writeln!(
8066                                        self.out,
8067                                        "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
8068                                        back::INDENT,
8069                                        varyings_member_name,
8070                                        arg_name,
8071                                    )?;
8072                                    continue;
8073                                }
8074                                Some(crate::Binding::Location { .. }) => {
8075                                    if has_varyings {
8076                                        write!(self.out, "{varyings_member_name}.")?;
8077                                    }
8078                                }
8079                                _ => (),
8080                            }
8081                            write!(self.out, "{name}")?;
8082                        }
8083                        writeln!(self.out, " }};")?;
8084                    }
8085                    _ => match arg.binding {
8086                        Some(crate::Binding::Location {
8087                            interpolation: Some(crate::Interpolation::PerVertex),
8088                            ..
8089                        }) => {
8090                            let ty_name = TypeContext {
8091                                handle: arg.ty,
8092                                gctx: module.to_ctx(),
8093                                names: &self.names,
8094                                access: crate::StorageAccess::empty(),
8095                                first_time: false,
8096                            };
8097                            writeln!(
8098                                self.out,
8099                                "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8100                                back::INDENT,
8101                                varyings_member_name,
8102                                arg_name,
8103                            )?;
8104                        }
8105                        Some(crate::Binding::Location { .. })
8106                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8107                            if has_varyings {
8108                                writeln!(
8109                                    self.out,
8110                                    "{}const auto {} = {}.{};",
8111                                    back::INDENT,
8112                                    arg_name,
8113                                    varyings_member_name,
8114                                    arg_name
8115                                )?;
8116                            }
8117                        }
8118                        _ => {}
8119                    },
8120                }
8121            }
8122
8123            let guarded_indices =
8124                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8125
8126            let context = StatementContext {
8127                expression: ExpressionContext {
8128                    function: fun,
8129                    origin: FunctionOrigin::EntryPoint(ep_index as _),
8130                    info: fun_info,
8131                    lang_version: options.lang_version,
8132                    policies: options.bounds_check_policies,
8133                    guarded_indices,
8134                    module,
8135                    mod_info,
8136                    pipeline_options,
8137                    force_loop_bounding: options.force_loop_bounding,
8138                },
8139                result_struct: if ep.stage == crate::ShaderStage::Task {
8140                    None
8141                } else {
8142                    Some(&stage_out_name)
8143                },
8144            };
8145
8146            // Finally, declare all the local variables that we need
8147            //TODO: we can postpone this till the relevant expressions are emitted
8148            self.put_locals(&context.expression)?;
8149            self.update_expressions_to_bake(fun, fun_info, &context.expression);
8150            self.put_block(back::Level(1), &fun.body, &context)?;
8151            writeln!(self.out, "}}")?;
8152            if ep_index + 1 != module.entry_points.len() {
8153                writeln!(self.out)?;
8154            }
8155            self.named_expressions.clear();
8156
8157            if is_wrapped {
8158                self.write_wrapper_function(NestedFunctionInfo {
8159                    options,
8160                    ep,
8161                    module,
8162                    mod_info,
8163                    fun_info,
8164                    args,
8165                    local_invocation_index,
8166                    nested_name: &nested_fun_name,
8167                    outer_name: &fun_name,
8168                    out_mesh_info,
8169                })?;
8170            }
8171        }
8172
8173        Ok(info)
8174    }
8175
8176    pub(super) fn write_barrier(
8177        &mut self,
8178        flags: crate::Barrier,
8179        level: back::Level,
8180    ) -> BackendResult {
8181        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
8182        // so we try to avoid it here.
8183        if flags.is_empty() {
8184            writeln!(
8185                self.out,
8186                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8187            )?;
8188        }
8189        if flags.contains(crate::Barrier::STORAGE) {
8190            writeln!(
8191                self.out,
8192                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8193            )?;
8194        }
8195        if flags.contains(crate::Barrier::WORK_GROUP) {
8196            writeln!(
8197                self.out,
8198                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8199            )?;
8200            if self.needs_object_memory_barriers {
8201                writeln!(
8202                    self.out,
8203                    "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8204                )?;
8205            }
8206        }
8207        if flags.contains(crate::Barrier::SUB_GROUP) {
8208            writeln!(
8209                self.out,
8210                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8211            )?;
8212        }
8213        if flags.contains(crate::Barrier::TEXTURE) {
8214            writeln!(
8215                self.out,
8216                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8217            )?;
8218        }
8219        Ok(())
8220    }
8221}
8222
8223/// Initializing workgroup variables is more tricky for Metal because we have to deal
8224/// with atomics at the type-level (which don't have a copy constructor).
8225mod workgroup_mem_init {
8226    use crate::EntryPoint;
8227
8228    use super::*;
8229
8230    enum Access {
8231        GlobalVariable(Handle<crate::GlobalVariable>),
8232        StructMember(Handle<crate::Type>, u32),
8233        Array(usize),
8234    }
8235
8236    impl Access {
8237        fn write<W: Write>(
8238            &self,
8239            writer: &mut W,
8240            names: &FastHashMap<NameKey, String>,
8241        ) -> Result<(), core::fmt::Error> {
8242            match *self {
8243                Access::GlobalVariable(handle) => {
8244                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8245                }
8246                Access::StructMember(handle, index) => {
8247                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8248                }
8249                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8250            }
8251        }
8252    }
8253
8254    struct AccessStack {
8255        stack: Vec<Access>,
8256        array_depth: usize,
8257    }
8258
8259    impl AccessStack {
8260        const fn new() -> Self {
8261            Self {
8262                stack: Vec::new(),
8263                array_depth: 0,
8264            }
8265        }
8266
8267        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8268            let array_depth = self.array_depth;
8269            self.stack.push(Access::Array(array_depth));
8270            self.array_depth += 1;
8271            let res = cb(self, array_depth);
8272            self.stack.pop();
8273            self.array_depth -= 1;
8274            res
8275        }
8276
8277        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8278            self.stack.push(new);
8279            let res = cb(self);
8280            self.stack.pop();
8281            res
8282        }
8283
8284        fn write<W: Write>(
8285            &self,
8286            writer: &mut W,
8287            names: &FastHashMap<NameKey, String>,
8288        ) -> Result<(), core::fmt::Error> {
8289            for next in self.stack.iter() {
8290                next.write(writer, names)?;
8291            }
8292            Ok(())
8293        }
8294    }
8295
8296    impl<W: Write> Writer<W> {
8297        pub(super) fn need_workgroup_variables_initialization(
8298            &mut self,
8299            options: &Options,
8300            ep: &EntryPoint,
8301            module: &crate::Module,
8302            fun_info: &valid::FunctionInfo,
8303        ) -> bool {
8304            let is_task = ep.stage == crate::ShaderStage::Task;
8305            options.zero_initialize_workgroup_memory
8306                && ep.stage.compute_like()
8307                && module.global_variables.iter().any(|(handle, var)| {
8308                    let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8309                        || (var.space == crate::AddressSpace::TaskPayload && is_task);
8310                    !fun_info[handle].is_empty() && is_right_address_space
8311                })
8312        }
8313
8314        pub fn write_workgroup_variables_initialization(
8315            &mut self,
8316            module: &crate::Module,
8317            module_info: &valid::ModuleInfo,
8318            fun_info: &valid::FunctionInfo,
8319            local_invocation_index: Option<&NameKey>,
8320            stage: crate::ShaderStage,
8321        ) -> BackendResult {
8322            let level = back::Level(1);
8323
8324            writeln!(
8325                self.out,
8326                "{}if ({} == 0u) {{",
8327                level,
8328                local_invocation_index
8329                    .map(|name_key| self.names[name_key].as_str())
8330                    .unwrap_or("__local_invocation_index"),
8331            )?;
8332
8333            let mut access_stack = AccessStack::new();
8334
8335            let is_task = stage == crate::ShaderStage::Task;
8336            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8337                let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8338                    || (var.space == crate::AddressSpace::TaskPayload && is_task);
8339                !fun_info[handle].is_empty() && is_right_address_space
8340            });
8341
8342            for (handle, var) in vars {
8343                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8344                    self.write_workgroup_variable_initialization(
8345                        module,
8346                        module_info,
8347                        var.ty,
8348                        access_stack,
8349                        level.next(),
8350                    )
8351                })?;
8352            }
8353
8354            writeln!(self.out, "{level}}}")?;
8355            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8356        }
8357
8358        fn write_workgroup_variable_initialization(
8359            &mut self,
8360            module: &crate::Module,
8361            module_info: &valid::ModuleInfo,
8362            ty: Handle<crate::Type>,
8363            access_stack: &mut AccessStack,
8364            level: back::Level,
8365        ) -> BackendResult {
8366            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8367                write!(self.out, "{level}")?;
8368                access_stack.write(&mut self.out, &self.names)?;
8369                writeln!(self.out, " = {{}};")?;
8370            } else {
8371                match module.types[ty].inner {
8372                    crate::TypeInner::Atomic { .. } => {
8373                        write!(
8374                            self.out,
8375                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8376                        )?;
8377                        access_stack.write(&mut self.out, &self.names)?;
8378                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8379                    }
8380                    crate::TypeInner::Array { base, size, .. } => {
8381                        let count = match size.resolve(module.to_ctx())? {
8382                            proc::IndexableLength::Known(count) => count,
8383                            proc::IndexableLength::Dynamic => unreachable!(),
8384                        };
8385
8386                        access_stack.enter_array(|access_stack, array_depth| {
8387                            writeln!(
8388                                self.out,
8389                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8390                            )?;
8391                            self.write_workgroup_variable_initialization(
8392                                module,
8393                                module_info,
8394                                base,
8395                                access_stack,
8396                                level.next(),
8397                            )?;
8398                            writeln!(self.out, "{level}}}")?;
8399                            BackendResult::Ok(())
8400                        })?;
8401                    }
8402                    crate::TypeInner::Struct { ref members, .. } => {
8403                        for (index, member) in members.iter().enumerate() {
8404                            access_stack.enter(
8405                                Access::StructMember(ty, index as u32),
8406                                |access_stack| {
8407                                    self.write_workgroup_variable_initialization(
8408                                        module,
8409                                        module_info,
8410                                        member.ty,
8411                                        access_stack,
8412                                        level,
8413                                    )
8414                                },
8415                            )?;
8416                        }
8417                    }
8418                    _ => unreachable!(),
8419                }
8420            }
8421
8422            Ok(())
8423        }
8424    }
8425}
8426
8427impl crate::AtomicFunction {
8428    const fn to_msl(self) -> &'static str {
8429        match self {
8430            Self::Add => "fetch_add",
8431            Self::Subtract => "fetch_sub",
8432            Self::And => "fetch_and",
8433            Self::InclusiveOr => "fetch_or",
8434            Self::ExclusiveOr => "fetch_xor",
8435            Self::Min => "fetch_min",
8436            Self::Max => "fetch_max",
8437            Self::Exchange { compare: None } => "exchange",
8438            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8439        }
8440    }
8441
8442    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8443        Ok(match self {
8444            Self::Min => "min",
8445            Self::Max => "max",
8446            _ => Err(Error::FeatureNotImplemented(
8447                "64-bit atomic operation other than min/max".to_string(),
8448            ))?,
8449        })
8450    }
8451}