naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18    arena::{Handle, HandleSet},
19    back::{self, get_entry_points, Baked},
20    common,
21    proc::{
22        self, concrete_int_scalars,
23        index::{self, BoundsCheck},
24        ExternalTextureNameKey, NameKey, TypeResolution,
25    },
26    valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32/// Shorthand result used internally by the backend
33type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36// The name of the array member of the Metal struct types we generate to
37// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
38// for details.
39const WRAPPED_ARRAY_FIELD: &str = "inner";
40// This is a hack: we need to pass a pointer to an atomic,
41// but generally the backend isn't putting "&" in front of every pointer.
42// Some more general handling of pointers is needed to be implemented here.
43const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
50const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68    "nagaTextureSampleBaseClampToEdge";
69/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
70/// However, if you put that texture inside a struct, everything is totally fine. This
71/// baffles me to no end.
72///
73/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
74/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
75/// you have noticed that this should be exactly the same to the compiler, and you're correct.
76pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77/// Name of the struct that is declared to wrap the 3 textures and parameters
78/// buffer that [`crate::ImageClass::External`] variables are lowered to,
79/// allowing them to be conveniently passed to user-defined or wrapper
80/// functions. The struct is declared in [`Writer::write_type_defs`].
81pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82
83/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
84///
85/// The `sizes` slice determines whether this function writes a
86/// scalar, vector, or matrix type:
87///
88/// - An empty slice produces a scalar type.
89/// - A one-element slice produces a vector type.
90/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
91fn put_numeric_type(
92    out: &mut impl Write,
93    scalar: crate::Scalar,
94    sizes: &[crate::VectorSize],
95) -> Result<(), FmtError> {
96    match (scalar, sizes) {
97        (scalar, &[]) => {
98            write!(out, "{}", scalar.to_msl_name())
99        }
100        (scalar, &[rows]) => {
101            write!(
102                out,
103                "{}::{}{}",
104                NAMESPACE,
105                scalar.to_msl_name(),
106                common::vector_size_str(rows)
107            )
108        }
109        (scalar, &[rows, columns]) => {
110            write!(
111                out,
112                "{}::{}{}x{}",
113                NAMESPACE,
114                scalar.to_msl_name(),
115                common::vector_size_str(columns),
116                common::vector_size_str(rows)
117            )
118        }
119        (_, _) => Ok(()), // not meaningful
120    }
121}
122
123const fn scalar_is_int(scalar: crate::Scalar) -> bool {
124    use crate::ScalarKind::*;
125    match scalar.kind {
126        Sint | Uint | AbstractInt | Bool => true,
127        Float | AbstractFloat => false,
128    }
129}
130
131/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
132const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
133
134/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
135const REINTERPRET_PREFIX: &str = "reinterpreted_";
136
137/// Wrapper for identifier names for clamped level-of-detail values
138///
139/// Values of this type implement [`core::fmt::Display`], formatting as
140/// the name of the variable used to hold the cached clamped
141/// level-of-detail value for an `ImageLoad` expression.
142struct ClampedLod(Handle<crate::Expression>);
143
144impl Display for ClampedLod {
145    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
146        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
147    }
148}
149
150/// Wrapper for generating `struct _mslBufferSizes` member names for
151/// runtime-sized array lengths.
152///
153/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
154/// as an argument to the entry point. This argument's type in the MSL is
155/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
156/// each global variable containing a runtime-sized array.
157///
158/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
159/// runtime-sized array, then the value `ArraySize(global)` implements
160/// [`core::fmt::Display`], formatting as the name of the struct member carrying
161/// the number of elements in that runtime-sized array.
162///
163/// [`GlobalVariable`]: crate::GlobalVariable
164struct ArraySizeMember(Handle<crate::GlobalVariable>);
165
166impl Display for ArraySizeMember {
167    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
168        self.0.write_prefixed(f, "size")
169    }
170}
171
172/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
173///
174/// Implements [`core::fmt::Display`], formatting as a name derived from
175/// `target_type` and the variable name of `orig`.
176#[derive(Clone, Copy)]
177struct Reinterpreted<'a> {
178    target_type: &'a str,
179    orig: Handle<crate::Expression>,
180}
181
182impl<'a> Reinterpreted<'a> {
183    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
184        Self { target_type, orig }
185    }
186}
187
188impl Display for Reinterpreted<'_> {
189    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
190        f.write_str(REINTERPRET_PREFIX)?;
191        f.write_str(self.target_type)?;
192        self.orig.write_prefixed(f, "_e")
193    }
194}
195
196struct TypeContext<'a> {
197    handle: Handle<crate::Type>,
198    gctx: proc::GlobalCtx<'a>,
199    names: &'a FastHashMap<NameKey, String>,
200    access: crate::StorageAccess,
201    first_time: bool,
202}
203
204impl TypeContext<'_> {
205    fn scalar(&self) -> Option<crate::Scalar> {
206        let ty = &self.gctx.types[self.handle];
207        ty.inner.scalar()
208    }
209
210    fn vertex_input_dimension(&self) -> u32 {
211        let ty = &self.gctx.types[self.handle];
212        match ty.inner {
213            crate::TypeInner::Scalar(_) => 1,
214            crate::TypeInner::Vector { size, .. } => size as u32,
215            _ => unreachable!(),
216        }
217    }
218}
219
220impl Display for TypeContext<'_> {
221    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
222        let ty = &self.gctx.types[self.handle];
223        if ty.needs_alias() && !self.first_time {
224            let name = &self.names[&NameKey::Type(self.handle)];
225            return write!(out, "{name}");
226        }
227
228        match ty.inner {
229            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
230            crate::TypeInner::Atomic(scalar) => {
231                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
232            }
233            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
234            crate::TypeInner::Matrix {
235                columns,
236                rows,
237                scalar,
238            } => put_numeric_type(out, scalar, &[rows, columns]),
239            crate::TypeInner::Pointer { base, space } => {
240                let sub = Self {
241                    handle: base,
242                    first_time: false,
243                    ..*self
244                };
245                let space_name = match space.to_msl_name() {
246                    Some(name) => name,
247                    None => return Ok(()),
248                };
249                write!(out, "{space_name} {sub}&")
250            }
251            crate::TypeInner::ValuePointer {
252                size,
253                scalar,
254                space,
255            } => {
256                match space.to_msl_name() {
257                    Some(name) => write!(out, "{name} ")?,
258                    None => return Ok(()),
259                };
260                match size {
261                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
262                    None => put_numeric_type(out, scalar, &[])?,
263                };
264
265                write!(out, "&")
266            }
267            crate::TypeInner::Array { base, .. } => {
268                let sub = Self {
269                    handle: base,
270                    first_time: false,
271                    ..*self
272                };
273                // Array lengths go at the end of the type definition,
274                // so just print the element type here.
275                write!(out, "{sub}")
276            }
277            crate::TypeInner::Struct { .. } => unreachable!(),
278            crate::TypeInner::Image {
279                dim,
280                arrayed,
281                class,
282            } => {
283                let dim_str = match dim {
284                    crate::ImageDimension::D1 => "1d",
285                    crate::ImageDimension::D2 => "2d",
286                    crate::ImageDimension::D3 => "3d",
287                    crate::ImageDimension::Cube => "cube",
288                };
289                let (texture_str, msaa_str, scalar, access) = match class {
290                    crate::ImageClass::Sampled { kind, multi } => {
291                        let (msaa_str, access) = if multi {
292                            ("_ms", "read")
293                        } else {
294                            ("", "sample")
295                        };
296                        let scalar = crate::Scalar { kind, width: 4 };
297                        ("texture", msaa_str, scalar, access)
298                    }
299                    crate::ImageClass::Depth { multi } => {
300                        let (msaa_str, access) = if multi {
301                            ("_ms", "read")
302                        } else {
303                            ("", "sample")
304                        };
305                        let scalar = crate::Scalar {
306                            kind: crate::ScalarKind::Float,
307                            width: 4,
308                        };
309                        ("depth", msaa_str, scalar, access)
310                    }
311                    crate::ImageClass::Storage { format, .. } => {
312                        let access = if self
313                            .access
314                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
315                        {
316                            "read_write"
317                        } else if self.access.contains(crate::StorageAccess::STORE) {
318                            "write"
319                        } else if self.access.contains(crate::StorageAccess::LOAD) {
320                            "read"
321                        } else {
322                            log::warn!(
323                                "Storage access for {:?} (name '{}'): {:?}",
324                                self.handle,
325                                ty.name.as_deref().unwrap_or_default(),
326                                self.access
327                            );
328                            unreachable!("module is not valid");
329                        };
330                        ("texture", "", format.into(), access)
331                    }
332                    crate::ImageClass::External => {
333                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
334                    }
335                };
336                let base_name = scalar.to_msl_name();
337                let array_str = if arrayed { "_array" } else { "" };
338                write!(
339                    out,
340                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
341                )
342            }
343            crate::TypeInner::Sampler { comparison: _ } => {
344                write!(out, "{NAMESPACE}::sampler")
345            }
346            crate::TypeInner::AccelerationStructure { vertex_return } => {
347                if vertex_return {
348                    unimplemented!("metal does not support vertex ray hit return")
349                }
350                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
351            }
352            crate::TypeInner::RayQuery { vertex_return } => {
353                if vertex_return {
354                    unimplemented!("metal does not support vertex ray hit return")
355                }
356                write!(out, "{RAY_QUERY_TYPE}")
357            }
358            crate::TypeInner::BindingArray { base, .. } => {
359                let base_tyname = Self {
360                    handle: base,
361                    first_time: false,
362                    ..*self
363                };
364
365                write!(
366                    out,
367                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
368                )
369            }
370        }
371    }
372}
373
374struct TypedGlobalVariable<'a> {
375    module: &'a crate::Module,
376    names: &'a FastHashMap<NameKey, String>,
377    handle: Handle<crate::GlobalVariable>,
378    usage: valid::GlobalUse,
379    reference: bool,
380}
381
382impl TypedGlobalVariable<'_> {
383    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
384        let var = &self.module.global_variables[self.handle];
385        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
386
387        let storage_access = match var.space {
388            crate::AddressSpace::Storage { access } => access,
389            _ => match self.module.types[var.ty].inner {
390                crate::TypeInner::Image {
391                    class: crate::ImageClass::Storage { access, .. },
392                    ..
393                } => access,
394                crate::TypeInner::BindingArray { base, .. } => {
395                    match self.module.types[base].inner {
396                        crate::TypeInner::Image {
397                            class: crate::ImageClass::Storage { access, .. },
398                            ..
399                        } => access,
400                        _ => crate::StorageAccess::default(),
401                    }
402                }
403                _ => crate::StorageAccess::default(),
404            },
405        };
406        let ty_name = TypeContext {
407            handle: var.ty,
408            gctx: self.module.to_ctx(),
409            names: self.names,
410            access: storage_access,
411            first_time: false,
412        };
413
414        let (space, access, reference) = match var.space.to_msl_name() {
415            Some(space) if self.reference => {
416                let access = if var.space.needs_access_qualifier()
417                    && !self.usage.intersects(valid::GlobalUse::WRITE)
418                {
419                    "const"
420                } else {
421                    ""
422                };
423                (space, access, "&")
424            }
425            _ => ("", "", ""),
426        };
427
428        Ok(write!(
429            out,
430            "{}{}{}{}{}{} {}",
431            space,
432            if space.is_empty() { "" } else { " " },
433            ty_name,
434            if access.is_empty() { "" } else { " " },
435            access,
436            reference,
437            name,
438        )?)
439    }
440}
441
442#[derive(Eq, PartialEq, Hash)]
443enum WrappedFunction {
444    UnaryOp {
445        op: crate::UnaryOperator,
446        ty: (Option<crate::VectorSize>, crate::Scalar),
447    },
448    BinaryOp {
449        op: crate::BinaryOperator,
450        left_ty: (Option<crate::VectorSize>, crate::Scalar),
451        right_ty: (Option<crate::VectorSize>, crate::Scalar),
452    },
453    Math {
454        fun: crate::MathFunction,
455        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
456    },
457    Cast {
458        src_scalar: crate::Scalar,
459        vector_size: Option<crate::VectorSize>,
460        dst_scalar: crate::Scalar,
461    },
462    ImageLoad {
463        class: crate::ImageClass,
464    },
465    ImageSample {
466        class: crate::ImageClass,
467        clamp_to_edge: bool,
468    },
469    ImageQuerySize {
470        class: crate::ImageClass,
471    },
472}
473
474pub struct Writer<W> {
475    out: W,
476    names: FastHashMap<NameKey, String>,
477    named_expressions: crate::NamedExpressions,
478    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
479    need_bake_expressions: back::NeedBakeExpressions,
480    namer: proc::Namer,
481    wrapped_functions: FastHashSet<WrappedFunction>,
482    #[cfg(test)]
483    put_expression_stack_pointers: FastHashSet<*const ()>,
484    #[cfg(test)]
485    put_block_stack_pointers: FastHashSet<*const ()>,
486    /// Set of (struct type, struct field index) denoting which fields require
487    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
488    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
489}
490
491impl crate::Scalar {
492    pub(super) fn to_msl_name(self) -> &'static str {
493        use crate::ScalarKind as Sk;
494        match self {
495            Self {
496                kind: Sk::Float,
497                width: 4,
498            } => "float",
499            Self {
500                kind: Sk::Float,
501                width: 2,
502            } => "half",
503            Self {
504                kind: Sk::Sint,
505                width: 4,
506            } => "int",
507            Self {
508                kind: Sk::Uint,
509                width: 4,
510            } => "uint",
511            Self {
512                kind: Sk::Sint,
513                width: 8,
514            } => "long",
515            Self {
516                kind: Sk::Uint,
517                width: 8,
518            } => "ulong",
519            Self {
520                kind: Sk::Bool,
521                width: _,
522            } => "bool",
523            Self {
524                kind: Sk::AbstractInt | Sk::AbstractFloat,
525                width: _,
526            } => unreachable!("Found Abstract scalar kind"),
527            _ => unreachable!("Unsupported scalar kind: {:?}", self),
528        }
529    }
530}
531
532const fn separate(need_separator: bool) -> &'static str {
533    if need_separator {
534        ","
535    } else {
536        ""
537    }
538}
539
540fn should_pack_struct_member(
541    members: &[crate::StructMember],
542    span: u32,
543    index: usize,
544    module: &crate::Module,
545) -> Option<crate::Scalar> {
546    let member = &members[index];
547
548    let ty_inner = &module.types[member.ty].inner;
549    let last_offset = member.offset + ty_inner.size(module.to_ctx());
550    let next_offset = match members.get(index + 1) {
551        Some(next) => next.offset,
552        None => span,
553    };
554    let is_tight = next_offset == last_offset;
555
556    match *ty_inner {
557        crate::TypeInner::Vector {
558            size: crate::VectorSize::Tri,
559            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
560        } if is_tight => Some(scalar),
561        _ => None,
562    }
563}
564
565fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
566    match arena[ty].inner {
567        crate::TypeInner::Struct { ref members, .. } => {
568            if let Some(member) = members.last() {
569                if let crate::TypeInner::Array {
570                    size: crate::ArraySize::Dynamic,
571                    ..
572                } = arena[member.ty].inner
573                {
574                    return true;
575                }
576            }
577            false
578        }
579        crate::TypeInner::Array {
580            size: crate::ArraySize::Dynamic,
581            ..
582        } => true,
583        _ => false,
584    }
585}
586
587impl crate::AddressSpace {
588    /// Returns true if global variables in this address space are
589    /// passed in function arguments. These arguments need to be
590    /// passed through any functions called from the entry point.
591    const fn needs_pass_through(&self) -> bool {
592        match *self {
593            Self::Uniform
594            | Self::Storage { .. }
595            | Self::Private
596            | Self::WorkGroup
597            | Self::PushConstant
598            | Self::Handle
599            | Self::TaskPayload => true,
600            Self::Function => false,
601        }
602    }
603
604    /// Returns true if the address space may need a "const" qualifier.
605    const fn needs_access_qualifier(&self) -> bool {
606        match *self {
607            //Note: we are ignoring the storage access here, and instead
608            // rely on the actual use of a global by functions. This means we
609            // may end up with "const" even if the binding is read-write,
610            // and that should be OK.
611            Self::Storage { .. } => true,
612            Self::TaskPayload => unimplemented!(),
613            // These should always be read-write.
614            Self::Private | Self::WorkGroup => false,
615            // These translate to `constant` address space, no need for qualifiers.
616            Self::Uniform | Self::PushConstant => false,
617            // Not applicable.
618            Self::Handle | Self::Function => false,
619        }
620    }
621
622    const fn to_msl_name(self) -> Option<&'static str> {
623        match self {
624            Self::Handle => None,
625            Self::Uniform | Self::PushConstant => Some("constant"),
626            Self::Storage { .. } => Some("device"),
627            Self::Private | Self::Function => Some("thread"),
628            Self::WorkGroup => Some("threadgroup"),
629            Self::TaskPayload => Some("object_data"),
630        }
631    }
632}
633
634impl crate::Type {
635    // Returns `true` if we need to emit an alias for this type.
636    const fn needs_alias(&self) -> bool {
637        use crate::TypeInner as Ti;
638
639        match self.inner {
640            // value types are concise enough, we only alias them if they are named
641            Ti::Scalar(_)
642            | Ti::Vector { .. }
643            | Ti::Matrix { .. }
644            | Ti::Atomic(_)
645            | Ti::Pointer { .. }
646            | Ti::ValuePointer { .. } => self.name.is_some(),
647            // composite types are better to be aliased, regardless of the name
648            Ti::Struct { .. } | Ti::Array { .. } => true,
649            // handle types may be different, depending on the global var access, so we always inline them
650            Ti::Image { .. }
651            | Ti::Sampler { .. }
652            | Ti::AccelerationStructure { .. }
653            | Ti::RayQuery { .. }
654            | Ti::BindingArray { .. } => false,
655        }
656    }
657}
658
659#[derive(Clone, Copy)]
660enum FunctionOrigin {
661    Handle(Handle<crate::Function>),
662    EntryPoint(proc::EntryPointIndex),
663}
664
665trait NameKeyExt {
666    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
667        match origin {
668            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
669            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
670        }
671    }
672
673    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
674    /// policy when it needs to produce a pointer-typed result for an OOB access. These
675    /// are unique per accessed type, so the second argument is a type handle. See docs
676    /// for [`crate::back::msl`].
677    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
678        match origin {
679            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
680            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
681        }
682    }
683}
684
685impl NameKeyExt for NameKey {}
686
687/// A level of detail argument.
688///
689/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
690/// save the clamped level of detail in a temporary variable whose name is based
691/// on the handle of the `ImageLoad` expression. But for other policies, we just
692/// use the expression directly.
693///
694/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
695/// [`ImageLoad`]: crate::Expression::ImageLoad
696#[derive(Clone, Copy)]
697enum LevelOfDetail {
698    Direct(Handle<crate::Expression>),
699    Restricted(Handle<crate::Expression>),
700}
701
702/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
703///
704/// When this is used in code paths unconcerned with the `Restrict` bounds check
705/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
706/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
707/// be too awkward. If that changes, we can revisit.
708///
709/// [`ImageLoad`]: crate::Expression::ImageLoad
710/// [`ImageStore`]: crate::Statement::ImageStore
711struct TexelAddress {
712    coordinate: Handle<crate::Expression>,
713    array_index: Option<Handle<crate::Expression>>,
714    sample: Option<Handle<crate::Expression>>,
715    level: Option<LevelOfDetail>,
716}
717
718struct ExpressionContext<'a> {
719    function: &'a crate::Function,
720    origin: FunctionOrigin,
721    info: &'a valid::FunctionInfo,
722    module: &'a crate::Module,
723    mod_info: &'a valid::ModuleInfo,
724    pipeline_options: &'a PipelineOptions,
725    lang_version: (u8, u8),
726    policies: index::BoundsCheckPolicies,
727
728    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
729    /// accesses. These may need to be cached in temporary variables. See
730    /// `index::find_checked_indexes` for details.
731    guarded_indices: HandleSet<crate::Expression>,
732    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
733    force_loop_bounding: bool,
734}
735
736impl<'a> ExpressionContext<'a> {
737    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
738        self.info[handle].ty.inner_with(&self.module.types)
739    }
740
741    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
742    ///
743    /// Only mipmapped images need to specify a level of detail. Since 1D
744    /// textures cannot have mipmaps, MSL requires that the level argument to
745    /// texture1d queries and accesses must be a constexpr 0. It's easiest
746    /// just to omit the level entirely for 1D textures.
747    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
748        let image_ty = self.resolve_type(image);
749        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
750            class.is_mipmapped() && dim != crate::ImageDimension::D1
751        } else {
752            false
753        }
754    }
755
756    fn choose_bounds_check_policy(
757        &self,
758        pointer: Handle<crate::Expression>,
759    ) -> index::BoundsCheckPolicy {
760        self.policies
761            .choose_policy(pointer, &self.module.types, self.info)
762    }
763
764    /// See docs for [`proc::index::access_needs_check`].
765    fn access_needs_check(
766        &self,
767        base: Handle<crate::Expression>,
768        index: index::GuardedIndex,
769    ) -> Option<index::IndexableLength> {
770        index::access_needs_check(
771            base,
772            index,
773            self.module,
774            &self.function.expressions,
775            self.info,
776        )
777    }
778
779    /// See docs for [`proc::index::bounds_check_iter`].
780    fn bounds_check_iter(
781        &self,
782        chain: Handle<crate::Expression>,
783    ) -> impl Iterator<Item = BoundsCheck> + '_ {
784        index::bounds_check_iter(chain, self.module, self.function, self.info)
785    }
786
787    /// See docs for [`proc::index::oob_local_types`].
788    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
789        index::oob_local_types(self.module, self.function, self.info, self.policies)
790    }
791
792    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
793        match self.function.expressions[expr_handle] {
794            crate::Expression::AccessIndex { base, index } => {
795                let ty = match *self.resolve_type(base) {
796                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
797                    ref ty => ty,
798                };
799                match *ty {
800                    crate::TypeInner::Struct {
801                        ref members, span, ..
802                    } => should_pack_struct_member(members, span, index as usize, self.module),
803                    _ => None,
804                }
805            }
806            _ => None,
807        }
808    }
809}
810
811struct StatementContext<'a> {
812    expression: ExpressionContext<'a>,
813    result_struct: Option<&'a str>,
814}
815
816impl<W: Write> Writer<W> {
817    /// Creates a new `Writer` instance.
818    pub fn new(out: W) -> Self {
819        Writer {
820            out,
821            names: FastHashMap::default(),
822            named_expressions: Default::default(),
823            need_bake_expressions: Default::default(),
824            namer: proc::Namer::default(),
825            wrapped_functions: FastHashSet::default(),
826            #[cfg(test)]
827            put_expression_stack_pointers: Default::default(),
828            #[cfg(test)]
829            put_block_stack_pointers: Default::default(),
830            struct_member_pads: FastHashSet::default(),
831        }
832    }
833
834    /// Finishes writing and returns the output.
835    // See https://github.com/rust-lang/rust-clippy/issues/4979.
836    #[allow(clippy::missing_const_for_fn)]
837    pub fn finish(self) -> W {
838        self.out
839    }
840
841    /// Generates statements to be inserted immediately before and at the very
842    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
843    /// The 0th item of the returned tuple should be inserted immediately prior
844    /// to the loop and the 1st item should be inserted at the very start of
845    /// the loop body.
846    ///
847    /// # What is this trying to solve?
848    ///
849    /// In Metal Shading Language, an infinite loop has undefined behavior.
850    /// (This rule is inherited from C++14.) This means that, if the MSL
851    /// compiler determines that a given loop will never exit, it may assume
852    /// that it is never reached. It may thus assume that any conditions
853    /// sufficient to cause the loop to be reached must be false. Like many
854    /// optimizing compilers, MSL uses this kind of analysis to establish limits
855    /// on the range of values variables involved in those conditions might
856    /// hold.
857    ///
858    /// For example, suppose the MSL compiler sees the code:
859    ///
860    /// ```ignore
861    /// if (i >= 10) {
862    ///     while (true) { }
863    /// }
864    /// ```
865    ///
866    /// It will recognize that the `while` loop will never terminate, conclude
867    /// that it must be unreachable, and thus infer that, if this code is
868    /// reached, then `i < 10` at that point.
869    ///
870    /// Now suppose that, at some point where `i` has the same value as above,
871    /// the compiler sees the code:
872    ///
873    /// ```ignore
874    /// if (i < 10) {
875    ///     a[i] = 1;
876    /// }
877    /// ```
878    ///
879    /// Because the compiler is confident that `i < 10`, it will make the
880    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
881    ///
882    /// ```ignore
883    /// a[i] = 1;
884    /// ```
885    ///
886    /// If that `if` condition was injected by Naga to implement a bounds check,
887    /// the MSL compiler's optimizations could allow out-of-bounds array
888    /// accesses to occur.
889    ///
890    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
891    /// that a loop is infinite, so an attacker could craft a Naga module
892    /// containing an infinite loop protected by conditions that cause the Metal
893    /// compiler to remove bounds checks that Naga injected elsewhere in the
894    /// function.
895    ///
896    /// This rewrite could occur even if the conditional assignment appears
897    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
898    /// reached. This would allow the attacker to save the results of
899    /// unauthorized reads somewhere accessible before entering the infinite
900    /// loop. But even worse, the MSL compiler has been observed to simply
901    /// delete the infinite loop entirely, so that even code dominated by the
902    /// loop becomes reachable. This would make the attack even more flexible,
903    /// since shaders that would appear to never terminate would actually exit
904    /// nicely, after having stolen data from elsewhere in the GPU address
905    /// space.
906    ///
907    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
908    /// generates is infinite. One approach would be to add inline assembly to
909    /// each loop that is annotated as potentially branching out of the loop,
910    /// but which in fact generates no instructions. Unfortunately, inline
911    /// assembly is not handled correctly by some Metal device drivers.
912    ///
913    /// A previously used approach was to add the following code to the bottom
914    /// of every loop:
915    ///
916    /// ```ignore
917    /// if (volatile bool unpredictable = false; unpredictable)
918    ///     break;
919    /// ```
920    ///
921    /// Although the `if` condition will always be false in any real execution,
922    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
923    /// it must assume that the `break` might be reached, and hence that the
924    /// loop is not unbounded. This prevents the range analysis impact described
925    /// above. Unfortunately this prevented the compiler from making important,
926    /// and safe, optimizations such as loop unrolling and was observed to
927    /// significantly hurt performance.
928    ///
929    /// Our current approach declares a counter before every loop and
930    /// increments it every iteration, breaking after 2^64 iterations:
931    ///
932    /// ```ignore
933    /// uint2 loop_bound = uint2(0);
934    /// while (true) {
935    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
936    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
937    /// }
938    /// ```
939    ///
940    /// This convinces the compiler that the loop is finite and therefore may
941    /// execute, whilst at the same time allowing optimizations such as loop
942    /// unrolling. Furthermore the 64-bit counter is large enough it seems
943    /// implausible that it would affect the execution of any shader.
944    ///
945    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
946    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
947    fn gen_force_bounded_loop_statements(
948        &mut self,
949        level: back::Level,
950        context: &StatementContext,
951    ) -> Option<(String, String)> {
952        if !context.expression.force_loop_bounding {
953            return None;
954        }
955
956        let loop_bound_name = self.namer.call("loop_bound");
957        // Count down from u32::MAX rather than up from 0 to avoid hang on
958        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
959        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
960        let level = level.next();
961        let break_and_inc = format!(
962            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
963{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
964        );
965
966        Some((decl, break_and_inc))
967    }
968
969    fn put_call_parameters(
970        &mut self,
971        parameters: impl Iterator<Item = Handle<crate::Expression>>,
972        context: &ExpressionContext,
973    ) -> BackendResult {
974        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
975            writer.put_expression(expr, context, true)
976        })
977    }
978
979    fn put_call_parameters_impl<C, E>(
980        &mut self,
981        parameters: impl Iterator<Item = Handle<crate::Expression>>,
982        ctx: &C,
983        put_expression: E,
984    ) -> BackendResult
985    where
986        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
987    {
988        write!(self.out, "(")?;
989        for (i, handle) in parameters.enumerate() {
990            if i != 0 {
991                write!(self.out, ", ")?;
992            }
993            put_expression(self, ctx, handle)?;
994        }
995        write!(self.out, ")")?;
996        Ok(())
997    }
998
999    /// Writes the local variables of the given function, as well as any extra
1000    /// out-of-bounds locals that are needed.
1001    ///
1002    /// The names of the OOB locals are also added to `self.names` at the same
1003    /// time.
1004    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1005        let oob_local_types = context.oob_local_types();
1006        for &ty in oob_local_types.iter() {
1007            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1008            self.names.insert(name_key, self.namer.call("oob"));
1009        }
1010
1011        for (name_key, ty, init) in context
1012            .function
1013            .local_variables
1014            .iter()
1015            .map(|(local_handle, local)| {
1016                let name_key = NameKey::local(context.origin, local_handle);
1017                (name_key, local.ty, local.init)
1018            })
1019            .chain(oob_local_types.iter().map(|&ty| {
1020                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1021                (name_key, ty, None)
1022            }))
1023        {
1024            let ty_name = TypeContext {
1025                handle: ty,
1026                gctx: context.module.to_ctx(),
1027                names: &self.names,
1028                access: crate::StorageAccess::empty(),
1029                first_time: false,
1030            };
1031            write!(
1032                self.out,
1033                "{}{} {}",
1034                back::INDENT,
1035                ty_name,
1036                self.names[&name_key]
1037            )?;
1038            match init {
1039                Some(value) => {
1040                    write!(self.out, " = ")?;
1041                    self.put_expression(value, context, true)?;
1042                }
1043                None => {
1044                    write!(self.out, " = {{}}")?;
1045                }
1046            };
1047            writeln!(self.out, ";")?;
1048        }
1049        Ok(())
1050    }
1051
1052    fn put_level_of_detail(
1053        &mut self,
1054        level: LevelOfDetail,
1055        context: &ExpressionContext,
1056    ) -> BackendResult {
1057        match level {
1058            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1059            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1060        }
1061        Ok(())
1062    }
1063
1064    fn put_image_query(
1065        &mut self,
1066        image: Handle<crate::Expression>,
1067        query: &str,
1068        level: Option<LevelOfDetail>,
1069        context: &ExpressionContext,
1070    ) -> BackendResult {
1071        self.put_expression(image, context, false)?;
1072        write!(self.out, ".get_{query}(")?;
1073        if let Some(level) = level {
1074            self.put_level_of_detail(level, context)?;
1075        }
1076        write!(self.out, ")")?;
1077        Ok(())
1078    }
1079
1080    fn put_image_size_query(
1081        &mut self,
1082        image: Handle<crate::Expression>,
1083        level: Option<LevelOfDetail>,
1084        kind: crate::ScalarKind,
1085        context: &ExpressionContext,
1086    ) -> BackendResult {
1087        if let crate::TypeInner::Image {
1088            class: crate::ImageClass::External,
1089            ..
1090        } = *context.resolve_type(image)
1091        {
1092            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1093            self.put_expression(image, context, true)?;
1094            write!(self.out, ")")?;
1095            return Ok(());
1096        }
1097
1098        //Note: MSL only has separate width/height/depth queries,
1099        // so compose the result of them.
1100        let dim = match *context.resolve_type(image) {
1101            crate::TypeInner::Image { dim, .. } => dim,
1102            ref other => unreachable!("Unexpected type {:?}", other),
1103        };
1104        let scalar = crate::Scalar { kind, width: 4 };
1105        let coordinate_type = scalar.to_msl_name();
1106        match dim {
1107            crate::ImageDimension::D1 => {
1108                // Since 1D textures never have mipmaps, MSL requires that the
1109                // `level` argument be a constexpr 0. It's simplest for us just
1110                // to pass `None` and omit the level entirely.
1111                if kind == crate::ScalarKind::Uint {
1112                    // No need to construct a vector. No cast needed.
1113                    self.put_image_query(image, "width", None, context)?;
1114                } else {
1115                    // There's no definition for `int` in the `metal` namespace.
1116                    write!(self.out, "int(")?;
1117                    self.put_image_query(image, "width", None, context)?;
1118                    write!(self.out, ")")?;
1119                }
1120            }
1121            crate::ImageDimension::D2 => {
1122                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1123                self.put_image_query(image, "width", level, context)?;
1124                write!(self.out, ", ")?;
1125                self.put_image_query(image, "height", level, context)?;
1126                write!(self.out, ")")?;
1127            }
1128            crate::ImageDimension::D3 => {
1129                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1130                self.put_image_query(image, "width", level, context)?;
1131                write!(self.out, ", ")?;
1132                self.put_image_query(image, "height", level, context)?;
1133                write!(self.out, ", ")?;
1134                self.put_image_query(image, "depth", level, context)?;
1135                write!(self.out, ")")?;
1136            }
1137            crate::ImageDimension::Cube => {
1138                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1139                self.put_image_query(image, "width", level, context)?;
1140                write!(self.out, ")")?;
1141            }
1142        }
1143        Ok(())
1144    }
1145
1146    fn put_cast_to_uint_scalar_or_vector(
1147        &mut self,
1148        expr: Handle<crate::Expression>,
1149        context: &ExpressionContext,
1150    ) -> BackendResult {
1151        // coordinates in IR are int, but Metal expects uint
1152        match *context.resolve_type(expr) {
1153            crate::TypeInner::Scalar(_) => {
1154                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1155            }
1156            crate::TypeInner::Vector { size, .. } => {
1157                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1158            }
1159            _ => {
1160                return Err(Error::GenericValidation(
1161                    "Invalid type for image coordinate".into(),
1162                ))
1163            }
1164        };
1165
1166        write!(self.out, "(")?;
1167        self.put_expression(expr, context, true)?;
1168        write!(self.out, ")")?;
1169        Ok(())
1170    }
1171
1172    fn put_image_sample_level(
1173        &mut self,
1174        image: Handle<crate::Expression>,
1175        level: crate::SampleLevel,
1176        context: &ExpressionContext,
1177    ) -> BackendResult {
1178        let has_levels = context.image_needs_lod(image);
1179        match level {
1180            crate::SampleLevel::Auto => {}
1181            crate::SampleLevel::Zero => {
1182                //TODO: do we support Zero on `Sampled` image classes?
1183            }
1184            _ if !has_levels => {
1185                log::warn!("1D image can't be sampled with level {level:?}");
1186            }
1187            crate::SampleLevel::Exact(h) => {
1188                write!(self.out, ", {NAMESPACE}::level(")?;
1189                self.put_expression(h, context, true)?;
1190                write!(self.out, ")")?;
1191            }
1192            crate::SampleLevel::Bias(h) => {
1193                write!(self.out, ", {NAMESPACE}::bias(")?;
1194                self.put_expression(h, context, true)?;
1195                write!(self.out, ")")?;
1196            }
1197            crate::SampleLevel::Gradient { x, y } => {
1198                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1199                self.put_expression(x, context, true)?;
1200                write!(self.out, ", ")?;
1201                self.put_expression(y, context, true)?;
1202                write!(self.out, ")")?;
1203            }
1204        }
1205        Ok(())
1206    }
1207
1208    fn put_image_coordinate_limits(
1209        &mut self,
1210        image: Handle<crate::Expression>,
1211        level: Option<LevelOfDetail>,
1212        context: &ExpressionContext,
1213    ) -> BackendResult {
1214        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1215        write!(self.out, " - 1")?;
1216        Ok(())
1217    }
1218
1219    /// General function for writing restricted image indexes.
1220    ///
1221    /// This is used to produce restricted mip levels, array indices, and sample
1222    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1223    /// [`Restrict`] bounds check policy.
1224    ///
1225    /// This function writes an expression of the form:
1226    ///
1227    /// ```ignore
1228    ///
1229    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1230    ///
1231    /// ```
1232    ///
1233    /// [`ImageLoad`]: crate::Expression::ImageLoad
1234    /// [`ImageStore`]: crate::Statement::ImageStore
1235    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1236    fn put_restricted_scalar_image_index(
1237        &mut self,
1238        image: Handle<crate::Expression>,
1239        index: Handle<crate::Expression>,
1240        limit_method: &str,
1241        context: &ExpressionContext,
1242    ) -> BackendResult {
1243        write!(self.out, "{NAMESPACE}::min(uint(")?;
1244        self.put_expression(index, context, true)?;
1245        write!(self.out, "), ")?;
1246        self.put_expression(image, context, false)?;
1247        write!(self.out, ".{limit_method}() - 1)")?;
1248        Ok(())
1249    }
1250
1251    fn put_restricted_texel_address(
1252        &mut self,
1253        image: Handle<crate::Expression>,
1254        address: &TexelAddress,
1255        context: &ExpressionContext,
1256    ) -> BackendResult {
1257        // Write the coordinate.
1258        write!(self.out, "{NAMESPACE}::min(")?;
1259        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1260        write!(self.out, ", ")?;
1261        self.put_image_coordinate_limits(image, address.level, context)?;
1262        write!(self.out, ")")?;
1263
1264        // Write the array index, if present.
1265        if let Some(array_index) = address.array_index {
1266            write!(self.out, ", ")?;
1267            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1268        }
1269
1270        // Write the sample index, if present.
1271        if let Some(sample) = address.sample {
1272            write!(self.out, ", ")?;
1273            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1274        }
1275
1276        // The level of detail should be clamped and cached by
1277        // `put_cache_restricted_level`, so we don't need to clamp it here.
1278        if let Some(level) = address.level {
1279            write!(self.out, ", ")?;
1280            self.put_level_of_detail(level, context)?;
1281        }
1282
1283        Ok(())
1284    }
1285
1286    /// Write an expression that is true if the given image access is in bounds.
1287    fn put_image_access_bounds_check(
1288        &mut self,
1289        image: Handle<crate::Expression>,
1290        address: &TexelAddress,
1291        context: &ExpressionContext,
1292    ) -> BackendResult {
1293        let mut conjunction = "";
1294
1295        // First, check the level of detail. Only if that is in bounds can we
1296        // use it to find the appropriate bounds for the coordinates.
1297        let level = if let Some(level) = address.level {
1298            write!(self.out, "uint(")?;
1299            self.put_level_of_detail(level, context)?;
1300            write!(self.out, ") < ")?;
1301            self.put_expression(image, context, true)?;
1302            write!(self.out, ".get_num_mip_levels()")?;
1303            conjunction = " && ";
1304            Some(level)
1305        } else {
1306            None
1307        };
1308
1309        // Check sample index, if present.
1310        if let Some(sample) = address.sample {
1311            write!(self.out, "uint(")?;
1312            self.put_expression(sample, context, true)?;
1313            write!(self.out, ") < ")?;
1314            self.put_expression(image, context, true)?;
1315            write!(self.out, ".get_num_samples()")?;
1316            conjunction = " && ";
1317        }
1318
1319        // Check array index, if present.
1320        if let Some(array_index) = address.array_index {
1321            write!(self.out, "{conjunction}uint(")?;
1322            self.put_expression(array_index, context, true)?;
1323            write!(self.out, ") < ")?;
1324            self.put_expression(image, context, true)?;
1325            write!(self.out, ".get_array_size()")?;
1326            conjunction = " && ";
1327        }
1328
1329        // Finally, check if the coordinates are within bounds.
1330        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1331            crate::TypeInner::Vector { .. } => true,
1332            _ => false,
1333        };
1334        write!(self.out, "{conjunction}")?;
1335        if coord_is_vector {
1336            write!(self.out, "{NAMESPACE}::all(")?;
1337        }
1338        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1339        write!(self.out, " < ")?;
1340        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1341        if coord_is_vector {
1342            write!(self.out, ")")?;
1343        }
1344
1345        Ok(())
1346    }
1347
1348    fn put_image_load(
1349        &mut self,
1350        load: Handle<crate::Expression>,
1351        image: Handle<crate::Expression>,
1352        mut address: TexelAddress,
1353        context: &ExpressionContext,
1354    ) -> BackendResult {
1355        if let crate::TypeInner::Image {
1356            class: crate::ImageClass::External,
1357            ..
1358        } = *context.resolve_type(image)
1359        {
1360            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1361            self.put_expression(image, context, true)?;
1362            write!(self.out, ", ")?;
1363            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1364            write!(self.out, ")")?;
1365            return Ok(());
1366        }
1367
1368        match context.policies.image_load {
1369            proc::BoundsCheckPolicy::Restrict => {
1370                // Use the cached restricted level of detail, if any. Omit the
1371                // level altogether for 1D textures.
1372                if address.level.is_some() {
1373                    address.level = if context.image_needs_lod(image) {
1374                        Some(LevelOfDetail::Restricted(load))
1375                    } else {
1376                        None
1377                    }
1378                }
1379
1380                self.put_expression(image, context, false)?;
1381                write!(self.out, ".read(")?;
1382                self.put_restricted_texel_address(image, &address, context)?;
1383                write!(self.out, ")")?;
1384            }
1385            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1386                write!(self.out, "(")?;
1387                self.put_image_access_bounds_check(image, &address, context)?;
1388                write!(self.out, " ? ")?;
1389                self.put_unchecked_image_load(image, &address, context)?;
1390                write!(self.out, ": DefaultConstructible())")?;
1391            }
1392            proc::BoundsCheckPolicy::Unchecked => {
1393                self.put_unchecked_image_load(image, &address, context)?;
1394            }
1395        }
1396
1397        Ok(())
1398    }
1399
1400    fn put_unchecked_image_load(
1401        &mut self,
1402        image: Handle<crate::Expression>,
1403        address: &TexelAddress,
1404        context: &ExpressionContext,
1405    ) -> BackendResult {
1406        self.put_expression(image, context, false)?;
1407        write!(self.out, ".read(")?;
1408        // coordinates in IR are int, but Metal expects uint
1409        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1410        if let Some(expr) = address.array_index {
1411            write!(self.out, ", ")?;
1412            self.put_expression(expr, context, true)?;
1413        }
1414        if let Some(sample) = address.sample {
1415            write!(self.out, ", ")?;
1416            self.put_expression(sample, context, true)?;
1417        }
1418        if let Some(level) = address.level {
1419            if context.image_needs_lod(image) {
1420                write!(self.out, ", ")?;
1421                self.put_level_of_detail(level, context)?;
1422            }
1423        }
1424        write!(self.out, ")")?;
1425
1426        Ok(())
1427    }
1428
1429    fn put_image_atomic(
1430        &mut self,
1431        level: back::Level,
1432        image: Handle<crate::Expression>,
1433        address: &TexelAddress,
1434        fun: crate::AtomicFunction,
1435        value: Handle<crate::Expression>,
1436        context: &StatementContext,
1437    ) -> BackendResult {
1438        write!(self.out, "{level}")?;
1439        self.put_expression(image, &context.expression, false)?;
1440        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1441            fun.to_msl_64_bit()?
1442        } else {
1443            fun.to_msl()
1444        };
1445        write!(self.out, ".atomic_{op}(")?;
1446        // coordinates in IR are int, but Metal expects uint
1447        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1448        write!(self.out, ", ")?;
1449        self.put_expression(value, &context.expression, true)?;
1450        writeln!(self.out, ");")?;
1451
1452        Ok(())
1453    }
1454
1455    fn put_image_store(
1456        &mut self,
1457        level: back::Level,
1458        image: Handle<crate::Expression>,
1459        address: &TexelAddress,
1460        value: Handle<crate::Expression>,
1461        context: &StatementContext,
1462    ) -> BackendResult {
1463        write!(self.out, "{level}")?;
1464        self.put_expression(image, &context.expression, false)?;
1465        write!(self.out, ".write(")?;
1466        self.put_expression(value, &context.expression, true)?;
1467        write!(self.out, ", ")?;
1468        // coordinates in IR are int, but Metal expects uint
1469        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1470        if let Some(expr) = address.array_index {
1471            write!(self.out, ", ")?;
1472            self.put_expression(expr, &context.expression, true)?;
1473        }
1474        writeln!(self.out, ");")?;
1475
1476        Ok(())
1477    }
1478
1479    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1480    ///
1481    /// The 'maximum valid index' is simply one less than the array's length.
1482    ///
1483    /// This emits an expression of the form `a / b`, so the caller must
1484    /// parenthesize its output if it will be applying operators of higher
1485    /// precedence.
1486    ///
1487    /// `handle` must be the handle of a global variable whose final member is a
1488    /// dynamically sized array.
1489    fn put_dynamic_array_max_index(
1490        &mut self,
1491        handle: Handle<crate::GlobalVariable>,
1492        context: &ExpressionContext,
1493    ) -> BackendResult {
1494        let global = &context.module.global_variables[handle];
1495        let (offset, array_ty) = match context.module.types[global.ty].inner {
1496            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1497                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1498                None => return Err(Error::GenericValidation("Struct has no members".into())),
1499            },
1500            crate::TypeInner::Array {
1501                size: crate::ArraySize::Dynamic,
1502                ..
1503            } => (0, global.ty),
1504            ref ty => {
1505                return Err(Error::GenericValidation(format!(
1506                    "Expected type with dynamic array, got {ty:?}"
1507                )))
1508            }
1509        };
1510
1511        let (size, stride) = match context.module.types[array_ty].inner {
1512            crate::TypeInner::Array { base, stride, .. } => (
1513                context.module.types[base]
1514                    .inner
1515                    .size(context.module.to_ctx()),
1516                stride,
1517            ),
1518            ref ty => {
1519                return Err(Error::GenericValidation(format!(
1520                    "Expected array type, got {ty:?}"
1521                )))
1522            }
1523        };
1524
1525        // When the stride length is larger than the size, the final element's stride of
1526        // bytes would have padding following the value. But the buffer size in
1527        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1528        // enough to hold the actual values' bytes.
1529        //
1530        // So subtract off the size to get a byte size that falls at the start or within
1531        // the final element. Then divide by the stride size, to get one less than the
1532        // length, and then add one. This works even if the buffer size does include the
1533        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1534        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1535        // rules, together with draw-time validation when `minBindingSize` is zero,
1536        // prevent that.
1537        write!(
1538            self.out,
1539            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1540            member = ArraySizeMember(handle),
1541            offset = offset,
1542            size = size,
1543            stride = stride,
1544        )?;
1545        Ok(())
1546    }
1547
1548    /// Emit code for the arithmetic expression of the dot product.
1549    ///
1550    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1551    /// an index. It writes out the expression for the vector component at that index.
1552    fn put_dot_product<T: Copy>(
1553        &mut self,
1554        arg: T,
1555        arg1: T,
1556        size: usize,
1557        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1558    ) -> BackendResult {
1559        // Write parentheses around the dot product expression to prevent operators
1560        // with different precedences from applying earlier.
1561        write!(self.out, "(")?;
1562
1563        // Cycle through all the components of the vector
1564        for index in 0..size {
1565            // Write the addition to the previous product
1566            // This will print an extra '+' at the beginning but that is fine in msl
1567            write!(self.out, " + ")?;
1568            extractor(self, arg, index)?;
1569            write!(self.out, " * ")?;
1570            extractor(self, arg1, index)?;
1571        }
1572
1573        write!(self.out, ")")?;
1574        Ok(())
1575    }
1576
1577    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1578    fn put_pack4x8(
1579        &mut self,
1580        arg: Handle<crate::Expression>,
1581        context: &ExpressionContext<'_>,
1582        was_signed: bool,
1583        clamp_bounds: Option<(&str, &str)>,
1584    ) -> Result<(), Error> {
1585        let write_arg = |this: &mut Self| -> BackendResult {
1586            if let Some((min, max)) = clamp_bounds {
1587                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1588                write!(this.out, "{NAMESPACE}::clamp(")?;
1589                this.put_expression(arg, context, true)?;
1590                write!(this.out, ", {min}, {max})")?;
1591            } else {
1592                this.put_expression(arg, context, true)?;
1593            }
1594            Ok(())
1595        };
1596
1597        if context.lang_version >= (2, 1) {
1598            let packed_type = if was_signed {
1599                "packed_char4"
1600            } else {
1601                "packed_uchar4"
1602            };
1603            // Metal uses little endian byte order, which matches what WGSL expects here.
1604            write!(self.out, "as_type<uint>({packed_type}(")?;
1605            write_arg(self)?;
1606            write!(self.out, "))")?;
1607        } else {
1608            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1609            if was_signed {
1610                write!(self.out, "uint(")?;
1611            }
1612            write!(self.out, "(")?;
1613            write_arg(self)?;
1614            write!(self.out, "[0] & 0xFF) | ((")?;
1615            write_arg(self)?;
1616            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1617            write_arg(self)?;
1618            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1619            write_arg(self)?;
1620            write!(self.out, "[3] & 0xFF) << 24)")?;
1621            if was_signed {
1622                write!(self.out, ")")?;
1623            }
1624        }
1625
1626        Ok(())
1627    }
1628
1629    /// Emit code for the isign expression.
1630    ///
1631    fn put_isign(
1632        &mut self,
1633        arg: Handle<crate::Expression>,
1634        context: &ExpressionContext,
1635    ) -> BackendResult {
1636        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1637        let scalar = context
1638            .resolve_type(arg)
1639            .scalar()
1640            .expect("put_isign should only be called for args which have an integer scalar type")
1641            .to_msl_name();
1642        match context.resolve_type(arg) {
1643            &crate::TypeInner::Vector { size, .. } => {
1644                let size = common::vector_size_str(size);
1645                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1646            }
1647            _ => {
1648                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1649            }
1650        }
1651        write!(self.out, ", (")?;
1652        self.put_expression(arg, context, true)?;
1653        write!(self.out, " > 0)), {scalar}(0), (")?;
1654        self.put_expression(arg, context, true)?;
1655        write!(self.out, " == 0))")?;
1656        Ok(())
1657    }
1658
1659    fn put_const_expression(
1660        &mut self,
1661        expr_handle: Handle<crate::Expression>,
1662        module: &crate::Module,
1663        mod_info: &valid::ModuleInfo,
1664        arena: &crate::Arena<crate::Expression>,
1665    ) -> BackendResult {
1666        self.put_possibly_const_expression(
1667            expr_handle,
1668            arena,
1669            module,
1670            mod_info,
1671            &(module, mod_info),
1672            |&(_, mod_info), expr| &mod_info[expr],
1673            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1674        )
1675    }
1676
1677    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1678        match literal {
1679            crate::Literal::F64(_) => {
1680                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1681            }
1682            crate::Literal::F16(value) => {
1683                if value.is_infinite() {
1684                    let sign = if value.is_sign_negative() { "-" } else { "" };
1685                    write!(self.out, "{sign}INFINITY")?;
1686                } else if value.is_nan() {
1687                    write!(self.out, "NAN")?;
1688                } else {
1689                    let suffix = if value.fract() == f16::from_f32(0.0) {
1690                        ".0h"
1691                    } else {
1692                        "h"
1693                    };
1694                    write!(self.out, "{value}{suffix}")?;
1695                }
1696            }
1697            crate::Literal::F32(value) => {
1698                if value.is_infinite() {
1699                    let sign = if value.is_sign_negative() { "-" } else { "" };
1700                    write!(self.out, "{sign}INFINITY")?;
1701                } else if value.is_nan() {
1702                    write!(self.out, "NAN")?;
1703                } else {
1704                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1705                    write!(self.out, "{value}{suffix}")?;
1706                }
1707            }
1708            crate::Literal::U32(value) => {
1709                write!(self.out, "{value}u")?;
1710            }
1711            crate::Literal::I32(value) => {
1712                // `-2147483648` is parsed as unary negation of positive 2147483648.
1713                // 2147483648 is too large for int32_t meaning the expression gets
1714                // promoted to a int64_t which is not our intention. Avoid this by instead
1715                // using `-2147483647 - 1`.
1716                if value == i32::MIN {
1717                    write!(self.out, "({} - 1)", value + 1)?;
1718                } else {
1719                    write!(self.out, "{value}")?;
1720                }
1721            }
1722            crate::Literal::U64(value) => {
1723                write!(self.out, "{value}uL")?;
1724            }
1725            crate::Literal::I64(value) => {
1726                // `-9223372036854775808` is parsed as unary negation of positive
1727                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1728                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1729                // value to `-9223372036854775808`. Which would then be negated, possibly
1730                // causing undefined behaviour. Avoid this by instead using
1731                // `-9223372036854775808L - 1L`.
1732                if value == i64::MIN {
1733                    write!(self.out, "({}L - 1L)", value + 1)?;
1734                } else {
1735                    write!(self.out, "{value}L")?;
1736                }
1737            }
1738            crate::Literal::Bool(value) => {
1739                write!(self.out, "{value}")?;
1740            }
1741            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1742                return Err(Error::GenericValidation(
1743                    "Unsupported abstract literal".into(),
1744                ));
1745            }
1746        }
1747        Ok(())
1748    }
1749
1750    #[allow(clippy::too_many_arguments)]
1751    fn put_possibly_const_expression<C, I, E>(
1752        &mut self,
1753        expr_handle: Handle<crate::Expression>,
1754        expressions: &crate::Arena<crate::Expression>,
1755        module: &crate::Module,
1756        mod_info: &valid::ModuleInfo,
1757        ctx: &C,
1758        get_expr_ty: I,
1759        put_expression: E,
1760    ) -> BackendResult
1761    where
1762        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1763        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1764    {
1765        match expressions[expr_handle] {
1766            crate::Expression::Literal(literal) => {
1767                self.put_literal(literal)?;
1768            }
1769            crate::Expression::Constant(handle) => {
1770                let constant = &module.constants[handle];
1771                if constant.name.is_some() {
1772                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1773                } else {
1774                    self.put_const_expression(
1775                        constant.init,
1776                        module,
1777                        mod_info,
1778                        &module.global_expressions,
1779                    )?;
1780                }
1781            }
1782            crate::Expression::ZeroValue(ty) => {
1783                let ty_name = TypeContext {
1784                    handle: ty,
1785                    gctx: module.to_ctx(),
1786                    names: &self.names,
1787                    access: crate::StorageAccess::empty(),
1788                    first_time: false,
1789                };
1790                write!(self.out, "{ty_name} {{}}")?;
1791            }
1792            crate::Expression::Compose { ty, ref components } => {
1793                let ty_name = TypeContext {
1794                    handle: ty,
1795                    gctx: module.to_ctx(),
1796                    names: &self.names,
1797                    access: crate::StorageAccess::empty(),
1798                    first_time: false,
1799                };
1800                write!(self.out, "{ty_name}")?;
1801                match module.types[ty].inner {
1802                    crate::TypeInner::Scalar(_)
1803                    | crate::TypeInner::Vector { .. }
1804                    | crate::TypeInner::Matrix { .. } => {
1805                        self.put_call_parameters_impl(
1806                            components.iter().copied(),
1807                            ctx,
1808                            put_expression,
1809                        )?;
1810                    }
1811                    crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1812                        write!(self.out, " {{")?;
1813                        for (index, &component) in components.iter().enumerate() {
1814                            if index != 0 {
1815                                write!(self.out, ", ")?;
1816                            }
1817                            // insert padding initialization, if needed
1818                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1819                                write!(self.out, "{{}}, ")?;
1820                            }
1821                            put_expression(self, ctx, component)?;
1822                        }
1823                        write!(self.out, "}}")?;
1824                    }
1825                    _ => return Err(Error::UnsupportedCompose(ty)),
1826                }
1827            }
1828            crate::Expression::Splat { size, value } => {
1829                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1830                    crate::TypeInner::Scalar(scalar) => scalar,
1831                    ref ty => {
1832                        return Err(Error::GenericValidation(format!(
1833                            "Expected splat value type must be a scalar, got {ty:?}",
1834                        )))
1835                    }
1836                };
1837                put_numeric_type(&mut self.out, scalar, &[size])?;
1838                write!(self.out, "(")?;
1839                put_expression(self, ctx, value)?;
1840                write!(self.out, ")")?;
1841            }
1842            _ => {
1843                return Err(Error::Override);
1844            }
1845        }
1846
1847        Ok(())
1848    }
1849
1850    /// Emit code for the expression `expr_handle`.
1851    ///
1852    /// The `is_scoped` argument is true if the surrounding operators have the
1853    /// precedence of the comma operator, or lower. So, for example:
1854    ///
1855    /// - Pass `true` for `is_scoped` when writing function arguments, an
1856    ///   expression statement, an initializer expression, or anything already
1857    ///   wrapped in parenthesis.
1858    ///
1859    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1860    ///   almost anything else.
1861    fn put_expression(
1862        &mut self,
1863        expr_handle: Handle<crate::Expression>,
1864        context: &ExpressionContext,
1865        is_scoped: bool,
1866    ) -> BackendResult {
1867        // Add to the set in order to track the stack size.
1868        #[cfg(test)]
1869        self.put_expression_stack_pointers
1870            .insert(ptr::from_ref(&expr_handle).cast());
1871
1872        if let Some(name) = self.named_expressions.get(&expr_handle) {
1873            write!(self.out, "{name}")?;
1874            return Ok(());
1875        }
1876
1877        let expression = &context.function.expressions[expr_handle];
1878        match *expression {
1879            crate::Expression::Literal(_)
1880            | crate::Expression::Constant(_)
1881            | crate::Expression::ZeroValue(_)
1882            | crate::Expression::Compose { .. }
1883            | crate::Expression::Splat { .. } => {
1884                self.put_possibly_const_expression(
1885                    expr_handle,
1886                    &context.function.expressions,
1887                    context.module,
1888                    context.mod_info,
1889                    context,
1890                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1891                    |writer, context, expr| writer.put_expression(expr, context, true),
1892                )?;
1893            }
1894            crate::Expression::Override(_) => return Err(Error::Override),
1895            crate::Expression::Access { base, .. }
1896            | crate::Expression::AccessIndex { base, .. } => {
1897                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1898                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1899                // access chain at a time, recursing back through `put_expression` only
1900                // for index expressions and the base object, we will never see intermediate
1901                // `Access` or `AccessIndex` expressions here.
1902                let policy = context.choose_bounds_check_policy(base);
1903                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1904                    && self.put_bounds_checks(
1905                        expr_handle,
1906                        context,
1907                        back::Level(0),
1908                        if is_scoped { "" } else { "(" },
1909                    )?
1910                {
1911                    write!(self.out, " ? ")?;
1912                    self.put_access_chain(expr_handle, policy, context)?;
1913                    write!(self.out, " : ")?;
1914
1915                    if context.resolve_type(base).pointer_space().is_some() {
1916                        // We can't just use `DefaultConstructible` if this is a pointer.
1917                        // Instead, we create a dummy local variable to serve as pointer
1918                        // target if the access is out of bounds.
1919                        let result_ty = context.info[expr_handle]
1920                            .ty
1921                            .inner_with(&context.module.types)
1922                            .pointer_base_type();
1923                        let result_ty_handle = match result_ty {
1924                            Some(TypeResolution::Handle(handle)) => handle,
1925                            Some(TypeResolution::Value(_)) => {
1926                                // As long as the result of a pointer access expression is
1927                                // passed to a function or stored in a let binding, the
1928                                // type will be in the arena. If additional uses of
1929                                // pointers become valid, this assumption might no longer
1930                                // hold. Note that the LHS of a load or store doesn't
1931                                // take this path -- there is dedicated code in `put_load`
1932                                // and `put_store`.
1933                                unreachable!(
1934                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1935                                );
1936                            }
1937                            None => {
1938                                unreachable!(
1939                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1940                                )
1941                            }
1942                        };
1943                        let name_key =
1944                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
1945                        self.out.write_str(&self.names[&name_key])?;
1946                    } else {
1947                        write!(self.out, "DefaultConstructible()")?;
1948                    }
1949
1950                    if !is_scoped {
1951                        write!(self.out, ")")?;
1952                    }
1953                } else {
1954                    self.put_access_chain(expr_handle, policy, context)?;
1955                }
1956            }
1957            crate::Expression::Swizzle {
1958                size,
1959                vector,
1960                pattern,
1961            } => {
1962                self.put_wrapped_expression_for_packed_vec3_access(
1963                    vector,
1964                    context,
1965                    false,
1966                    &Self::put_expression,
1967                )?;
1968                write!(self.out, ".")?;
1969                for &sc in pattern[..size as usize].iter() {
1970                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1971                }
1972            }
1973            crate::Expression::FunctionArgument(index) => {
1974                let name_key = match context.origin {
1975                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1976                    FunctionOrigin::EntryPoint(ep_index) => {
1977                        NameKey::EntryPointArgument(ep_index, index)
1978                    }
1979                };
1980                let name = &self.names[&name_key];
1981                write!(self.out, "{name}")?;
1982            }
1983            crate::Expression::GlobalVariable(handle) => {
1984                let name = &self.names[&NameKey::GlobalVariable(handle)];
1985                write!(self.out, "{name}")?;
1986            }
1987            crate::Expression::LocalVariable(handle) => {
1988                let name_key = NameKey::local(context.origin, handle);
1989                let name = &self.names[&name_key];
1990                write!(self.out, "{name}")?;
1991            }
1992            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1993            crate::Expression::ImageSample {
1994                coordinate,
1995                image,
1996                sampler,
1997                clamp_to_edge: true,
1998                gather: None,
1999                array_index: None,
2000                offset: None,
2001                level: crate::SampleLevel::Zero,
2002                depth_ref: None,
2003            } => {
2004                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2005                self.put_expression(image, context, true)?;
2006                write!(self.out, ", ")?;
2007                self.put_expression(sampler, context, true)?;
2008                write!(self.out, ", ")?;
2009                self.put_expression(coordinate, context, true)?;
2010                write!(self.out, ")")?;
2011            }
2012            crate::Expression::ImageSample {
2013                image,
2014                sampler,
2015                gather,
2016                coordinate,
2017                array_index,
2018                offset,
2019                level,
2020                depth_ref,
2021                clamp_to_edge,
2022            } => {
2023                if clamp_to_edge {
2024                    return Err(Error::GenericValidation(
2025                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2026                    ));
2027                }
2028
2029                let main_op = match gather {
2030                    Some(_) => "gather",
2031                    None => "sample",
2032                };
2033                let comparison_op = match depth_ref {
2034                    Some(_) => "_compare",
2035                    None => "",
2036                };
2037                self.put_expression(image, context, false)?;
2038                write!(self.out, ".{main_op}{comparison_op}(")?;
2039                self.put_expression(sampler, context, true)?;
2040                write!(self.out, ", ")?;
2041                self.put_expression(coordinate, context, true)?;
2042                if let Some(expr) = array_index {
2043                    write!(self.out, ", ")?;
2044                    self.put_expression(expr, context, true)?;
2045                }
2046                if let Some(dref) = depth_ref {
2047                    write!(self.out, ", ")?;
2048                    self.put_expression(dref, context, true)?;
2049                }
2050
2051                self.put_image_sample_level(image, level, context)?;
2052
2053                if let Some(offset) = offset {
2054                    write!(self.out, ", ")?;
2055                    self.put_expression(offset, context, true)?;
2056                }
2057
2058                match gather {
2059                    None | Some(crate::SwizzleComponent::X) => {}
2060                    Some(component) => {
2061                        let is_cube_map = match *context.resolve_type(image) {
2062                            crate::TypeInner::Image {
2063                                dim: crate::ImageDimension::Cube,
2064                                ..
2065                            } => true,
2066                            _ => false,
2067                        };
2068                        // Offset always comes before the gather, except
2069                        // in cube maps where it's not applicable
2070                        if offset.is_none() && !is_cube_map {
2071                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2072                        }
2073                        let letter = back::COMPONENTS[component as usize];
2074                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2075                    }
2076                }
2077                write!(self.out, ")")?;
2078            }
2079            crate::Expression::ImageLoad {
2080                image,
2081                coordinate,
2082                array_index,
2083                sample,
2084                level,
2085            } => {
2086                let address = TexelAddress {
2087                    coordinate,
2088                    array_index,
2089                    sample,
2090                    level: level.map(LevelOfDetail::Direct),
2091                };
2092                self.put_image_load(expr_handle, image, address, context)?;
2093            }
2094            //Note: for all the queries, the signed integers are expected,
2095            // so a conversion is needed.
2096            crate::Expression::ImageQuery { image, query } => match query {
2097                crate::ImageQuery::Size { level } => {
2098                    self.put_image_size_query(
2099                        image,
2100                        level.map(LevelOfDetail::Direct),
2101                        crate::ScalarKind::Uint,
2102                        context,
2103                    )?;
2104                }
2105                crate::ImageQuery::NumLevels => {
2106                    self.put_expression(image, context, false)?;
2107                    write!(self.out, ".get_num_mip_levels()")?;
2108                }
2109                crate::ImageQuery::NumLayers => {
2110                    self.put_expression(image, context, false)?;
2111                    write!(self.out, ".get_array_size()")?;
2112                }
2113                crate::ImageQuery::NumSamples => {
2114                    self.put_expression(image, context, false)?;
2115                    write!(self.out, ".get_num_samples()")?;
2116                }
2117            },
2118            crate::Expression::Unary { op, expr } => {
2119                let op_str = match op {
2120                    crate::UnaryOperator::Negate => {
2121                        match context.resolve_type(expr).scalar_kind() {
2122                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2123                            _ => "-",
2124                        }
2125                    }
2126                    crate::UnaryOperator::LogicalNot => "!",
2127                    crate::UnaryOperator::BitwiseNot => "~",
2128                };
2129                write!(self.out, "{op_str}(")?;
2130                self.put_expression(expr, context, false)?;
2131                write!(self.out, ")")?;
2132            }
2133            crate::Expression::Binary { op, left, right } => {
2134                let kind = context
2135                    .resolve_type(left)
2136                    .scalar_kind()
2137                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2138
2139                if op == crate::BinaryOperator::Divide
2140                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2141                {
2142                    write!(self.out, "{DIV_FUNCTION}(")?;
2143                    self.put_expression(left, context, true)?;
2144                    write!(self.out, ", ")?;
2145                    self.put_expression(right, context, true)?;
2146                    write!(self.out, ")")?;
2147                } else if op == crate::BinaryOperator::Modulo
2148                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2149                {
2150                    write!(self.out, "{MOD_FUNCTION}(")?;
2151                    self.put_expression(left, context, true)?;
2152                    write!(self.out, ", ")?;
2153                    self.put_expression(right, context, true)?;
2154                    write!(self.out, ")")?;
2155                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2156                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2157                    //
2158                    // float:
2159                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2160                    write!(self.out, "{NAMESPACE}::fmod(")?;
2161                    self.put_expression(left, context, true)?;
2162                    write!(self.out, ", ")?;
2163                    self.put_expression(right, context, true)?;
2164                    write!(self.out, ")")?;
2165                } else if (op == crate::BinaryOperator::Add
2166                    || op == crate::BinaryOperator::Subtract
2167                    || op == crate::BinaryOperator::Multiply)
2168                    && kind == crate::ScalarKind::Sint
2169                {
2170                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2171                        crate::TypeInner::Scalar(scalar) => {
2172                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2173                                kind: crate::ScalarKind::Uint,
2174                                ..scalar
2175                            }))
2176                        }
2177                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2178                            size,
2179                            scalar: crate::Scalar {
2180                                kind: crate::ScalarKind::Uint,
2181                                ..scalar
2182                            },
2183                        }),
2184                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2185                    };
2186
2187                    // Avoid undefined behaviour due to overflowing signed
2188                    // integer arithmetic. Cast the operands to unsigned prior
2189                    // to performing the operation, then cast the result back
2190                    // to signed.
2191                    self.put_bitcasted_expression(
2192                        context.resolve_type(expr_handle),
2193                        context,
2194                        &|writer, context, is_scoped| {
2195                            writer.put_binop(
2196                                op,
2197                                left,
2198                                right,
2199                                context,
2200                                is_scoped,
2201                                &|writer, expr, context, _is_scoped| {
2202                                    writer.put_bitcasted_expression(
2203                                        &to_unsigned(context.resolve_type(expr))?,
2204                                        context,
2205                                        &|writer, context, is_scoped| {
2206                                            writer.put_expression(expr, context, is_scoped)
2207                                        },
2208                                    )
2209                                },
2210                            )
2211                        },
2212                    )?;
2213                } else {
2214                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2215                }
2216            }
2217            crate::Expression::Select {
2218                condition,
2219                accept,
2220                reject,
2221            } => match *context.resolve_type(condition) {
2222                crate::TypeInner::Scalar(crate::Scalar {
2223                    kind: crate::ScalarKind::Bool,
2224                    ..
2225                }) => {
2226                    if !is_scoped {
2227                        write!(self.out, "(")?;
2228                    }
2229                    self.put_expression(condition, context, false)?;
2230                    write!(self.out, " ? ")?;
2231                    self.put_expression(accept, context, false)?;
2232                    write!(self.out, " : ")?;
2233                    self.put_expression(reject, context, false)?;
2234                    if !is_scoped {
2235                        write!(self.out, ")")?;
2236                    }
2237                }
2238                crate::TypeInner::Vector {
2239                    scalar:
2240                        crate::Scalar {
2241                            kind: crate::ScalarKind::Bool,
2242                            ..
2243                        },
2244                    ..
2245                } => {
2246                    write!(self.out, "{NAMESPACE}::select(")?;
2247                    self.put_expression(reject, context, true)?;
2248                    write!(self.out, ", ")?;
2249                    self.put_expression(accept, context, true)?;
2250                    write!(self.out, ", ")?;
2251                    self.put_expression(condition, context, true)?;
2252                    write!(self.out, ")")?;
2253                }
2254                ref ty => {
2255                    return Err(Error::GenericValidation(format!(
2256                        "Expected select condition to be a non-bool type, got {ty:?}",
2257                    )))
2258                }
2259            },
2260            crate::Expression::Derivative { axis, expr, .. } => {
2261                use crate::DerivativeAxis as Axis;
2262                let op = match axis {
2263                    Axis::X => "dfdx",
2264                    Axis::Y => "dfdy",
2265                    Axis::Width => "fwidth",
2266                };
2267                write!(self.out, "{NAMESPACE}::{op}")?;
2268                self.put_call_parameters(iter::once(expr), context)?;
2269            }
2270            crate::Expression::Relational { fun, argument } => {
2271                let op = match fun {
2272                    crate::RelationalFunction::Any => "any",
2273                    crate::RelationalFunction::All => "all",
2274                    crate::RelationalFunction::IsNan => "isnan",
2275                    crate::RelationalFunction::IsInf => "isinf",
2276                };
2277                write!(self.out, "{NAMESPACE}::{op}")?;
2278                self.put_call_parameters(iter::once(argument), context)?;
2279            }
2280            crate::Expression::Math {
2281                fun,
2282                arg,
2283                arg1,
2284                arg2,
2285                arg3,
2286            } => {
2287                use crate::MathFunction as Mf;
2288
2289                let arg_type = context.resolve_type(arg);
2290                let scalar_argument = match arg_type {
2291                    &crate::TypeInner::Scalar(_) => true,
2292                    _ => false,
2293                };
2294
2295                let fun_name = match fun {
2296                    // comparison
2297                    Mf::Abs => "abs",
2298                    Mf::Min => "min",
2299                    Mf::Max => "max",
2300                    Mf::Clamp => "clamp",
2301                    Mf::Saturate => "saturate",
2302                    // trigonometry
2303                    Mf::Cos => "cos",
2304                    Mf::Cosh => "cosh",
2305                    Mf::Sin => "sin",
2306                    Mf::Sinh => "sinh",
2307                    Mf::Tan => "tan",
2308                    Mf::Tanh => "tanh",
2309                    Mf::Acos => "acos",
2310                    Mf::Asin => "asin",
2311                    Mf::Atan => "atan",
2312                    Mf::Atan2 => "atan2",
2313                    Mf::Asinh => "asinh",
2314                    Mf::Acosh => "acosh",
2315                    Mf::Atanh => "atanh",
2316                    Mf::Radians => "",
2317                    Mf::Degrees => "",
2318                    // decomposition
2319                    Mf::Ceil => "ceil",
2320                    Mf::Floor => "floor",
2321                    Mf::Round => "rint",
2322                    Mf::Fract => "fract",
2323                    Mf::Trunc => "trunc",
2324                    Mf::Modf => MODF_FUNCTION,
2325                    Mf::Frexp => FREXP_FUNCTION,
2326                    Mf::Ldexp => "ldexp",
2327                    // exponent
2328                    Mf::Exp => "exp",
2329                    Mf::Exp2 => "exp2",
2330                    Mf::Log => "log",
2331                    Mf::Log2 => "log2",
2332                    Mf::Pow => "pow",
2333                    // geometry
2334                    Mf::Dot => match *context.resolve_type(arg) {
2335                        crate::TypeInner::Vector {
2336                            scalar:
2337                                crate::Scalar {
2338                                    // Resolve float values to MSL's builtin dot function.
2339                                    kind: crate::ScalarKind::Float,
2340                                    ..
2341                                },
2342                            ..
2343                        } => "dot",
2344                        crate::TypeInner::Vector {
2345                            size,
2346                            scalar:
2347                                scalar @ crate::Scalar {
2348                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2349                                    ..
2350                                },
2351                        } => {
2352                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2353                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2354                            write!(self.out, "{fun_name}(")?;
2355                            self.put_expression(arg, context, true)?;
2356                            write!(self.out, ", ")?;
2357                            self.put_expression(arg1.unwrap(), context, true)?;
2358                            write!(self.out, ")")?;
2359                            return Ok(());
2360                        }
2361                        _ => unreachable!(
2362                            "Correct TypeInner for dot product should be already validated"
2363                        ),
2364                    },
2365                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2366                        if context.lang_version >= (2, 1) {
2367                            // Write potentially optimizable code using `packed_(u?)char4`.
2368                            // The two function arguments were already reinterpreted as packed (signed
2369                            // or unsigned) chars in `Self::put_block`.
2370                            let packed_type = match fun {
2371                                Mf::Dot4I8Packed => "packed_char4",
2372                                Mf::Dot4U8Packed => "packed_uchar4",
2373                                _ => unreachable!(),
2374                            };
2375
2376                            return self.put_dot_product(
2377                                Reinterpreted::new(packed_type, arg),
2378                                Reinterpreted::new(packed_type, arg1.unwrap()),
2379                                4,
2380                                |writer, arg, index| {
2381                                    // MSL implicitly promotes these (signed or unsigned) chars to
2382                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2383                                    write!(writer.out, "{arg}[{index}]")?;
2384                                    Ok(())
2385                                },
2386                            );
2387                        } else {
2388                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2389                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2390                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2391                            let conversion = match fun {
2392                                Mf::Dot4I8Packed => "int",
2393                                Mf::Dot4U8Packed => "",
2394                                _ => unreachable!(),
2395                            };
2396
2397                            return self.put_dot_product(
2398                                arg,
2399                                arg1.unwrap(),
2400                                4,
2401                                |writer, arg, index| {
2402                                    write!(writer.out, "({conversion}(")?;
2403                                    writer.put_expression(arg, context, true)?;
2404                                    if index == 3 {
2405                                        write!(writer.out, ") >> 24)")?;
2406                                    } else {
2407                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2408                                    }
2409                                    Ok(())
2410                                },
2411                            );
2412                        }
2413                    }
2414                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2415                    Mf::Cross => "cross",
2416                    Mf::Distance => "distance",
2417                    Mf::Length if scalar_argument => "abs",
2418                    Mf::Length => "length",
2419                    Mf::Normalize => "normalize",
2420                    Mf::FaceForward => "faceforward",
2421                    Mf::Reflect => "reflect",
2422                    Mf::Refract => "refract",
2423                    // computational
2424                    Mf::Sign => match arg_type.scalar_kind() {
2425                        Some(crate::ScalarKind::Sint) => {
2426                            return self.put_isign(arg, context);
2427                        }
2428                        _ => "sign",
2429                    },
2430                    Mf::Fma => "fma",
2431                    Mf::Mix => "mix",
2432                    Mf::Step => "step",
2433                    Mf::SmoothStep => "smoothstep",
2434                    Mf::Sqrt => "sqrt",
2435                    Mf::InverseSqrt => "rsqrt",
2436                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2437                    Mf::Transpose => "transpose",
2438                    Mf::Determinant => "determinant",
2439                    Mf::QuantizeToF16 => "",
2440                    // bits
2441                    Mf::CountTrailingZeros => "ctz",
2442                    Mf::CountLeadingZeros => "clz",
2443                    Mf::CountOneBits => "popcount",
2444                    Mf::ReverseBits => "reverse_bits",
2445                    Mf::ExtractBits => "",
2446                    Mf::InsertBits => "",
2447                    Mf::FirstTrailingBit => "",
2448                    Mf::FirstLeadingBit => "",
2449                    // data packing
2450                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2451                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2452                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2453                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2454                    Mf::Pack2x16float => "",
2455                    Mf::Pack4xI8 => "",
2456                    Mf::Pack4xU8 => "",
2457                    Mf::Pack4xI8Clamp => "",
2458                    Mf::Pack4xU8Clamp => "",
2459                    // data unpacking
2460                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2461                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2462                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2463                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2464                    Mf::Unpack2x16float => "",
2465                    Mf::Unpack4xI8 => "",
2466                    Mf::Unpack4xU8 => "",
2467                };
2468
2469                match fun {
2470                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2471                        // reverse_bits is listed as requiring MSL 2.1 but that
2472                        // is a copy/paste error. Looking at previous snapshots
2473                        // on web.archive.org it's present in MSL 1.2.
2474                        //
2475                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2476                        // also talks about MSL 1.2 adding "New integer
2477                        // functions to extract, insert, and reverse bits, as
2478                        // described in Integer Functions."
2479                        if context.lang_version < (1, 2) {
2480                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2481                        }
2482                    }
2483                    _ => {}
2484                }
2485
2486                match fun {
2487                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2488                        write!(self.out, "{ABS_FUNCTION}(")?;
2489                        self.put_expression(arg, context, true)?;
2490                        write!(self.out, ")")?;
2491                    }
2492                    Mf::Distance if scalar_argument => {
2493                        write!(self.out, "{NAMESPACE}::abs(")?;
2494                        self.put_expression(arg, context, false)?;
2495                        write!(self.out, " - ")?;
2496                        self.put_expression(arg1.unwrap(), context, false)?;
2497                        write!(self.out, ")")?;
2498                    }
2499                    Mf::FirstTrailingBit => {
2500                        let scalar = context.resolve_type(arg).scalar().unwrap();
2501                        let constant = scalar.width * 8 + 1;
2502
2503                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2504                        self.put_expression(arg, context, true)?;
2505                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2506                    }
2507                    Mf::FirstLeadingBit => {
2508                        let inner = context.resolve_type(arg);
2509                        let scalar = inner.scalar().unwrap();
2510                        let constant = scalar.width * 8 - 1;
2511
2512                        write!(
2513                            self.out,
2514                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2515                        )?;
2516
2517                        if scalar.kind == crate::ScalarKind::Sint {
2518                            write!(self.out, "{NAMESPACE}::select(")?;
2519                            self.put_expression(arg, context, true)?;
2520                            write!(self.out, ", ~")?;
2521                            self.put_expression(arg, context, true)?;
2522                            write!(self.out, ", ")?;
2523                            self.put_expression(arg, context, true)?;
2524                            write!(self.out, " < 0)")?;
2525                        } else {
2526                            self.put_expression(arg, context, true)?;
2527                        }
2528
2529                        write!(self.out, "), ")?;
2530
2531                        // or metal will complain that select is ambiguous
2532                        match *inner {
2533                            crate::TypeInner::Vector { size, scalar } => {
2534                                let size = common::vector_size_str(size);
2535                                let name = scalar.to_msl_name();
2536                                write!(self.out, "{name}{size}")?;
2537                            }
2538                            crate::TypeInner::Scalar(scalar) => {
2539                                let name = scalar.to_msl_name();
2540                                write!(self.out, "{name}")?;
2541                            }
2542                            _ => (),
2543                        }
2544
2545                        write!(self.out, "(-1), ")?;
2546                        self.put_expression(arg, context, true)?;
2547                        write!(self.out, " == 0 || ")?;
2548                        self.put_expression(arg, context, true)?;
2549                        write!(self.out, " == -1)")?;
2550                    }
2551                    Mf::Unpack2x16float => {
2552                        write!(self.out, "float2(as_type<half2>(")?;
2553                        self.put_expression(arg, context, false)?;
2554                        write!(self.out, "))")?;
2555                    }
2556                    Mf::Pack2x16float => {
2557                        write!(self.out, "as_type<uint>(half2(")?;
2558                        self.put_expression(arg, context, false)?;
2559                        write!(self.out, "))")?;
2560                    }
2561                    Mf::ExtractBits => {
2562                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2563                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2564                        // will return out-of-spec values if the extracted range is not within the bit width.
2565                        //
2566                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2567                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2568                        //
2569                        // w = sizeof(x) * 8
2570                        // o = min(offset, w)
2571                        // tmp = w - o
2572                        // c = min(count, tmp)
2573                        //
2574                        // bitfieldExtract(x, o, c)
2575                        //
2576                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2577
2578                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2579
2580                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2581                        self.put_expression(arg, context, true)?;
2582                        write!(self.out, ", {NAMESPACE}::min(")?;
2583                        self.put_expression(arg1.unwrap(), context, true)?;
2584                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2585                        self.put_expression(arg2.unwrap(), context, true)?;
2586                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2587                        self.put_expression(arg1.unwrap(), context, true)?;
2588                        write!(self.out, ", {scalar_bits}u)))")?;
2589                    }
2590                    Mf::InsertBits => {
2591                        // The behavior of InsertBits has the same issue as ExtractBits.
2592                        //
2593                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2594
2595                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2596
2597                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2598                        self.put_expression(arg, context, true)?;
2599                        write!(self.out, ", ")?;
2600                        self.put_expression(arg1.unwrap(), context, true)?;
2601                        write!(self.out, ", {NAMESPACE}::min(")?;
2602                        self.put_expression(arg2.unwrap(), context, true)?;
2603                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2604                        self.put_expression(arg3.unwrap(), context, true)?;
2605                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2606                        self.put_expression(arg2.unwrap(), context, true)?;
2607                        write!(self.out, ", {scalar_bits}u)))")?;
2608                    }
2609                    Mf::Radians => {
2610                        write!(self.out, "((")?;
2611                        self.put_expression(arg, context, false)?;
2612                        write!(self.out, ") * 0.017453292519943295474)")?;
2613                    }
2614                    Mf::Degrees => {
2615                        write!(self.out, "((")?;
2616                        self.put_expression(arg, context, false)?;
2617                        write!(self.out, ") * 57.295779513082322865)")?;
2618                    }
2619                    Mf::Modf | Mf::Frexp => {
2620                        write!(self.out, "{fun_name}")?;
2621                        self.put_call_parameters(iter::once(arg), context)?;
2622                    }
2623                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2624                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2625                    Mf::Pack4xI8Clamp => {
2626                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2627                    }
2628                    Mf::Pack4xU8Clamp => {
2629                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2630                    }
2631                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2632                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2633                            "u"
2634                        } else {
2635                            ""
2636                        };
2637
2638                        if context.lang_version >= (2, 1) {
2639                            // Metal uses little endian byte order, which matches what WGSL expects here.
2640                            write!(
2641                                self.out,
2642                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2643                            )?;
2644                            self.put_expression(arg, context, true)?;
2645                            write!(self.out, "))")?;
2646                        } else {
2647                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2648                            write!(self.out, "({sign_prefix}int4(")?;
2649                            self.put_expression(arg, context, true)?;
2650                            write!(self.out, ", ")?;
2651                            self.put_expression(arg, context, true)?;
2652                            write!(self.out, " >> 8, ")?;
2653                            self.put_expression(arg, context, true)?;
2654                            write!(self.out, " >> 16, ")?;
2655                            self.put_expression(arg, context, true)?;
2656                            write!(self.out, " >> 24) << 24 >> 24)")?;
2657                        }
2658                    }
2659                    Mf::QuantizeToF16 => {
2660                        match *context.resolve_type(arg) {
2661                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2662                            crate::TypeInner::Vector { size, .. } => write!(
2663                                self.out,
2664                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2665                                size = common::vector_size_str(size),
2666                            )?,
2667                            _ => unreachable!(
2668                                "Correct TypeInner for QuantizeToF16 should be already validated"
2669                            ),
2670                        };
2671
2672                        self.put_expression(arg, context, true)?;
2673                        write!(self.out, "))")?;
2674                    }
2675                    _ => {
2676                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2677                        self.put_call_parameters(
2678                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2679                            context,
2680                        )?;
2681                    }
2682                }
2683            }
2684            crate::Expression::As {
2685                expr,
2686                kind,
2687                convert,
2688            } => match *context.resolve_type(expr) {
2689                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2690                    if src.kind == crate::ScalarKind::Float
2691                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2692                        && convert.is_some()
2693                    {
2694                        // Use helper functions for float to int casts in order to avoid
2695                        // undefined behaviour when value is out of range for the target
2696                        // type.
2697                        let fun_name = match (kind, convert) {
2698                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2699                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2700                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2701                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2702                            _ => unreachable!(),
2703                        };
2704                        write!(self.out, "{fun_name}(")?;
2705                        self.put_expression(expr, context, true)?;
2706                        write!(self.out, ")")?;
2707                    } else {
2708                        let target_scalar = crate::Scalar {
2709                            kind,
2710                            width: convert.unwrap_or(src.width),
2711                        };
2712                        let op = match convert {
2713                            Some(_) => "static_cast",
2714                            None => "as_type",
2715                        };
2716                        write!(self.out, "{op}<")?;
2717                        match *context.resolve_type(expr) {
2718                            crate::TypeInner::Vector { size, .. } => {
2719                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2720                            }
2721                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2722                        };
2723                        write!(self.out, ">(")?;
2724                        self.put_expression(expr, context, true)?;
2725                        write!(self.out, ")")?;
2726                    }
2727                }
2728                crate::TypeInner::Matrix {
2729                    columns,
2730                    rows,
2731                    scalar,
2732                } => {
2733                    let target_scalar = crate::Scalar {
2734                        kind,
2735                        width: convert.unwrap_or(scalar.width),
2736                    };
2737                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2738                    write!(self.out, "(")?;
2739                    self.put_expression(expr, context, true)?;
2740                    write!(self.out, ")")?;
2741                }
2742                ref ty => {
2743                    return Err(Error::GenericValidation(format!(
2744                        "Unsupported type for As: {ty:?}"
2745                    )))
2746                }
2747            },
2748            // has to be a named expression
2749            crate::Expression::CallResult(_)
2750            | crate::Expression::AtomicResult { .. }
2751            | crate::Expression::WorkGroupUniformLoadResult { .. }
2752            | crate::Expression::SubgroupBallotResult
2753            | crate::Expression::SubgroupOperationResult { .. }
2754            | crate::Expression::RayQueryProceedResult => {
2755                unreachable!()
2756            }
2757            crate::Expression::ArrayLength(expr) => {
2758                // Find the global to which the array belongs.
2759                let global = match context.function.expressions[expr] {
2760                    crate::Expression::AccessIndex { base, .. } => {
2761                        match context.function.expressions[base] {
2762                            crate::Expression::GlobalVariable(handle) => handle,
2763                            ref ex => {
2764                                return Err(Error::GenericValidation(format!(
2765                                    "Expected global variable in AccessIndex, got {ex:?}"
2766                                )))
2767                            }
2768                        }
2769                    }
2770                    crate::Expression::GlobalVariable(handle) => handle,
2771                    ref ex => {
2772                        return Err(Error::GenericValidation(format!(
2773                            "Unexpected expression in ArrayLength, got {ex:?}"
2774                        )))
2775                    }
2776                };
2777
2778                if !is_scoped {
2779                    write!(self.out, "(")?;
2780                }
2781                write!(self.out, "1 + ")?;
2782                self.put_dynamic_array_max_index(global, context)?;
2783                if !is_scoped {
2784                    write!(self.out, ")")?;
2785                }
2786            }
2787            crate::Expression::RayQueryVertexPositions { .. } => {
2788                unimplemented!()
2789            }
2790            crate::Expression::RayQueryGetIntersection {
2791                query,
2792                committed: _,
2793            } => {
2794                if context.lang_version < (2, 4) {
2795                    return Err(Error::UnsupportedRayTracing);
2796                }
2797
2798                let ty = context.module.special_types.ray_intersection.unwrap();
2799                let type_name = &self.names[&NameKey::Type(ty)];
2800                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2801                self.put_expression(query, context, true)?;
2802                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2803                let fields = [
2804                    "distance",
2805                    "user_instance_id", // req Metal 2.4
2806                    "instance_id",
2807                    "", // SBT offset
2808                    "geometry_id",
2809                    "primitive_id",
2810                    "triangle_barycentric_coord",
2811                    "triangle_front_facing",
2812                    "",                          // padding
2813                    "object_to_world_transform", // req Metal 2.4
2814                    "world_to_object_transform", // req Metal 2.4
2815                ];
2816                for field in fields {
2817                    write!(self.out, ", ")?;
2818                    if field.is_empty() {
2819                        write!(self.out, "{{}}")?;
2820                    } else {
2821                        self.put_expression(query, context, true)?;
2822                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2823                    }
2824                }
2825                write!(self.out, "}}")?;
2826            }
2827        }
2828        Ok(())
2829    }
2830
2831    /// Emits code for a binary operation, using the provided callback to emit
2832    /// the left and right operands.
2833    fn put_binop<F>(
2834        &mut self,
2835        op: crate::BinaryOperator,
2836        left: Handle<crate::Expression>,
2837        right: Handle<crate::Expression>,
2838        context: &ExpressionContext,
2839        is_scoped: bool,
2840        put_expression: &F,
2841    ) -> BackendResult
2842    where
2843        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2844    {
2845        let op_str = back::binary_operation_str(op);
2846
2847        if !is_scoped {
2848            write!(self.out, "(")?;
2849        }
2850
2851        // Cast packed vector if necessary
2852        // Packed vector - matrix multiplications are not supported in MSL
2853        if op == crate::BinaryOperator::Multiply
2854            && matches!(
2855                context.resolve_type(right),
2856                &crate::TypeInner::Matrix { .. }
2857            )
2858        {
2859            self.put_wrapped_expression_for_packed_vec3_access(
2860                left,
2861                context,
2862                false,
2863                put_expression,
2864            )?;
2865        } else {
2866            put_expression(self, left, context, false)?;
2867        }
2868
2869        write!(self.out, " {op_str} ")?;
2870
2871        // See comment above
2872        if op == crate::BinaryOperator::Multiply
2873            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2874        {
2875            self.put_wrapped_expression_for_packed_vec3_access(
2876                right,
2877                context,
2878                false,
2879                put_expression,
2880            )?;
2881        } else {
2882            put_expression(self, right, context, false)?;
2883        }
2884
2885        if !is_scoped {
2886            write!(self.out, ")")?;
2887        }
2888
2889        Ok(())
2890    }
2891
2892    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2893    fn put_wrapped_expression_for_packed_vec3_access<F>(
2894        &mut self,
2895        expr_handle: Handle<crate::Expression>,
2896        context: &ExpressionContext,
2897        is_scoped: bool,
2898        put_expression: &F,
2899    ) -> BackendResult
2900    where
2901        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2902    {
2903        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2904            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2905            put_expression(self, expr_handle, context, is_scoped)?;
2906            write!(self.out, ")")?;
2907        } else {
2908            put_expression(self, expr_handle, context, is_scoped)?;
2909        }
2910        Ok(())
2911    }
2912
2913    /// Emits code for an expression using the provided callback, wrapping the
2914    /// result in a bitcast to the type `cast_to`.
2915    fn put_bitcasted_expression<F>(
2916        &mut self,
2917        cast_to: &crate::TypeInner,
2918        context: &ExpressionContext,
2919        put_expression: &F,
2920    ) -> BackendResult
2921    where
2922        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2923    {
2924        write!(self.out, "as_type<")?;
2925        match *cast_to {
2926            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2927            crate::TypeInner::Vector { size, scalar } => {
2928                put_numeric_type(&mut self.out, scalar, &[size])?
2929            }
2930            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2931        };
2932        write!(self.out, ">(")?;
2933        put_expression(self, context, true)?;
2934        write!(self.out, ")")?;
2935
2936        Ok(())
2937    }
2938
2939    /// Write a `GuardedIndex` as a Metal expression.
2940    fn put_index(
2941        &mut self,
2942        index: index::GuardedIndex,
2943        context: &ExpressionContext,
2944        is_scoped: bool,
2945    ) -> BackendResult {
2946        match index {
2947            index::GuardedIndex::Expression(expr) => {
2948                self.put_expression(expr, context, is_scoped)?
2949            }
2950            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2951        }
2952        Ok(())
2953    }
2954
2955    /// Emit an index bounds check condition for `chain`, if required.
2956    ///
2957    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
2958    /// operating either on a pointer to a value, or on a value directly. If we cannot
2959    /// statically determine that all indexing operations in `chain` are within
2960    /// bounds, then write a conditional expression to check them dynamically,
2961    /// and return true. All accesses in the chain are checked by the generated
2962    /// expression.
2963    ///
2964    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
2965    ///
2966    /// The text written is of the form:
2967    ///
2968    /// ```ignore
2969    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
2970    /// ```
2971    ///
2972    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
2973    /// statements, presumably these arguments start an indented `if` statement; for
2974    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
2975    /// expression. In either case, what is written is not a complete syntactic structure
2976    /// in its own right, and the caller will have to finish it off if we return `true`.
2977    ///
2978    /// If no expression is written, return false.
2979    ///
2980    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
2981    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
2982    /// [`Store`]: crate::Statement::Store
2983    /// [`Load`]: crate::Expression::Load
2984    #[allow(unused_variables)]
2985    fn put_bounds_checks(
2986        &mut self,
2987        chain: Handle<crate::Expression>,
2988        context: &ExpressionContext,
2989        level: back::Level,
2990        prefix: &'static str,
2991    ) -> Result<bool, Error> {
2992        let mut check_written = false;
2993
2994        // Iterate over the access chain, handling each required bounds check.
2995        for item in context.bounds_check_iter(chain) {
2996            let BoundsCheck {
2997                base,
2998                index,
2999                length,
3000            } = item;
3001
3002            if check_written {
3003                write!(self.out, " && ")?;
3004            } else {
3005                write!(self.out, "{level}{prefix}")?;
3006                check_written = true;
3007            }
3008
3009            // Check that the index falls within bounds. Do this with a single
3010            // comparison, by casting the index to `uint` first, so that negative
3011            // indices become large positive values.
3012            write!(self.out, "uint(")?;
3013            self.put_index(index, context, true)?;
3014            self.out.write_str(") < ")?;
3015            match length {
3016                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3017                index::IndexableLength::Dynamic => {
3018                    let global = context.function.originating_global(base).ok_or_else(|| {
3019                        Error::GenericValidation("Could not find originating global".into())
3020                    })?;
3021                    write!(self.out, "1 + ")?;
3022                    self.put_dynamic_array_max_index(global, context)?
3023                }
3024            }
3025        }
3026
3027        Ok(check_written)
3028    }
3029
3030    /// Write the access chain `chain`.
3031    ///
3032    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3033    /// operating either on a pointer to a value, or on a value directly.
3034    ///
3035    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3036    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3037    /// that must be handled in the caller.
3038    ///
3039    /// Handle the entire chain, recursing back into `put_expression` only for index
3040    /// expressions and the base expression that originates the pointer or composite value
3041    /// being accessed. This allows `put_expression` to assume that any `Access` or
3042    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3043    /// `ReadZeroSkipWrite` checks.
3044    ///
3045    /// [`Access`]: crate::Expression::Access
3046    /// [`AccessIndex`]: crate::Expression::AccessIndex
3047    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3048    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3049    fn put_access_chain(
3050        &mut self,
3051        chain: Handle<crate::Expression>,
3052        policy: index::BoundsCheckPolicy,
3053        context: &ExpressionContext,
3054    ) -> BackendResult {
3055        match context.function.expressions[chain] {
3056            crate::Expression::Access { base, index } => {
3057                let mut base_ty = context.resolve_type(base);
3058
3059                // Look through any pointers to see what we're really indexing.
3060                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3061                    base_ty = &context.module.types[base].inner;
3062                }
3063
3064                self.put_subscripted_access_chain(
3065                    base,
3066                    base_ty,
3067                    index::GuardedIndex::Expression(index),
3068                    policy,
3069                    context,
3070                )?;
3071            }
3072            crate::Expression::AccessIndex { base, index } => {
3073                let base_resolution = &context.info[base].ty;
3074                let mut base_ty = base_resolution.inner_with(&context.module.types);
3075                let mut base_ty_handle = base_resolution.handle();
3076
3077                // Look through any pointers to see what we're really indexing.
3078                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3079                    base_ty = &context.module.types[base].inner;
3080                    base_ty_handle = Some(base);
3081                }
3082
3083                // Handle structs and anything else that can use `.x` syntax here, so
3084                // `put_subscripted_access_chain` won't have to handle the absurd case of
3085                // indexing a struct with an expression.
3086                match *base_ty {
3087                    crate::TypeInner::Struct { .. } => {
3088                        let base_ty = base_ty_handle.unwrap();
3089                        self.put_access_chain(base, policy, context)?;
3090                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3091                        write!(self.out, ".{name}")?;
3092                    }
3093                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3094                        self.put_access_chain(base, policy, context)?;
3095                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3096                        // however array indexing is
3097                        if context.get_packed_vec_kind(base).is_some() {
3098                            write!(self.out, "[{index}]")?;
3099                        } else {
3100                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3101                        }
3102                    }
3103                    _ => {
3104                        self.put_subscripted_access_chain(
3105                            base,
3106                            base_ty,
3107                            index::GuardedIndex::Known(index),
3108                            policy,
3109                            context,
3110                        )?;
3111                    }
3112                }
3113            }
3114            _ => self.put_expression(chain, context, false)?,
3115        }
3116
3117        Ok(())
3118    }
3119
3120    /// Write a `[]`-style access of `base` by `index`.
3121    ///
3122    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3123    /// values within bounds.
3124    ///
3125    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3126    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3127    /// removed. Our callers often already have this handy.
3128    ///
3129    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3130    /// referencing vector components by name.
3131    ///
3132    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3133    /// [`Array`]: crate::TypeInner::Array
3134    /// [`Vector`]: crate::TypeInner::Vector
3135    /// [`Pointer`]: crate::TypeInner::Pointer
3136    fn put_subscripted_access_chain(
3137        &mut self,
3138        base: Handle<crate::Expression>,
3139        base_ty: &crate::TypeInner,
3140        index: index::GuardedIndex,
3141        policy: index::BoundsCheckPolicy,
3142        context: &ExpressionContext,
3143    ) -> BackendResult {
3144        let accessing_wrapped_array = match *base_ty {
3145            crate::TypeInner::Array {
3146                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3147                ..
3148            } => true,
3149            _ => false,
3150        };
3151        let accessing_wrapped_binding_array =
3152            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3153
3154        self.put_access_chain(base, policy, context)?;
3155        if accessing_wrapped_array {
3156            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3157        }
3158        write!(self.out, "[")?;
3159
3160        // Decide whether this index needs to be clamped to fall within range.
3161        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3162            context.access_needs_check(base, index)
3163        } else {
3164            None
3165        };
3166        if let Some(limit) = restriction_needed {
3167            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3168            self.put_index(index, context, true)?;
3169            write!(self.out, "), ")?;
3170            match limit {
3171                index::IndexableLength::Known(limit) => {
3172                    write!(self.out, "{}u", limit - 1)?;
3173                }
3174                index::IndexableLength::Dynamic => {
3175                    let global = context.function.originating_global(base).ok_or_else(|| {
3176                        Error::GenericValidation("Could not find originating global".into())
3177                    })?;
3178                    self.put_dynamic_array_max_index(global, context)?;
3179                }
3180            }
3181            write!(self.out, ")")?;
3182        } else {
3183            self.put_index(index, context, true)?;
3184        }
3185
3186        write!(self.out, "]")?;
3187
3188        if accessing_wrapped_binding_array {
3189            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3190        }
3191
3192        Ok(())
3193    }
3194
3195    fn put_load(
3196        &mut self,
3197        pointer: Handle<crate::Expression>,
3198        context: &ExpressionContext,
3199        is_scoped: bool,
3200    ) -> BackendResult {
3201        // Since access chains never cross between address spaces, we can just
3202        // check the index bounds check policy once at the top.
3203        let policy = context.choose_bounds_check_policy(pointer);
3204        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3205            && self.put_bounds_checks(
3206                pointer,
3207                context,
3208                back::Level(0),
3209                if is_scoped { "" } else { "(" },
3210            )?
3211        {
3212            write!(self.out, " ? ")?;
3213            self.put_unchecked_load(pointer, policy, context)?;
3214            write!(self.out, " : DefaultConstructible()")?;
3215
3216            if !is_scoped {
3217                write!(self.out, ")")?;
3218            }
3219        } else {
3220            self.put_unchecked_load(pointer, policy, context)?;
3221        }
3222
3223        Ok(())
3224    }
3225
3226    fn put_unchecked_load(
3227        &mut self,
3228        pointer: Handle<crate::Expression>,
3229        policy: index::BoundsCheckPolicy,
3230        context: &ExpressionContext,
3231    ) -> BackendResult {
3232        let is_atomic_pointer = context
3233            .resolve_type(pointer)
3234            .is_atomic_pointer(&context.module.types);
3235
3236        if is_atomic_pointer {
3237            write!(
3238                self.out,
3239                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3240            )?;
3241            self.put_access_chain(pointer, policy, context)?;
3242            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3243        } else {
3244            // We don't do any dereferencing with `*` here as pointer arguments to functions
3245            // are done by `&` references and not `*` pointers. These do not need to be
3246            // dereferenced.
3247            self.put_access_chain(pointer, policy, context)?;
3248        }
3249
3250        Ok(())
3251    }
3252
3253    fn put_return_value(
3254        &mut self,
3255        level: back::Level,
3256        expr_handle: Handle<crate::Expression>,
3257        result_struct: Option<&str>,
3258        context: &ExpressionContext,
3259    ) -> BackendResult {
3260        match result_struct {
3261            Some(struct_name) => {
3262                let mut has_point_size = false;
3263                let result_ty = context.function.result.as_ref().unwrap().ty;
3264                match context.module.types[result_ty].inner {
3265                    crate::TypeInner::Struct { ref members, .. } => {
3266                        let tmp = "_tmp";
3267                        write!(self.out, "{level}const auto {tmp} = ")?;
3268                        self.put_expression(expr_handle, context, true)?;
3269                        writeln!(self.out, ";")?;
3270                        write!(self.out, "{level}return {struct_name} {{")?;
3271
3272                        let mut is_first = true;
3273
3274                        for (index, member) in members.iter().enumerate() {
3275                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3276                                member.binding
3277                            {
3278                                has_point_size = true;
3279                                if !context.pipeline_options.allow_and_force_point_size {
3280                                    continue;
3281                                }
3282                            }
3283
3284                            let comma = if is_first { "" } else { "," };
3285                            is_first = false;
3286                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3287                            // HACK: we are forcefully deduplicating the expression here
3288                            // to convert from a wrapped struct to a raw array, e.g.
3289                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3290                            if let crate::TypeInner::Array {
3291                                size: crate::ArraySize::Constant(size),
3292                                ..
3293                            } = context.module.types[member.ty].inner
3294                            {
3295                                write!(self.out, "{comma} {{")?;
3296                                for j in 0..size.get() {
3297                                    if j != 0 {
3298                                        write!(self.out, ",")?;
3299                                    }
3300                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3301                                }
3302                                write!(self.out, "}}")?;
3303                            } else {
3304                                write!(self.out, "{comma} {tmp}.{name}")?;
3305                            }
3306                        }
3307                    }
3308                    _ => {
3309                        write!(self.out, "{level}return {struct_name} {{ ")?;
3310                        self.put_expression(expr_handle, context, true)?;
3311                    }
3312                }
3313
3314                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3315                    let stage = context.module.entry_points[ep_index as usize].stage;
3316                    if context.pipeline_options.allow_and_force_point_size
3317                        && stage == crate::ShaderStage::Vertex
3318                        && !has_point_size
3319                    {
3320                        // point size was injected and comes last
3321                        write!(self.out, ", 1.0")?;
3322                    }
3323                }
3324                write!(self.out, " }}")?;
3325            }
3326            None => {
3327                write!(self.out, "{level}return ")?;
3328                self.put_expression(expr_handle, context, true)?;
3329            }
3330        }
3331        writeln!(self.out, ";")?;
3332        Ok(())
3333    }
3334
3335    /// Helper method used to find which expressions of a given function require baking
3336    ///
3337    /// # Notes
3338    /// This function overwrites the contents of `self.need_bake_expressions`
3339    fn update_expressions_to_bake(
3340        &mut self,
3341        func: &crate::Function,
3342        info: &valid::FunctionInfo,
3343        context: &ExpressionContext,
3344    ) {
3345        use crate::Expression;
3346        self.need_bake_expressions.clear();
3347
3348        for (expr_handle, expr) in func.expressions.iter() {
3349            // Expressions whose reference count is above the
3350            // threshold should always be stored in temporaries.
3351            let expr_info = &info[expr_handle];
3352            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3353            if min_ref_count <= expr_info.ref_count {
3354                self.need_bake_expressions.insert(expr_handle);
3355            } else {
3356                match expr_info.ty {
3357                    // force ray desc to be baked: it's used multiple times internally
3358                    TypeResolution::Handle(h)
3359                        if Some(h) == context.module.special_types.ray_desc =>
3360                    {
3361                        self.need_bake_expressions.insert(expr_handle);
3362                    }
3363                    _ => {}
3364                }
3365            }
3366
3367            if let Expression::Math {
3368                fun,
3369                arg,
3370                arg1,
3371                arg2,
3372                ..
3373            } = *expr
3374            {
3375                match fun {
3376                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3377                    // works on floating-point vectors, so we emit inline code for
3378                    // integer vector `dot` calls. But that code uses each argument `N`
3379                    // times, once for each component (see `put_dot_product`), so to
3380                    // avoid duplicated evaluation, we must bake integer operands.
3381                    // This applies both when using the polyfill (because of the duplicate
3382                    // evaluation issue) and when we don't use the polyfill (because we
3383                    // need them to be emitted before casting to packed chars -- see the
3384                    // comment at the call to `put_casting_to_packed_chars`).
3385                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3386                        self.need_bake_expressions.insert(arg);
3387                        self.need_bake_expressions.insert(arg1.unwrap());
3388                    }
3389                    crate::MathFunction::FirstLeadingBit => {
3390                        self.need_bake_expressions.insert(arg);
3391                    }
3392                    crate::MathFunction::Pack4xI8
3393                    | crate::MathFunction::Pack4xU8
3394                    | crate::MathFunction::Pack4xI8Clamp
3395                    | crate::MathFunction::Pack4xU8Clamp
3396                    | crate::MathFunction::Unpack4xI8
3397                    | crate::MathFunction::Unpack4xU8 => {
3398                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3399                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3400                        if context.lang_version < (2, 1) {
3401                            self.need_bake_expressions.insert(arg);
3402                        }
3403                    }
3404                    crate::MathFunction::ExtractBits => {
3405                        // Only argument 1 is re-used.
3406                        self.need_bake_expressions.insert(arg1.unwrap());
3407                    }
3408                    crate::MathFunction::InsertBits => {
3409                        // Only argument 2 is re-used.
3410                        self.need_bake_expressions.insert(arg2.unwrap());
3411                    }
3412                    crate::MathFunction::Sign => {
3413                        // WGSL's `sign` function works also on signed ints, but Metal's only
3414                        // works on floating points, so we emit inline code for integer `sign`
3415                        // calls. But that code uses each argument 2 times (see `put_isign`),
3416                        // so to avoid duplicated evaluation, we must bake the argument.
3417                        let inner = context.resolve_type(expr_handle);
3418                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3419                            self.need_bake_expressions.insert(arg);
3420                        }
3421                    }
3422                    _ => {}
3423                }
3424            }
3425        }
3426    }
3427
3428    fn start_baking_expression(
3429        &mut self,
3430        handle: Handle<crate::Expression>,
3431        context: &ExpressionContext,
3432        name: &str,
3433    ) -> BackendResult {
3434        match context.info[handle].ty {
3435            TypeResolution::Handle(ty_handle) => {
3436                let ty_name = TypeContext {
3437                    handle: ty_handle,
3438                    gctx: context.module.to_ctx(),
3439                    names: &self.names,
3440                    access: crate::StorageAccess::empty(),
3441                    first_time: false,
3442                };
3443                write!(self.out, "{ty_name}")?;
3444            }
3445            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3446                put_numeric_type(&mut self.out, scalar, &[])?;
3447            }
3448            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3449                put_numeric_type(&mut self.out, scalar, &[size])?;
3450            }
3451            TypeResolution::Value(crate::TypeInner::Matrix {
3452                columns,
3453                rows,
3454                scalar,
3455            }) => {
3456                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3457            }
3458            TypeResolution::Value(ref other) => {
3459                log::warn!("Type {other:?} isn't a known local"); //TEMP!
3460                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3461            }
3462        }
3463
3464        //TODO: figure out the naming scheme that wouldn't collide with user names.
3465        write!(self.out, " {name} = ")?;
3466
3467        Ok(())
3468    }
3469
3470    /// Cache a clamped level of detail value, if necessary.
3471    ///
3472    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3473    /// properly clamped level of detail value both in the access itself, and
3474    /// for fetching the size of the requested MIP level, needed to clamp the
3475    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3476    /// it in a temporary variable, as part of the [`Emit`] statement covering
3477    /// the [`ImageLoad`] expression.
3478    ///
3479    /// [`ImageLoad`]: crate::Expression::ImageLoad
3480    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3481    /// [`Emit`]: crate::Statement::Emit
3482    fn put_cache_restricted_level(
3483        &mut self,
3484        load: Handle<crate::Expression>,
3485        image: Handle<crate::Expression>,
3486        mip_level: Option<Handle<crate::Expression>>,
3487        indent: back::Level,
3488        context: &StatementContext,
3489    ) -> BackendResult {
3490        // Does this image access actually require (or even permit) a
3491        // level-of-detail, and does the policy require us to restrict it?
3492        let level_of_detail = match mip_level {
3493            Some(level) => level,
3494            None => return Ok(()),
3495        };
3496
3497        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3498            || !context.expression.image_needs_lod(image)
3499        {
3500            return Ok(());
3501        }
3502
3503        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3504        self.put_restricted_scalar_image_index(
3505            image,
3506            level_of_detail,
3507            "get_num_mip_levels",
3508            &context.expression,
3509        )?;
3510        writeln!(self.out, ";")?;
3511
3512        Ok(())
3513    }
3514
3515    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3516    ///
3517    /// Caches the results in temporary variables (whose names are derived from
3518    /// the original variable names). This caching avoids the need to redo the
3519    /// casting for each vector component when emitting the dot product.
3520    fn put_casting_to_packed_chars(
3521        &mut self,
3522        fun: crate::MathFunction,
3523        arg0: Handle<crate::Expression>,
3524        arg1: Handle<crate::Expression>,
3525        indent: back::Level,
3526        context: &StatementContext<'_>,
3527    ) -> Result<(), Error> {
3528        let packed_type = match fun {
3529            crate::MathFunction::Dot4I8Packed => "packed_char4",
3530            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3531            _ => unreachable!(),
3532        };
3533
3534        for arg in [arg0, arg1] {
3535            write!(
3536                self.out,
3537                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3538                Reinterpreted::new(packed_type, arg)
3539            )?;
3540            self.put_expression(arg, &context.expression, true)?;
3541            writeln!(self.out, ");")?;
3542        }
3543
3544        Ok(())
3545    }
3546
3547    fn put_block(
3548        &mut self,
3549        level: back::Level,
3550        statements: &[crate::Statement],
3551        context: &StatementContext,
3552    ) -> BackendResult {
3553        // Add to the set in order to track the stack size.
3554        #[cfg(test)]
3555        self.put_block_stack_pointers
3556            .insert(ptr::from_ref(&level).cast());
3557
3558        for statement in statements {
3559            log::trace!("statement[{}] {:?}", level.0, statement);
3560            match *statement {
3561                crate::Statement::Emit(ref range) => {
3562                    for handle in range.clone() {
3563                        use crate::MathFunction as Mf;
3564
3565                        match context.expression.function.expressions[handle] {
3566                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3567                            // may need to cache a clamped version of their level-of-detail argument.
3568                            crate::Expression::ImageLoad {
3569                                image,
3570                                level: mip_level,
3571                                ..
3572                            } => {
3573                                self.put_cache_restricted_level(
3574                                    handle, image, mip_level, level, context,
3575                                )?;
3576                            }
3577
3578                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3579                            // 2.1+ then we introduce two intermediate variables that recast the two
3580                            // arguments as packed (signed or unsigned) chars. The actual dot product
3581                            // is implemented in `Self::put_expression`, and it uses both of these
3582                            // intermediate variables multiple times. There's no danger that the
3583                            // original arguments get modified between the definition of these
3584                            // intermediate variables and the implementation of the actual dot
3585                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3586                            crate::Expression::Math {
3587                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3588                                arg,
3589                                arg1,
3590                                ..
3591                            } if context.expression.lang_version >= (2, 1) => {
3592                                self.put_casting_to_packed_chars(
3593                                    fun,
3594                                    arg,
3595                                    arg1.unwrap(),
3596                                    level,
3597                                    context,
3598                                )?;
3599                            }
3600
3601                            _ => (),
3602                        }
3603
3604                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3605                        let expr_name = if ptr_class.is_some() {
3606                            None // don't bake pointer expressions (just yet)
3607                        } else if let Some(name) =
3608                            context.expression.function.named_expressions.get(&handle)
3609                        {
3610                            // The `crate::Function::named_expressions` table holds
3611                            // expressions that should be saved in temporaries once they
3612                            // are `Emit`ted. We only add them to `self.named_expressions`
3613                            // when we reach the `Emit` that covers them, so that we don't
3614                            // try to use their names before we've actually initialized
3615                            // the temporary that holds them.
3616                            //
3617                            // Don't assume the names in `named_expressions` are unique,
3618                            // or even valid. Use the `Namer`.
3619                            Some(self.namer.call(name))
3620                        } else {
3621                            // If this expression is an index that we're going to first compare
3622                            // against a limit, and then actually use as an index, then we may
3623                            // want to cache it in a temporary, to avoid evaluating it twice.
3624                            let bake = if context.expression.guarded_indices.contains(handle) {
3625                                true
3626                            } else {
3627                                self.need_bake_expressions.contains(&handle)
3628                            };
3629
3630                            if bake {
3631                                Some(Baked(handle).to_string())
3632                            } else {
3633                                None
3634                            }
3635                        };
3636
3637                        if let Some(name) = expr_name {
3638                            write!(self.out, "{level}")?;
3639                            self.start_baking_expression(handle, &context.expression, &name)?;
3640                            self.put_expression(handle, &context.expression, true)?;
3641                            self.named_expressions.insert(handle, name);
3642                            writeln!(self.out, ";")?;
3643                        }
3644                    }
3645                }
3646                crate::Statement::Block(ref block) => {
3647                    if !block.is_empty() {
3648                        writeln!(self.out, "{level}{{")?;
3649                        self.put_block(level.next(), block, context)?;
3650                        writeln!(self.out, "{level}}}")?;
3651                    }
3652                }
3653                crate::Statement::If {
3654                    condition,
3655                    ref accept,
3656                    ref reject,
3657                } => {
3658                    write!(self.out, "{level}if (")?;
3659                    self.put_expression(condition, &context.expression, true)?;
3660                    writeln!(self.out, ") {{")?;
3661                    self.put_block(level.next(), accept, context)?;
3662                    if !reject.is_empty() {
3663                        writeln!(self.out, "{level}}} else {{")?;
3664                        self.put_block(level.next(), reject, context)?;
3665                    }
3666                    writeln!(self.out, "{level}}}")?;
3667                }
3668                crate::Statement::Switch {
3669                    selector,
3670                    ref cases,
3671                } => {
3672                    write!(self.out, "{level}switch(")?;
3673                    self.put_expression(selector, &context.expression, true)?;
3674                    writeln!(self.out, ") {{")?;
3675                    let lcase = level.next();
3676                    for case in cases.iter() {
3677                        match case.value {
3678                            crate::SwitchValue::I32(value) => {
3679                                write!(self.out, "{lcase}case {value}:")?;
3680                            }
3681                            crate::SwitchValue::U32(value) => {
3682                                write!(self.out, "{lcase}case {value}u:")?;
3683                            }
3684                            crate::SwitchValue::Default => {
3685                                write!(self.out, "{lcase}default:")?;
3686                            }
3687                        }
3688
3689                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3690                        if write_block_braces {
3691                            writeln!(self.out, " {{")?;
3692                        } else {
3693                            writeln!(self.out)?;
3694                        }
3695
3696                        self.put_block(lcase.next(), &case.body, context)?;
3697                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3698                        {
3699                            writeln!(self.out, "{}break;", lcase.next())?;
3700                        }
3701
3702                        if write_block_braces {
3703                            writeln!(self.out, "{lcase}}}")?;
3704                        }
3705                    }
3706                    writeln!(self.out, "{level}}}")?;
3707                }
3708                crate::Statement::Loop {
3709                    ref body,
3710                    ref continuing,
3711                    break_if,
3712                } => {
3713                    let force_loop_bound_statements =
3714                        self.gen_force_bounded_loop_statements(level, context);
3715                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3716                        .then(|| self.namer.call("loop_init"));
3717
3718                    if let Some((ref decl, _)) = force_loop_bound_statements {
3719                        writeln!(self.out, "{decl}")?;
3720                    }
3721                    if let Some(ref gate_name) = gate_name {
3722                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3723                    }
3724
3725                    writeln!(self.out, "{level}while(true) {{",)?;
3726                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3727                        writeln!(self.out, "{break_and_inc}")?;
3728                    }
3729                    if let Some(ref gate_name) = gate_name {
3730                        let lif = level.next();
3731                        let lcontinuing = lif.next();
3732                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3733                        self.put_block(lcontinuing, continuing, context)?;
3734                        if let Some(condition) = break_if {
3735                            write!(self.out, "{lcontinuing}if (")?;
3736                            self.put_expression(condition, &context.expression, true)?;
3737                            writeln!(self.out, ") {{")?;
3738                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3739                            writeln!(self.out, "{lcontinuing}}}")?;
3740                        }
3741                        writeln!(self.out, "{lif}}}")?;
3742                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3743                    }
3744                    self.put_block(level.next(), body, context)?;
3745
3746                    writeln!(self.out, "{level}}}")?;
3747                }
3748                crate::Statement::Break => {
3749                    writeln!(self.out, "{level}break;")?;
3750                }
3751                crate::Statement::Continue => {
3752                    writeln!(self.out, "{level}continue;")?;
3753                }
3754                crate::Statement::Return {
3755                    value: Some(expr_handle),
3756                } => {
3757                    self.put_return_value(
3758                        level,
3759                        expr_handle,
3760                        context.result_struct,
3761                        &context.expression,
3762                    )?;
3763                }
3764                crate::Statement::Return { value: None } => {
3765                    writeln!(self.out, "{level}return;")?;
3766                }
3767                crate::Statement::Kill => {
3768                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3769                }
3770                crate::Statement::ControlBarrier(flags)
3771                | crate::Statement::MemoryBarrier(flags) => {
3772                    self.write_barrier(flags, level)?;
3773                }
3774                crate::Statement::Store { pointer, value } => {
3775                    self.put_store(pointer, value, level, context)?
3776                }
3777                crate::Statement::ImageStore {
3778                    image,
3779                    coordinate,
3780                    array_index,
3781                    value,
3782                } => {
3783                    let address = TexelAddress {
3784                        coordinate,
3785                        array_index,
3786                        sample: None,
3787                        level: None,
3788                    };
3789                    self.put_image_store(level, image, &address, value, context)?
3790                }
3791                crate::Statement::Call {
3792                    function,
3793                    ref arguments,
3794                    result,
3795                } => {
3796                    write!(self.out, "{level}")?;
3797                    if let Some(expr) = result {
3798                        let name = Baked(expr).to_string();
3799                        self.start_baking_expression(expr, &context.expression, &name)?;
3800                        self.named_expressions.insert(expr, name);
3801                    }
3802                    let fun_name = &self.names[&NameKey::Function(function)];
3803                    write!(self.out, "{fun_name}(")?;
3804                    // first, write down the actual arguments
3805                    for (i, &handle) in arguments.iter().enumerate() {
3806                        if i != 0 {
3807                            write!(self.out, ", ")?;
3808                        }
3809                        self.put_expression(handle, &context.expression, true)?;
3810                    }
3811                    // follow-up with any global resources used
3812                    let mut separate = !arguments.is_empty();
3813                    let fun_info = &context.expression.mod_info[function];
3814                    let mut needs_buffer_sizes = false;
3815                    for (handle, var) in context.expression.module.global_variables.iter() {
3816                        if fun_info[handle].is_empty() {
3817                            continue;
3818                        }
3819                        if var.space.needs_pass_through() {
3820                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3821                            if separate {
3822                                write!(self.out, ", ")?;
3823                            } else {
3824                                separate = true;
3825                            }
3826                            write!(self.out, "{name}")?;
3827                        }
3828                        needs_buffer_sizes |=
3829                            needs_array_length(var.ty, &context.expression.module.types);
3830                    }
3831                    if needs_buffer_sizes {
3832                        if separate {
3833                            write!(self.out, ", ")?;
3834                        }
3835                        write!(self.out, "_buffer_sizes")?;
3836                    }
3837
3838                    // done
3839                    writeln!(self.out, ");")?;
3840                }
3841                crate::Statement::Atomic {
3842                    pointer,
3843                    ref fun,
3844                    value,
3845                    result,
3846                } => {
3847                    let context = &context.expression;
3848
3849                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
3850                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
3851                    // `Some`, we are not operating on a 64-bit value, and that if we are
3852                    // operating on a 64-bit value, `result` is `None`.
3853                    write!(self.out, "{level}")?;
3854                    let fun_key = if let Some(result) = result {
3855                        let res_name = Baked(result).to_string();
3856                        self.start_baking_expression(result, context, &res_name)?;
3857                        self.named_expressions.insert(result, res_name);
3858                        fun.to_msl()
3859                    } else if context.resolve_type(value).scalar_width() == Some(8) {
3860                        fun.to_msl_64_bit()?
3861                    } else {
3862                        fun.to_msl()
3863                    };
3864
3865                    // If the pointer we're passing to the atomic operation needs to be conditional
3866                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
3867                    // the pointer operand should be unchecked.
3868                    let policy = context.choose_bounds_check_policy(pointer);
3869                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3870                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3871
3872                    // If requested and successfully put bounds checks, continue the ternary expression.
3873                    if checked {
3874                        write!(self.out, " ? ")?;
3875                    }
3876
3877                    // Put the atomic function invocation.
3878                    match *fun {
3879                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3880                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3881                            self.put_access_chain(pointer, policy, context)?;
3882                            write!(self.out, ", ")?;
3883                            self.put_expression(cmp, context, true)?;
3884                            write!(self.out, ", ")?;
3885                            self.put_expression(value, context, true)?;
3886                            write!(self.out, ")")?;
3887                        }
3888                        _ => {
3889                            write!(
3890                                self.out,
3891                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3892                            )?;
3893                            self.put_access_chain(pointer, policy, context)?;
3894                            write!(self.out, ", ")?;
3895                            self.put_expression(value, context, true)?;
3896                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3897                        }
3898                    }
3899
3900                    // Finish the ternary expression.
3901                    if checked {
3902                        write!(self.out, " : DefaultConstructible()")?;
3903                    }
3904
3905                    // Done
3906                    writeln!(self.out, ";")?;
3907                }
3908                crate::Statement::ImageAtomic {
3909                    image,
3910                    coordinate,
3911                    array_index,
3912                    fun,
3913                    value,
3914                } => {
3915                    let address = TexelAddress {
3916                        coordinate,
3917                        array_index,
3918                        sample: None,
3919                        level: None,
3920                    };
3921                    self.put_image_atomic(level, image, &address, fun, value, context)?
3922                }
3923                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3924                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3925
3926                    write!(self.out, "{level}")?;
3927                    let name = self.namer.call("");
3928                    self.start_baking_expression(result, &context.expression, &name)?;
3929                    self.put_load(pointer, &context.expression, true)?;
3930                    self.named_expressions.insert(result, name);
3931
3932                    writeln!(self.out, ";")?;
3933                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3934                }
3935                crate::Statement::RayQuery { query, ref fun } => {
3936                    if context.expression.lang_version < (2, 4) {
3937                        return Err(Error::UnsupportedRayTracing);
3938                    }
3939
3940                    match *fun {
3941                        crate::RayQueryFunction::Initialize {
3942                            acceleration_structure,
3943                            descriptor,
3944                        } => {
3945                            //TODO: how to deal with winding?
3946                            write!(self.out, "{level}")?;
3947                            self.put_expression(query, &context.expression, true)?;
3948                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3949                            {
3950                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3951                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3952                                write!(self.out, "{level}")?;
3953                                self.put_expression(query, &context.expression, true)?;
3954                                write!(
3955                                    self.out,
3956                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3957                                )?;
3958                                self.put_expression(descriptor, &context.expression, true)?;
3959                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3960                                self.put_expression(descriptor, &context.expression, true)?;
3961                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3962                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3963                            }
3964                            {
3965                                let f_opaque = back::RayFlag::OPAQUE.bits();
3966                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3967                                write!(self.out, "{level}")?;
3968                                self.put_expression(query, &context.expression, true)?;
3969                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3970                                self.put_expression(descriptor, &context.expression, true)?;
3971                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3972                                self.put_expression(descriptor, &context.expression, true)?;
3973                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3974                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3975                            }
3976                            {
3977                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3978                                write!(self.out, "{level}")?;
3979                                self.put_expression(query, &context.expression, true)?;
3980                                write!(
3981                                    self.out,
3982                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3983                                )?;
3984                                self.put_expression(descriptor, &context.expression, true)?;
3985                                writeln!(self.out, ".flags & {flag}) != 0);")?;
3986                            }
3987
3988                            write!(self.out, "{level}")?;
3989                            self.put_expression(query, &context.expression, true)?;
3990                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3991                            self.put_expression(query, &context.expression, true)?;
3992                            write!(
3993                                self.out,
3994                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
3995                            )?;
3996                            self.put_expression(descriptor, &context.expression, true)?;
3997                            write!(self.out, ".origin, ")?;
3998                            self.put_expression(descriptor, &context.expression, true)?;
3999                            write!(self.out, ".dir, ")?;
4000                            self.put_expression(descriptor, &context.expression, true)?;
4001                            write!(self.out, ".tmin, ")?;
4002                            self.put_expression(descriptor, &context.expression, true)?;
4003                            write!(self.out, ".tmax), ")?;
4004                            self.put_expression(acceleration_structure, &context.expression, true)?;
4005                            write!(self.out, ", ")?;
4006                            self.put_expression(descriptor, &context.expression, true)?;
4007                            write!(self.out, ".cull_mask);")?;
4008
4009                            write!(self.out, "{level}")?;
4010                            self.put_expression(query, &context.expression, true)?;
4011                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4012                        }
4013                        crate::RayQueryFunction::Proceed { result } => {
4014                            write!(self.out, "{level}")?;
4015                            let name = Baked(result).to_string();
4016                            self.start_baking_expression(result, &context.expression, &name)?;
4017                            self.named_expressions.insert(result, name);
4018                            self.put_expression(query, &context.expression, true)?;
4019                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4020                            if RAY_QUERY_MODERN_SUPPORT {
4021                                write!(self.out, "{level}")?;
4022                                self.put_expression(query, &context.expression, true)?;
4023                                writeln!(self.out, ".?.next();")?;
4024                            }
4025                        }
4026                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4027                            if RAY_QUERY_MODERN_SUPPORT {
4028                                write!(self.out, "{level}")?;
4029                                self.put_expression(query, &context.expression, true)?;
4030                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4031                                self.put_expression(hit_t, &context.expression, true)?;
4032                                writeln!(self.out, ");")?;
4033                            } else {
4034                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4035                            }
4036                        }
4037                        crate::RayQueryFunction::ConfirmIntersection => {
4038                            if RAY_QUERY_MODERN_SUPPORT {
4039                                write!(self.out, "{level}")?;
4040                                self.put_expression(query, &context.expression, true)?;
4041                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4042                            } else {
4043                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4044                            }
4045                        }
4046                        crate::RayQueryFunction::Terminate => {
4047                            if RAY_QUERY_MODERN_SUPPORT {
4048                                write!(self.out, "{level}")?;
4049                                self.put_expression(query, &context.expression, true)?;
4050                                writeln!(self.out, ".?.abort();")?;
4051                            }
4052                            write!(self.out, "{level}")?;
4053                            self.put_expression(query, &context.expression, true)?;
4054                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4055                        }
4056                    }
4057                }
4058                crate::Statement::SubgroupBallot { result, predicate } => {
4059                    write!(self.out, "{level}")?;
4060                    let name = self.namer.call("");
4061                    self.start_baking_expression(result, &context.expression, &name)?;
4062                    self.named_expressions.insert(result, name);
4063                    write!(
4064                        self.out,
4065                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4066                    )?;
4067                    if let Some(predicate) = predicate {
4068                        self.put_expression(predicate, &context.expression, true)?;
4069                    } else {
4070                        write!(self.out, "true")?;
4071                    }
4072                    writeln!(self.out, "), 0, 0, 0);")?;
4073                }
4074                crate::Statement::SubgroupCollectiveOperation {
4075                    op,
4076                    collective_op,
4077                    argument,
4078                    result,
4079                } => {
4080                    write!(self.out, "{level}")?;
4081                    let name = self.namer.call("");
4082                    self.start_baking_expression(result, &context.expression, &name)?;
4083                    self.named_expressions.insert(result, name);
4084                    match (collective_op, op) {
4085                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4086                            write!(self.out, "{NAMESPACE}::simd_all(")?
4087                        }
4088                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4089                            write!(self.out, "{NAMESPACE}::simd_any(")?
4090                        }
4091                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4092                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4093                        }
4094                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4095                            write!(self.out, "{NAMESPACE}::simd_product(")?
4096                        }
4097                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4098                            write!(self.out, "{NAMESPACE}::simd_max(")?
4099                        }
4100                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4101                            write!(self.out, "{NAMESPACE}::simd_min(")?
4102                        }
4103                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4104                            write!(self.out, "{NAMESPACE}::simd_and(")?
4105                        }
4106                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4107                            write!(self.out, "{NAMESPACE}::simd_or(")?
4108                        }
4109                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4110                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4111                        }
4112                        (
4113                            crate::CollectiveOperation::ExclusiveScan,
4114                            crate::SubgroupOperation::Add,
4115                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4116                        (
4117                            crate::CollectiveOperation::ExclusiveScan,
4118                            crate::SubgroupOperation::Mul,
4119                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4120                        (
4121                            crate::CollectiveOperation::InclusiveScan,
4122                            crate::SubgroupOperation::Add,
4123                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4124                        (
4125                            crate::CollectiveOperation::InclusiveScan,
4126                            crate::SubgroupOperation::Mul,
4127                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4128                        _ => unimplemented!(),
4129                    }
4130                    self.put_expression(argument, &context.expression, true)?;
4131                    writeln!(self.out, ");")?;
4132                }
4133                crate::Statement::SubgroupGather {
4134                    mode,
4135                    argument,
4136                    result,
4137                } => {
4138                    write!(self.out, "{level}")?;
4139                    let name = self.namer.call("");
4140                    self.start_baking_expression(result, &context.expression, &name)?;
4141                    self.named_expressions.insert(result, name);
4142                    match mode {
4143                        crate::GatherMode::BroadcastFirst => {
4144                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4145                        }
4146                        crate::GatherMode::Broadcast(_) => {
4147                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4148                        }
4149                        crate::GatherMode::Shuffle(_) => {
4150                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4151                        }
4152                        crate::GatherMode::ShuffleDown(_) => {
4153                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4154                        }
4155                        crate::GatherMode::ShuffleUp(_) => {
4156                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4157                        }
4158                        crate::GatherMode::ShuffleXor(_) => {
4159                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4160                        }
4161                        crate::GatherMode::QuadBroadcast(_) => {
4162                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4163                        }
4164                        crate::GatherMode::QuadSwap(_) => {
4165                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4166                        }
4167                    }
4168                    self.put_expression(argument, &context.expression, true)?;
4169                    match mode {
4170                        crate::GatherMode::BroadcastFirst => {}
4171                        crate::GatherMode::Broadcast(index)
4172                        | crate::GatherMode::Shuffle(index)
4173                        | crate::GatherMode::ShuffleDown(index)
4174                        | crate::GatherMode::ShuffleUp(index)
4175                        | crate::GatherMode::ShuffleXor(index)
4176                        | crate::GatherMode::QuadBroadcast(index) => {
4177                            write!(self.out, ", ")?;
4178                            self.put_expression(index, &context.expression, true)?;
4179                        }
4180                        crate::GatherMode::QuadSwap(direction) => {
4181                            write!(self.out, ", ")?;
4182                            match direction {
4183                                crate::Direction::X => {
4184                                    write!(self.out, "1u")?;
4185                                }
4186                                crate::Direction::Y => {
4187                                    write!(self.out, "2u")?;
4188                                }
4189                                crate::Direction::Diagonal => {
4190                                    write!(self.out, "3u")?;
4191                                }
4192                            }
4193                        }
4194                    }
4195                    writeln!(self.out, ");")?;
4196                }
4197            }
4198        }
4199
4200        // un-emit expressions
4201        //TODO: take care of loop/continuing?
4202        for statement in statements {
4203            if let crate::Statement::Emit(ref range) = *statement {
4204                for handle in range.clone() {
4205                    self.named_expressions.shift_remove(&handle);
4206                }
4207            }
4208        }
4209        Ok(())
4210    }
4211
4212    fn put_store(
4213        &mut self,
4214        pointer: Handle<crate::Expression>,
4215        value: Handle<crate::Expression>,
4216        level: back::Level,
4217        context: &StatementContext,
4218    ) -> BackendResult {
4219        let policy = context.expression.choose_bounds_check_policy(pointer);
4220        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4221            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4222        {
4223            writeln!(self.out, ") {{")?;
4224            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4225            writeln!(self.out, "{level}}}")?;
4226        } else {
4227            self.put_unchecked_store(pointer, value, policy, level, context)?;
4228        }
4229
4230        Ok(())
4231    }
4232
4233    fn put_unchecked_store(
4234        &mut self,
4235        pointer: Handle<crate::Expression>,
4236        value: Handle<crate::Expression>,
4237        policy: index::BoundsCheckPolicy,
4238        level: back::Level,
4239        context: &StatementContext,
4240    ) -> BackendResult {
4241        let is_atomic_pointer = context
4242            .expression
4243            .resolve_type(pointer)
4244            .is_atomic_pointer(&context.expression.module.types);
4245
4246        if is_atomic_pointer {
4247            write!(
4248                self.out,
4249                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4250            )?;
4251            self.put_access_chain(pointer, policy, &context.expression)?;
4252            write!(self.out, ", ")?;
4253            self.put_expression(value, &context.expression, true)?;
4254            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4255        } else {
4256            write!(self.out, "{level}")?;
4257            self.put_access_chain(pointer, policy, &context.expression)?;
4258            write!(self.out, " = ")?;
4259            self.put_expression(value, &context.expression, true)?;
4260            writeln!(self.out, ";")?;
4261        }
4262
4263        Ok(())
4264    }
4265
4266    pub fn write(
4267        &mut self,
4268        module: &crate::Module,
4269        info: &valid::ModuleInfo,
4270        options: &Options,
4271        pipeline_options: &PipelineOptions,
4272    ) -> Result<TranslationInfo, Error> {
4273        self.names.clear();
4274        self.namer.reset(
4275            module,
4276            &super::keywords::RESERVED_SET,
4277            proc::CaseInsensitiveKeywordSet::empty(),
4278            &[CLAMPED_LOD_LOAD_PREFIX],
4279            &mut self.names,
4280        );
4281        self.wrapped_functions.clear();
4282        self.struct_member_pads.clear();
4283
4284        writeln!(
4285            self.out,
4286            "// language: metal{}.{}",
4287            options.lang_version.0, options.lang_version.1
4288        )?;
4289        writeln!(self.out, "#include <metal_stdlib>")?;
4290        writeln!(self.out, "#include <simd/simd.h>")?;
4291        writeln!(self.out)?;
4292        // Work around Metal bug where `uint` is not available by default
4293        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4294
4295        let mut uses_ray_query = false;
4296        for (_, ty) in module.types.iter() {
4297            match ty.inner {
4298                crate::TypeInner::AccelerationStructure { .. } => {
4299                    if options.lang_version < (2, 4) {
4300                        return Err(Error::UnsupportedRayTracing);
4301                    }
4302                }
4303                crate::TypeInner::RayQuery { .. } => {
4304                    if options.lang_version < (2, 4) {
4305                        return Err(Error::UnsupportedRayTracing);
4306                    }
4307                    uses_ray_query = true;
4308                }
4309                _ => (),
4310            }
4311        }
4312
4313        if module.special_types.ray_desc.is_some()
4314            || module.special_types.ray_intersection.is_some()
4315        {
4316            if options.lang_version < (2, 4) {
4317                return Err(Error::UnsupportedRayTracing);
4318            }
4319        }
4320
4321        if uses_ray_query {
4322            self.put_ray_query_type()?;
4323        }
4324
4325        if options
4326            .bounds_check_policies
4327            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4328        {
4329            self.put_default_constructible()?;
4330        }
4331        writeln!(self.out)?;
4332
4333        {
4334            // Make a `Vec` of all the `GlobalVariable`s that contain
4335            // runtime-sized arrays.
4336            let globals: Vec<Handle<crate::GlobalVariable>> = module
4337                .global_variables
4338                .iter()
4339                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4340                .map(|(handle, _)| handle)
4341                .collect();
4342
4343            let mut buffer_indices = vec![];
4344            for vbm in &pipeline_options.vertex_buffer_mappings {
4345                buffer_indices.push(vbm.id);
4346            }
4347
4348            if !globals.is_empty() || !buffer_indices.is_empty() {
4349                writeln!(self.out, "struct _mslBufferSizes {{")?;
4350
4351                for global in globals {
4352                    writeln!(
4353                        self.out,
4354                        "{}uint {};",
4355                        back::INDENT,
4356                        ArraySizeMember(global)
4357                    )?;
4358                }
4359
4360                for idx in buffer_indices {
4361                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4362                }
4363
4364                writeln!(self.out, "}};")?;
4365                writeln!(self.out)?;
4366            }
4367        };
4368
4369        self.write_type_defs(module)?;
4370        self.write_global_constants(module, info)?;
4371        self.write_functions(module, info, options, pipeline_options)
4372    }
4373
4374    /// Write the definition for the `DefaultConstructible` class.
4375    ///
4376    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4377    /// produce 'zero' values for any type, including structs, arrays, and so
4378    /// on. We could do this by emitting default constructor applications, but
4379    /// that would entail printing the name of the type, which is more trouble
4380    /// than you'd think. Instead, we just construct this magic C++14 class that
4381    /// can be converted to any type that can be default constructed, using
4382    /// template parameter inference to detect which type is needed, so we don't
4383    /// have to figure out the name.
4384    ///
4385    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4386    fn put_default_constructible(&mut self) -> BackendResult {
4387        let tab = back::INDENT;
4388        writeln!(self.out, "struct DefaultConstructible {{")?;
4389        writeln!(self.out, "{tab}template<typename T>")?;
4390        writeln!(self.out, "{tab}operator T() && {{")?;
4391        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4392        writeln!(self.out, "{tab}}}")?;
4393        writeln!(self.out, "}};")?;
4394        Ok(())
4395    }
4396
4397    fn put_ray_query_type(&mut self) -> BackendResult {
4398        let tab = back::INDENT;
4399        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4400        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4401        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4402        writeln!(
4403            self.out,
4404            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4405        )?;
4406        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4407        writeln!(self.out, "}};")?;
4408        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4409        let v_triangle = back::RayIntersectionType::Triangle as u32;
4410        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4411        writeln!(
4412            self.out,
4413            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4414        )?;
4415        writeln!(
4416            self.out,
4417            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4418        )?;
4419        writeln!(self.out, "}}")?;
4420        Ok(())
4421    }
4422
4423    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4424        let mut generated_argument_buffer_wrapper = false;
4425        let mut generated_external_texture_wrapper = false;
4426        for (handle, ty) in module.types.iter() {
4427            match ty.inner {
4428                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4429                    writeln!(self.out, "template <typename T>")?;
4430                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4431                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4432                    writeln!(self.out, "}};")?;
4433                    generated_argument_buffer_wrapper = true;
4434                }
4435                crate::TypeInner::Image {
4436                    class: crate::ImageClass::External,
4437                    ..
4438                } if !generated_external_texture_wrapper => {
4439                    let params_ty_name = &self.names
4440                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4441                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4442                    writeln!(
4443                        self.out,
4444                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4445                        back::INDENT
4446                    )?;
4447                    writeln!(
4448                        self.out,
4449                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4450                        back::INDENT
4451                    )?;
4452                    writeln!(
4453                        self.out,
4454                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4455                        back::INDENT
4456                    )?;
4457                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4458                    writeln!(self.out, "}};")?;
4459                    generated_external_texture_wrapper = true;
4460                }
4461                _ => {}
4462            }
4463
4464            if !ty.needs_alias() {
4465                continue;
4466            }
4467            let name = &self.names[&NameKey::Type(handle)];
4468            match ty.inner {
4469                // Naga IR can pass around arrays by value, but Metal, following
4470                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4471                // on expressions of array type, so assigning the array by value
4472                // isn't possible. However, Metal *does* assign structs by
4473                // value. So in our Metal output, we wrap all array types in
4474                // synthetic struct types:
4475                //
4476                //     struct type1 {
4477                //         float inner[10]
4478                //     };
4479                //
4480                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4481                // any expression that actually wants access to the array.
4482                crate::TypeInner::Array {
4483                    base,
4484                    size,
4485                    stride: _,
4486                } => {
4487                    let base_name = TypeContext {
4488                        handle: base,
4489                        gctx: module.to_ctx(),
4490                        names: &self.names,
4491                        access: crate::StorageAccess::empty(),
4492                        first_time: false,
4493                    };
4494
4495                    match size.resolve(module.to_ctx())? {
4496                        proc::IndexableLength::Known(size) => {
4497                            writeln!(self.out, "struct {name} {{")?;
4498                            writeln!(
4499                                self.out,
4500                                "{}{} {}[{}];",
4501                                back::INDENT,
4502                                base_name,
4503                                WRAPPED_ARRAY_FIELD,
4504                                size
4505                            )?;
4506                            writeln!(self.out, "}};")?;
4507                        }
4508                        proc::IndexableLength::Dynamic => {
4509                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4510                        }
4511                    }
4512                }
4513                crate::TypeInner::Struct {
4514                    ref members, span, ..
4515                } => {
4516                    writeln!(self.out, "struct {name} {{")?;
4517                    let mut last_offset = 0;
4518                    for (index, member) in members.iter().enumerate() {
4519                        if member.offset > last_offset {
4520                            self.struct_member_pads.insert((handle, index as u32));
4521                            let pad = member.offset - last_offset;
4522                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4523                        }
4524                        let ty_inner = &module.types[member.ty].inner;
4525                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4526
4527                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4528
4529                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4530                        match should_pack_struct_member(members, span, index, module) {
4531                            Some(scalar) => {
4532                                writeln!(
4533                                    self.out,
4534                                    "{}{}::packed_{}3 {};",
4535                                    back::INDENT,
4536                                    NAMESPACE,
4537                                    scalar.to_msl_name(),
4538                                    member_name
4539                                )?;
4540                            }
4541                            None => {
4542                                let base_name = TypeContext {
4543                                    handle: member.ty,
4544                                    gctx: module.to_ctx(),
4545                                    names: &self.names,
4546                                    access: crate::StorageAccess::empty(),
4547                                    first_time: false,
4548                                };
4549                                writeln!(
4550                                    self.out,
4551                                    "{}{} {};",
4552                                    back::INDENT,
4553                                    base_name,
4554                                    member_name
4555                                )?;
4556
4557                                // for 3-component vectors, add one component
4558                                if let crate::TypeInner::Vector {
4559                                    size: crate::VectorSize::Tri,
4560                                    scalar,
4561                                } = *ty_inner
4562                                {
4563                                    last_offset += scalar.width as u32;
4564                                }
4565                            }
4566                        }
4567                    }
4568                    if last_offset < span {
4569                        let pad = span - last_offset;
4570                        writeln!(
4571                            self.out,
4572                            "{}char _pad{}[{}];",
4573                            back::INDENT,
4574                            members.len(),
4575                            pad
4576                        )?;
4577                    }
4578                    writeln!(self.out, "}};")?;
4579                }
4580                _ => {
4581                    let ty_name = TypeContext {
4582                        handle,
4583                        gctx: module.to_ctx(),
4584                        names: &self.names,
4585                        access: crate::StorageAccess::empty(),
4586                        first_time: true,
4587                    };
4588                    writeln!(self.out, "typedef {ty_name} {name};")?;
4589                }
4590            }
4591        }
4592
4593        // Write functions to create special types.
4594        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4595            match type_key {
4596                &crate::PredeclaredType::ModfResult { size, scalar }
4597                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4598                    let arg_type_name_owner;
4599                    let arg_type_name = if let Some(size) = size {
4600                        arg_type_name_owner = format!(
4601                            "{NAMESPACE}::{}{}",
4602                            if scalar.width == 8 { "double" } else { "float" },
4603                            size as u8
4604                        );
4605                        &arg_type_name_owner
4606                    } else if scalar.width == 8 {
4607                        "double"
4608                    } else {
4609                        "float"
4610                    };
4611
4612                    let other_type_name_owner;
4613                    let (defined_func_name, called_func_name, other_type_name) =
4614                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4615                            (MODF_FUNCTION, "modf", arg_type_name)
4616                        } else {
4617                            let other_type_name = if let Some(size) = size {
4618                                other_type_name_owner = format!("int{}", size as u8);
4619                                &other_type_name_owner
4620                            } else {
4621                                "int"
4622                            };
4623                            (FREXP_FUNCTION, "frexp", other_type_name)
4624                        };
4625
4626                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4627
4628                    writeln!(self.out)?;
4629                    writeln!(
4630                        self.out,
4631                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4632    {other_type_name} other;
4633    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4634    return {struct_name}{{ fract, other }};
4635}}"
4636                    )?;
4637                }
4638                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4639                    let arg_type_name = scalar.to_msl_name();
4640                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4641                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4642                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4643
4644                    writeln!(self.out)?;
4645
4646                    for address_space_name in ["device", "threadgroup"] {
4647                        writeln!(
4648                            self.out,
4649                            "\
4650template <typename A>
4651{struct_name} {defined_func_name}(
4652    {address_space_name} A *atomic_ptr,
4653    {arg_type_name} cmp,
4654    {arg_type_name} v
4655) {{
4656    bool swapped = {NAMESPACE}::{called_func_name}(
4657        atomic_ptr, &cmp, v,
4658        metal::memory_order_relaxed, metal::memory_order_relaxed
4659    );
4660    return {struct_name}{{cmp, swapped}};
4661}}"
4662                        )?;
4663                    }
4664                }
4665            }
4666        }
4667
4668        Ok(())
4669    }
4670
4671    /// Writes all named constants
4672    fn write_global_constants(
4673        &mut self,
4674        module: &crate::Module,
4675        mod_info: &valid::ModuleInfo,
4676    ) -> BackendResult {
4677        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4678
4679        for (handle, constant) in constants {
4680            let ty_name = TypeContext {
4681                handle: constant.ty,
4682                gctx: module.to_ctx(),
4683                names: &self.names,
4684                access: crate::StorageAccess::empty(),
4685                first_time: false,
4686            };
4687            let name = &self.names[&NameKey::Constant(handle)];
4688            write!(self.out, "constant {ty_name} {name} = ")?;
4689            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4690            writeln!(self.out, ";")?;
4691        }
4692
4693        Ok(())
4694    }
4695
4696    fn put_inline_sampler_properties(
4697        &mut self,
4698        level: back::Level,
4699        sampler: &sm::InlineSampler,
4700    ) -> BackendResult {
4701        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4702            writeln!(
4703                self.out,
4704                "{}{}::{}_address::{},",
4705                level,
4706                NAMESPACE,
4707                letter,
4708                address.as_str(),
4709            )?;
4710        }
4711        writeln!(
4712            self.out,
4713            "{}{}::mag_filter::{},",
4714            level,
4715            NAMESPACE,
4716            sampler.mag_filter.as_str(),
4717        )?;
4718        writeln!(
4719            self.out,
4720            "{}{}::min_filter::{},",
4721            level,
4722            NAMESPACE,
4723            sampler.min_filter.as_str(),
4724        )?;
4725        if let Some(filter) = sampler.mip_filter {
4726            writeln!(
4727                self.out,
4728                "{}{}::mip_filter::{},",
4729                level,
4730                NAMESPACE,
4731                filter.as_str(),
4732            )?;
4733        }
4734        // avoid setting it on platforms that don't support it
4735        if sampler.border_color != sm::BorderColor::TransparentBlack {
4736            writeln!(
4737                self.out,
4738                "{}{}::border_color::{},",
4739                level,
4740                NAMESPACE,
4741                sampler.border_color.as_str(),
4742            )?;
4743        }
4744        //TODO: I'm not able to feed this in a way that MSL likes:
4745        //>error: use of undeclared identifier 'lod_clamp'
4746        //>error: no member named 'max_anisotropy' in namespace 'metal'
4747        if false {
4748            if let Some(ref lod) = sampler.lod_clamp {
4749                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4750            }
4751            if let Some(aniso) = sampler.max_anisotropy {
4752                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4753            }
4754        }
4755        if sampler.compare_func != sm::CompareFunc::Never {
4756            writeln!(
4757                self.out,
4758                "{}{}::compare_func::{},",
4759                level,
4760                NAMESPACE,
4761                sampler.compare_func.as_str(),
4762            )?;
4763        }
4764        writeln!(
4765            self.out,
4766            "{}{}::coord::{}",
4767            level,
4768            NAMESPACE,
4769            sampler.coord.as_str()
4770        )?;
4771        Ok(())
4772    }
4773
4774    fn write_unpacking_function(
4775        &mut self,
4776        format: back::msl::VertexFormat,
4777    ) -> Result<(String, u32, u32), Error> {
4778        use back::msl::VertexFormat::*;
4779        match format {
4780            Uint8 => {
4781                let name = self.namer.call("unpackUint8");
4782                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4783                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4784                writeln!(self.out, "}}")?;
4785                Ok((name, 1, 1))
4786            }
4787            Uint8x2 => {
4788                let name = self.namer.call("unpackUint8x2");
4789                writeln!(
4790                    self.out,
4791                    "metal::uint2 {name}(metal::uchar b0, \
4792                                         metal::uchar b1) {{"
4793                )?;
4794                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4795                writeln!(self.out, "}}")?;
4796                Ok((name, 2, 2))
4797            }
4798            Uint8x4 => {
4799                let name = self.namer.call("unpackUint8x4");
4800                writeln!(
4801                    self.out,
4802                    "metal::uint4 {name}(metal::uchar b0, \
4803                                         metal::uchar b1, \
4804                                         metal::uchar b2, \
4805                                         metal::uchar b3) {{"
4806                )?;
4807                writeln!(
4808                    self.out,
4809                    "{}return metal::uint4(b0, b1, b2, b3);",
4810                    back::INDENT
4811                )?;
4812                writeln!(self.out, "}}")?;
4813                Ok((name, 4, 4))
4814            }
4815            Sint8 => {
4816                let name = self.namer.call("unpackSint8");
4817                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4818                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4819                writeln!(self.out, "}}")?;
4820                Ok((name, 1, 1))
4821            }
4822            Sint8x2 => {
4823                let name = self.namer.call("unpackSint8x2");
4824                writeln!(
4825                    self.out,
4826                    "metal::int2 {name}(metal::uchar b0, \
4827                                        metal::uchar b1) {{"
4828                )?;
4829                writeln!(
4830                    self.out,
4831                    "{}return metal::int2(as_type<char>(b0), \
4832                                          as_type<char>(b1));",
4833                    back::INDENT
4834                )?;
4835                writeln!(self.out, "}}")?;
4836                Ok((name, 2, 2))
4837            }
4838            Sint8x4 => {
4839                let name = self.namer.call("unpackSint8x4");
4840                writeln!(
4841                    self.out,
4842                    "metal::int4 {name}(metal::uchar b0, \
4843                                        metal::uchar b1, \
4844                                        metal::uchar b2, \
4845                                        metal::uchar b3) {{"
4846                )?;
4847                writeln!(
4848                    self.out,
4849                    "{}return metal::int4(as_type<char>(b0), \
4850                                          as_type<char>(b1), \
4851                                          as_type<char>(b2), \
4852                                          as_type<char>(b3));",
4853                    back::INDENT
4854                )?;
4855                writeln!(self.out, "}}")?;
4856                Ok((name, 4, 4))
4857            }
4858            Unorm8 => {
4859                let name = self.namer.call("unpackUnorm8");
4860                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4861                writeln!(
4862                    self.out,
4863                    "{}return float(float(b0) / 255.0f);",
4864                    back::INDENT
4865                )?;
4866                writeln!(self.out, "}}")?;
4867                Ok((name, 1, 1))
4868            }
4869            Unorm8x2 => {
4870                let name = self.namer.call("unpackUnorm8x2");
4871                writeln!(
4872                    self.out,
4873                    "metal::float2 {name}(metal::uchar b0, \
4874                                          metal::uchar b1) {{"
4875                )?;
4876                writeln!(
4877                    self.out,
4878                    "{}return metal::float2(float(b0) / 255.0f, \
4879                                            float(b1) / 255.0f);",
4880                    back::INDENT
4881                )?;
4882                writeln!(self.out, "}}")?;
4883                Ok((name, 2, 2))
4884            }
4885            Unorm8x4 => {
4886                let name = self.namer.call("unpackUnorm8x4");
4887                writeln!(
4888                    self.out,
4889                    "metal::float4 {name}(metal::uchar b0, \
4890                                          metal::uchar b1, \
4891                                          metal::uchar b2, \
4892                                          metal::uchar b3) {{"
4893                )?;
4894                writeln!(
4895                    self.out,
4896                    "{}return metal::float4(float(b0) / 255.0f, \
4897                                            float(b1) / 255.0f, \
4898                                            float(b2) / 255.0f, \
4899                                            float(b3) / 255.0f);",
4900                    back::INDENT
4901                )?;
4902                writeln!(self.out, "}}")?;
4903                Ok((name, 4, 4))
4904            }
4905            Snorm8 => {
4906                let name = self.namer.call("unpackSnorm8");
4907                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4908                writeln!(
4909                    self.out,
4910                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4911                    back::INDENT
4912                )?;
4913                writeln!(self.out, "}}")?;
4914                Ok((name, 1, 1))
4915            }
4916            Snorm8x2 => {
4917                let name = self.namer.call("unpackSnorm8x2");
4918                writeln!(
4919                    self.out,
4920                    "metal::float2 {name}(metal::uchar b0, \
4921                                          metal::uchar b1) {{"
4922                )?;
4923                writeln!(
4924                    self.out,
4925                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4926                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4927                    back::INDENT
4928                )?;
4929                writeln!(self.out, "}}")?;
4930                Ok((name, 2, 2))
4931            }
4932            Snorm8x4 => {
4933                let name = self.namer.call("unpackSnorm8x4");
4934                writeln!(
4935                    self.out,
4936                    "metal::float4 {name}(metal::uchar b0, \
4937                                          metal::uchar b1, \
4938                                          metal::uchar b2, \
4939                                          metal::uchar b3) {{"
4940                )?;
4941                writeln!(
4942                    self.out,
4943                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4944                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4945                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4946                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4947                    back::INDENT
4948                )?;
4949                writeln!(self.out, "}}")?;
4950                Ok((name, 4, 4))
4951            }
4952            Uint16 => {
4953                let name = self.namer.call("unpackUint16");
4954                writeln!(
4955                    self.out,
4956                    "metal::uint {name}(metal::uint b0, \
4957                                        metal::uint b1) {{"
4958                )?;
4959                writeln!(
4960                    self.out,
4961                    "{}return metal::uint(b1 << 8 | b0);",
4962                    back::INDENT
4963                )?;
4964                writeln!(self.out, "}}")?;
4965                Ok((name, 2, 1))
4966            }
4967            Uint16x2 => {
4968                let name = self.namer.call("unpackUint16x2");
4969                writeln!(
4970                    self.out,
4971                    "metal::uint2 {name}(metal::uint b0, \
4972                                         metal::uint b1, \
4973                                         metal::uint b2, \
4974                                         metal::uint b3) {{"
4975                )?;
4976                writeln!(
4977                    self.out,
4978                    "{}return metal::uint2(b1 << 8 | b0, \
4979                                           b3 << 8 | b2);",
4980                    back::INDENT
4981                )?;
4982                writeln!(self.out, "}}")?;
4983                Ok((name, 4, 2))
4984            }
4985            Uint16x4 => {
4986                let name = self.namer.call("unpackUint16x4");
4987                writeln!(
4988                    self.out,
4989                    "metal::uint4 {name}(metal::uint b0, \
4990                                         metal::uint b1, \
4991                                         metal::uint b2, \
4992                                         metal::uint b3, \
4993                                         metal::uint b4, \
4994                                         metal::uint b5, \
4995                                         metal::uint b6, \
4996                                         metal::uint b7) {{"
4997                )?;
4998                writeln!(
4999                    self.out,
5000                    "{}return metal::uint4(b1 << 8 | b0, \
5001                                           b3 << 8 | b2, \
5002                                           b5 << 8 | b4, \
5003                                           b7 << 8 | b6);",
5004                    back::INDENT
5005                )?;
5006                writeln!(self.out, "}}")?;
5007                Ok((name, 8, 4))
5008            }
5009            Sint16 => {
5010                let name = self.namer.call("unpackSint16");
5011                writeln!(
5012                    self.out,
5013                    "int {name}(metal::ushort b0, \
5014                                metal::ushort b1) {{"
5015                )?;
5016                writeln!(
5017                    self.out,
5018                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5019                    back::INDENT
5020                )?;
5021                writeln!(self.out, "}}")?;
5022                Ok((name, 2, 1))
5023            }
5024            Sint16x2 => {
5025                let name = self.namer.call("unpackSint16x2");
5026                writeln!(
5027                    self.out,
5028                    "metal::int2 {name}(metal::ushort b0, \
5029                                        metal::ushort b1, \
5030                                        metal::ushort b2, \
5031                                        metal::ushort b3) {{"
5032                )?;
5033                writeln!(
5034                    self.out,
5035                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5036                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5037                    back::INDENT
5038                )?;
5039                writeln!(self.out, "}}")?;
5040                Ok((name, 4, 2))
5041            }
5042            Sint16x4 => {
5043                let name = self.namer.call("unpackSint16x4");
5044                writeln!(
5045                    self.out,
5046                    "metal::int4 {name}(metal::ushort b0, \
5047                                        metal::ushort b1, \
5048                                        metal::ushort b2, \
5049                                        metal::ushort b3, \
5050                                        metal::ushort b4, \
5051                                        metal::ushort b5, \
5052                                        metal::ushort b6, \
5053                                        metal::ushort b7) {{"
5054                )?;
5055                writeln!(
5056                    self.out,
5057                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5058                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5059                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5060                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5061                    back::INDENT
5062                )?;
5063                writeln!(self.out, "}}")?;
5064                Ok((name, 8, 4))
5065            }
5066            Unorm16 => {
5067                let name = self.namer.call("unpackUnorm16");
5068                writeln!(
5069                    self.out,
5070                    "float {name}(metal::ushort b0, \
5071                                  metal::ushort b1) {{"
5072                )?;
5073                writeln!(
5074                    self.out,
5075                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5076                    back::INDENT
5077                )?;
5078                writeln!(self.out, "}}")?;
5079                Ok((name, 2, 1))
5080            }
5081            Unorm16x2 => {
5082                let name = self.namer.call("unpackUnorm16x2");
5083                writeln!(
5084                    self.out,
5085                    "metal::float2 {name}(metal::ushort b0, \
5086                                          metal::ushort b1, \
5087                                          metal::ushort b2, \
5088                                          metal::ushort b3) {{"
5089                )?;
5090                writeln!(
5091                    self.out,
5092                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5093                                            float(b3 << 8 | b2) / 65535.0f);",
5094                    back::INDENT
5095                )?;
5096                writeln!(self.out, "}}")?;
5097                Ok((name, 4, 2))
5098            }
5099            Unorm16x4 => {
5100                let name = self.namer.call("unpackUnorm16x4");
5101                writeln!(
5102                    self.out,
5103                    "metal::float4 {name}(metal::ushort b0, \
5104                                          metal::ushort b1, \
5105                                          metal::ushort b2, \
5106                                          metal::ushort b3, \
5107                                          metal::ushort b4, \
5108                                          metal::ushort b5, \
5109                                          metal::ushort b6, \
5110                                          metal::ushort b7) {{"
5111                )?;
5112                writeln!(
5113                    self.out,
5114                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5115                                            float(b3 << 8 | b2) / 65535.0f, \
5116                                            float(b5 << 8 | b4) / 65535.0f, \
5117                                            float(b7 << 8 | b6) / 65535.0f);",
5118                    back::INDENT
5119                )?;
5120                writeln!(self.out, "}}")?;
5121                Ok((name, 8, 4))
5122            }
5123            Snorm16 => {
5124                let name = self.namer.call("unpackSnorm16");
5125                writeln!(
5126                    self.out,
5127                    "float {name}(metal::ushort b0, \
5128                                  metal::ushort b1) {{"
5129                )?;
5130                writeln!(
5131                    self.out,
5132                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5133                    back::INDENT
5134                )?;
5135                writeln!(self.out, "}}")?;
5136                Ok((name, 2, 1))
5137            }
5138            Snorm16x2 => {
5139                let name = self.namer.call("unpackSnorm16x2");
5140                writeln!(
5141                    self.out,
5142                    "metal::float2 {name}(uint b0, \
5143                                          uint b1, \
5144                                          uint b2, \
5145                                          uint b3) {{"
5146                )?;
5147                writeln!(
5148                    self.out,
5149                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5150                    back::INDENT
5151                )?;
5152                writeln!(self.out, "}}")?;
5153                Ok((name, 4, 2))
5154            }
5155            Snorm16x4 => {
5156                let name = self.namer.call("unpackSnorm16x4");
5157                writeln!(
5158                    self.out,
5159                    "metal::float4 {name}(uint b0, \
5160                                          uint b1, \
5161                                          uint b2, \
5162                                          uint b3, \
5163                                          uint b4, \
5164                                          uint b5, \
5165                                          uint b6, \
5166                                          uint b7) {{"
5167                )?;
5168                writeln!(
5169                    self.out,
5170                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5171                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5172                    back::INDENT
5173                )?;
5174                writeln!(self.out, "}}")?;
5175                Ok((name, 8, 4))
5176            }
5177            Float16 => {
5178                let name = self.namer.call("unpackFloat16");
5179                writeln!(
5180                    self.out,
5181                    "float {name}(metal::ushort b0, \
5182                                  metal::ushort b1) {{"
5183                )?;
5184                writeln!(
5185                    self.out,
5186                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5187                    back::INDENT
5188                )?;
5189                writeln!(self.out, "}}")?;
5190                Ok((name, 2, 1))
5191            }
5192            Float16x2 => {
5193                let name = self.namer.call("unpackFloat16x2");
5194                writeln!(
5195                    self.out,
5196                    "metal::float2 {name}(metal::ushort b0, \
5197                                          metal::ushort b1, \
5198                                          metal::ushort b2, \
5199                                          metal::ushort b3) {{"
5200                )?;
5201                writeln!(
5202                    self.out,
5203                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5204                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5205                    back::INDENT
5206                )?;
5207                writeln!(self.out, "}}")?;
5208                Ok((name, 4, 2))
5209            }
5210            Float16x4 => {
5211                let name = self.namer.call("unpackFloat16x4");
5212                writeln!(
5213                    self.out,
5214                    "metal::float4 {name}(metal::ushort b0, \
5215                                        metal::ushort b1, \
5216                                        metal::ushort b2, \
5217                                        metal::ushort b3, \
5218                                        metal::ushort b4, \
5219                                        metal::ushort b5, \
5220                                        metal::ushort b6, \
5221                                        metal::ushort b7) {{"
5222                )?;
5223                writeln!(
5224                    self.out,
5225                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5226                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5227                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5228                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5229                    back::INDENT
5230                )?;
5231                writeln!(self.out, "}}")?;
5232                Ok((name, 8, 4))
5233            }
5234            Float32 => {
5235                let name = self.namer.call("unpackFloat32");
5236                writeln!(
5237                    self.out,
5238                    "float {name}(uint b0, \
5239                                  uint b1, \
5240                                  uint b2, \
5241                                  uint b3) {{"
5242                )?;
5243                writeln!(
5244                    self.out,
5245                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5246                    back::INDENT
5247                )?;
5248                writeln!(self.out, "}}")?;
5249                Ok((name, 4, 1))
5250            }
5251            Float32x2 => {
5252                let name = self.namer.call("unpackFloat32x2");
5253                writeln!(
5254                    self.out,
5255                    "metal::float2 {name}(uint b0, \
5256                                          uint b1, \
5257                                          uint b2, \
5258                                          uint b3, \
5259                                          uint b4, \
5260                                          uint b5, \
5261                                          uint b6, \
5262                                          uint b7) {{"
5263                )?;
5264                writeln!(
5265                    self.out,
5266                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5267                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5268                    back::INDENT
5269                )?;
5270                writeln!(self.out, "}}")?;
5271                Ok((name, 8, 2))
5272            }
5273            Float32x3 => {
5274                let name = self.namer.call("unpackFloat32x3");
5275                writeln!(
5276                    self.out,
5277                    "metal::float3 {name}(uint b0, \
5278                                          uint b1, \
5279                                          uint b2, \
5280                                          uint b3, \
5281                                          uint b4, \
5282                                          uint b5, \
5283                                          uint b6, \
5284                                          uint b7, \
5285                                          uint b8, \
5286                                          uint b9, \
5287                                          uint b10, \
5288                                          uint b11) {{"
5289                )?;
5290                writeln!(
5291                    self.out,
5292                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5293                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5294                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5295                    back::INDENT
5296                )?;
5297                writeln!(self.out, "}}")?;
5298                Ok((name, 12, 3))
5299            }
5300            Float32x4 => {
5301                let name = self.namer.call("unpackFloat32x4");
5302                writeln!(
5303                    self.out,
5304                    "metal::float4 {name}(uint b0, \
5305                                          uint b1, \
5306                                          uint b2, \
5307                                          uint b3, \
5308                                          uint b4, \
5309                                          uint b5, \
5310                                          uint b6, \
5311                                          uint b7, \
5312                                          uint b8, \
5313                                          uint b9, \
5314                                          uint b10, \
5315                                          uint b11, \
5316                                          uint b12, \
5317                                          uint b13, \
5318                                          uint b14, \
5319                                          uint b15) {{"
5320                )?;
5321                writeln!(
5322                    self.out,
5323                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5324                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5325                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5326                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5327                    back::INDENT
5328                )?;
5329                writeln!(self.out, "}}")?;
5330                Ok((name, 16, 4))
5331            }
5332            Uint32 => {
5333                let name = self.namer.call("unpackUint32");
5334                writeln!(
5335                    self.out,
5336                    "uint {name}(uint b0, \
5337                                 uint b1, \
5338                                 uint b2, \
5339                                 uint b3) {{"
5340                )?;
5341                writeln!(
5342                    self.out,
5343                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5344                    back::INDENT
5345                )?;
5346                writeln!(self.out, "}}")?;
5347                Ok((name, 4, 1))
5348            }
5349            Uint32x2 => {
5350                let name = self.namer.call("unpackUint32x2");
5351                writeln!(
5352                    self.out,
5353                    "uint2 {name}(uint b0, \
5354                                  uint b1, \
5355                                  uint b2, \
5356                                  uint b3, \
5357                                  uint b4, \
5358                                  uint b5, \
5359                                  uint b6, \
5360                                  uint b7) {{"
5361                )?;
5362                writeln!(
5363                    self.out,
5364                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5365                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5366                    back::INDENT
5367                )?;
5368                writeln!(self.out, "}}")?;
5369                Ok((name, 8, 2))
5370            }
5371            Uint32x3 => {
5372                let name = self.namer.call("unpackUint32x3");
5373                writeln!(
5374                    self.out,
5375                    "uint3 {name}(uint b0, \
5376                                  uint b1, \
5377                                  uint b2, \
5378                                  uint b3, \
5379                                  uint b4, \
5380                                  uint b5, \
5381                                  uint b6, \
5382                                  uint b7, \
5383                                  uint b8, \
5384                                  uint b9, \
5385                                  uint b10, \
5386                                  uint b11) {{"
5387                )?;
5388                writeln!(
5389                    self.out,
5390                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5391                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5392                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5393                    back::INDENT
5394                )?;
5395                writeln!(self.out, "}}")?;
5396                Ok((name, 12, 3))
5397            }
5398            Uint32x4 => {
5399                let name = self.namer.call("unpackUint32x4");
5400                writeln!(
5401                    self.out,
5402                    "{NAMESPACE}::uint4 {name}(uint b0, \
5403                                  uint b1, \
5404                                  uint b2, \
5405                                  uint b3, \
5406                                  uint b4, \
5407                                  uint b5, \
5408                                  uint b6, \
5409                                  uint b7, \
5410                                  uint b8, \
5411                                  uint b9, \
5412                                  uint b10, \
5413                                  uint b11, \
5414                                  uint b12, \
5415                                  uint b13, \
5416                                  uint b14, \
5417                                  uint b15) {{"
5418                )?;
5419                writeln!(
5420                    self.out,
5421                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5422                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5423                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5424                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5425                    back::INDENT
5426                )?;
5427                writeln!(self.out, "}}")?;
5428                Ok((name, 16, 4))
5429            }
5430            Sint32 => {
5431                let name = self.namer.call("unpackSint32");
5432                writeln!(
5433                    self.out,
5434                    "int {name}(uint b0, \
5435                                uint b1, \
5436                                uint b2, \
5437                                uint b3) {{"
5438                )?;
5439                writeln!(
5440                    self.out,
5441                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5442                    back::INDENT
5443                )?;
5444                writeln!(self.out, "}}")?;
5445                Ok((name, 4, 1))
5446            }
5447            Sint32x2 => {
5448                let name = self.namer.call("unpackSint32x2");
5449                writeln!(
5450                    self.out,
5451                    "metal::int2 {name}(uint b0, \
5452                                        uint b1, \
5453                                        uint b2, \
5454                                        uint b3, \
5455                                        uint b4, \
5456                                        uint b5, \
5457                                        uint b6, \
5458                                        uint b7) {{"
5459                )?;
5460                writeln!(
5461                    self.out,
5462                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5463                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5464                    back::INDENT
5465                )?;
5466                writeln!(self.out, "}}")?;
5467                Ok((name, 8, 2))
5468            }
5469            Sint32x3 => {
5470                let name = self.namer.call("unpackSint32x3");
5471                writeln!(
5472                    self.out,
5473                    "metal::int3 {name}(uint b0, \
5474                                        uint b1, \
5475                                        uint b2, \
5476                                        uint b3, \
5477                                        uint b4, \
5478                                        uint b5, \
5479                                        uint b6, \
5480                                        uint b7, \
5481                                        uint b8, \
5482                                        uint b9, \
5483                                        uint b10, \
5484                                        uint b11) {{"
5485                )?;
5486                writeln!(
5487                    self.out,
5488                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5489                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5490                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5491                    back::INDENT
5492                )?;
5493                writeln!(self.out, "}}")?;
5494                Ok((name, 12, 3))
5495            }
5496            Sint32x4 => {
5497                let name = self.namer.call("unpackSint32x4");
5498                writeln!(
5499                    self.out,
5500                    "metal::int4 {name}(uint b0, \
5501                                        uint b1, \
5502                                        uint b2, \
5503                                        uint b3, \
5504                                        uint b4, \
5505                                        uint b5, \
5506                                        uint b6, \
5507                                        uint b7, \
5508                                        uint b8, \
5509                                        uint b9, \
5510                                        uint b10, \
5511                                        uint b11, \
5512                                        uint b12, \
5513                                        uint b13, \
5514                                        uint b14, \
5515                                        uint b15) {{"
5516                )?;
5517                writeln!(
5518                    self.out,
5519                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5520                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5521                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5522                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5523                    back::INDENT
5524                )?;
5525                writeln!(self.out, "}}")?;
5526                Ok((name, 16, 4))
5527            }
5528            Unorm10_10_10_2 => {
5529                let name = self.namer.call("unpackUnorm10_10_10_2");
5530                writeln!(
5531                    self.out,
5532                    "metal::float4 {name}(uint b0, \
5533                                          uint b1, \
5534                                          uint b2, \
5535                                          uint b3) {{"
5536                )?;
5537                writeln!(
5538                    self.out,
5539                    // The following is correct for RGBA packing, but our format seems to
5540                    // match ABGR, which can be fed into the Metal builtin function
5541                    // unpack_unorm10a2_to_float.
5542                    /*
5543                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5544                       uint r = (v & 0xFFC00000) >> 22; \
5545                       uint g = (v & 0x003FF000) >> 12; \
5546                       uint b = (v & 0x00000FFC) >> 2; \
5547                       uint a = (v & 0x00000003); \
5548                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5549                    */
5550                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5551                    back::INDENT
5552                )?;
5553                writeln!(self.out, "}}")?;
5554                Ok((name, 4, 4))
5555            }
5556            Unorm8x4Bgra => {
5557                let name = self.namer.call("unpackUnorm8x4Bgra");
5558                writeln!(
5559                    self.out,
5560                    "metal::float4 {name}(metal::uchar b0, \
5561                                          metal::uchar b1, \
5562                                          metal::uchar b2, \
5563                                          metal::uchar b3) {{"
5564                )?;
5565                writeln!(
5566                    self.out,
5567                    "{}return metal::float4(float(b2) / 255.0f, \
5568                                            float(b1) / 255.0f, \
5569                                            float(b0) / 255.0f, \
5570                                            float(b3) / 255.0f);",
5571                    back::INDENT
5572                )?;
5573                writeln!(self.out, "}}")?;
5574                Ok((name, 4, 4))
5575            }
5576        }
5577    }
5578
5579    fn write_wrapped_unary_op(
5580        &mut self,
5581        module: &crate::Module,
5582        func_ctx: &back::FunctionCtx,
5583        op: crate::UnaryOperator,
5584        operand: Handle<crate::Expression>,
5585    ) -> BackendResult {
5586        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5587        match op {
5588            // Negating the TYPE_MIN of a two's complement signed integer
5589            // type causes overflow, which is undefined behaviour in MSL. To
5590            // avoid this we bitcast the value to unsigned and negate it,
5591            // then bitcast back to signed.
5592            // This adheres to the WGSL spec in that the negative of the
5593            // type's minimum value should equal to the minimum value.
5594            crate::UnaryOperator::Negate
5595                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5596            {
5597                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5598                    return Ok(());
5599                };
5600                let wrapped = WrappedFunction::UnaryOp {
5601                    op,
5602                    ty: (vector_size, scalar),
5603                };
5604                if !self.wrapped_functions.insert(wrapped) {
5605                    return Ok(());
5606                }
5607
5608                let unsigned_scalar = crate::Scalar {
5609                    kind: crate::ScalarKind::Uint,
5610                    ..scalar
5611                };
5612                let mut type_name = String::new();
5613                let mut unsigned_type_name = String::new();
5614                match vector_size {
5615                    None => {
5616                        put_numeric_type(&mut type_name, scalar, &[])?;
5617                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5618                    }
5619                    Some(size) => {
5620                        put_numeric_type(&mut type_name, scalar, &[size])?;
5621                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5622                    }
5623                };
5624
5625                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5626                let level = back::Level(1);
5627                writeln!(
5628                    self.out,
5629                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5630                )?;
5631                writeln!(self.out, "}}")?;
5632                writeln!(self.out)?;
5633            }
5634            _ => {}
5635        }
5636        Ok(())
5637    }
5638
5639    fn write_wrapped_binary_op(
5640        &mut self,
5641        module: &crate::Module,
5642        func_ctx: &back::FunctionCtx,
5643        expr: Handle<crate::Expression>,
5644        op: crate::BinaryOperator,
5645        left: Handle<crate::Expression>,
5646        right: Handle<crate::Expression>,
5647    ) -> BackendResult {
5648        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5649        let left_ty = func_ctx.resolve_type(left, &module.types);
5650        let right_ty = func_ctx.resolve_type(right, &module.types);
5651        match (op, expr_ty.scalar_kind()) {
5652            // Signed integer division of TYPE_MIN / -1, or signed or
5653            // unsigned division by zero, gives an unspecified value in MSL.
5654            // We override the divisor to 1 in these cases.
5655            // This adheres to the WGSL spec in that:
5656            // * TYPE_MIN / -1 == TYPE_MIN
5657            // * x / 0 == x
5658            (
5659                crate::BinaryOperator::Divide,
5660                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5661            ) => {
5662                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5663                    return Ok(());
5664                };
5665                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5666                    return Ok(());
5667                };
5668                let wrapped = WrappedFunction::BinaryOp {
5669                    op,
5670                    left_ty: left_wrapped_ty,
5671                    right_ty: right_wrapped_ty,
5672                };
5673                if !self.wrapped_functions.insert(wrapped) {
5674                    return Ok(());
5675                }
5676
5677                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5678                    return Ok(());
5679                };
5680                let mut type_name = String::new();
5681                match vector_size {
5682                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5683                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5684                };
5685                writeln!(
5686                    self.out,
5687                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5688                )?;
5689                let level = back::Level(1);
5690                match scalar.kind {
5691                    crate::ScalarKind::Sint => {
5692                        let min_val = match scalar.width {
5693                            4 => crate::Literal::I32(i32::MIN),
5694                            8 => crate::Literal::I64(i64::MIN),
5695                            _ => {
5696                                return Err(Error::GenericValidation(format!(
5697                                    "Unexpected width for scalar {scalar:?}"
5698                                )));
5699                            }
5700                        };
5701                        write!(
5702                            self.out,
5703                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5704                        )?;
5705                        self.put_literal(min_val)?;
5706                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5707                    }
5708                    crate::ScalarKind::Uint => writeln!(
5709                        self.out,
5710                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5711                    )?,
5712                    _ => unreachable!(),
5713                }
5714                writeln!(self.out, "}}")?;
5715                writeln!(self.out)?;
5716            }
5717            // Integer modulo where one or both operands are negative, or the
5718            // divisor is zero, is undefined behaviour in MSL. To avoid this
5719            // we use the following equation:
5720            //
5721            // dividend - (dividend / divisor) * divisor
5722            //
5723            // overriding the divisor to 1 if either it is 0, or it is -1
5724            // and the dividend is TYPE_MIN.
5725            //
5726            // This adheres to the WGSL spec in that:
5727            // * TYPE_MIN % -1 == 0
5728            // * x % 0 == 0
5729            (
5730                crate::BinaryOperator::Modulo,
5731                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5732            ) => {
5733                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5734                    return Ok(());
5735                };
5736                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5737                else {
5738                    return Ok(());
5739                };
5740                let wrapped = WrappedFunction::BinaryOp {
5741                    op,
5742                    left_ty: left_wrapped_ty,
5743                    right_ty: (right_vector_size, right_scalar),
5744                };
5745                if !self.wrapped_functions.insert(wrapped) {
5746                    return Ok(());
5747                }
5748
5749                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5750                    return Ok(());
5751                };
5752                let mut type_name = String::new();
5753                match vector_size {
5754                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5755                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5756                };
5757                let mut rhs_type_name = String::new();
5758                match right_vector_size {
5759                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5760                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5761                };
5762
5763                writeln!(
5764                    self.out,
5765                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5766                )?;
5767                let level = back::Level(1);
5768                match scalar.kind {
5769                    crate::ScalarKind::Sint => {
5770                        let min_val = match scalar.width {
5771                            4 => crate::Literal::I32(i32::MIN),
5772                            8 => crate::Literal::I64(i64::MIN),
5773                            _ => {
5774                                return Err(Error::GenericValidation(format!(
5775                                    "Unexpected width for scalar {scalar:?}"
5776                                )));
5777                            }
5778                        };
5779                        write!(
5780                            self.out,
5781                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5782                        )?;
5783                        self.put_literal(min_val)?;
5784                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5785                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5786                    }
5787                    crate::ScalarKind::Uint => writeln!(
5788                        self.out,
5789                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5790                    )?,
5791                    _ => unreachable!(),
5792                }
5793                writeln!(self.out, "}}")?;
5794                writeln!(self.out)?;
5795            }
5796            _ => {}
5797        }
5798        Ok(())
5799    }
5800
5801    /// Build the mangled helper name for integer vector dot products.
5802    ///
5803    /// `scalar` must be a concrete integer scalar type.
5804    ///
5805    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5806    fn get_dot_wrapper_function_helper_name(
5807        &self,
5808        scalar: crate::Scalar,
5809        size: crate::VectorSize,
5810    ) -> String {
5811        // Check for consistency with [`super::keywords::RESERVED_SET`]
5812        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5813
5814        let type_name = scalar.to_msl_name();
5815        let size_suffix = common::vector_size_str(size);
5816        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5817    }
5818
5819    #[allow(clippy::too_many_arguments)]
5820    fn write_wrapped_math_function(
5821        &mut self,
5822        module: &crate::Module,
5823        func_ctx: &back::FunctionCtx,
5824        fun: crate::MathFunction,
5825        arg: Handle<crate::Expression>,
5826        _arg1: Option<Handle<crate::Expression>>,
5827        _arg2: Option<Handle<crate::Expression>>,
5828        _arg3: Option<Handle<crate::Expression>>,
5829    ) -> BackendResult {
5830        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5831        match fun {
5832            // Taking the absolute value of the TYPE_MIN of a two's
5833            // complement signed integer type causes overflow, which is
5834            // undefined behaviour in MSL. To avoid this, when the value is
5835            // negative we bitcast the value to unsigned and negate it, then
5836            // bitcast back to signed.
5837            // This adheres to the WGSL spec in that the absolute of the
5838            // type's minimum value should equal to the minimum value.
5839            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5840                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5841                    return Ok(());
5842                };
5843                let wrapped = WrappedFunction::Math {
5844                    fun,
5845                    arg_ty: (vector_size, scalar),
5846                };
5847                if !self.wrapped_functions.insert(wrapped) {
5848                    return Ok(());
5849                }
5850
5851                let unsigned_scalar = crate::Scalar {
5852                    kind: crate::ScalarKind::Uint,
5853                    ..scalar
5854                };
5855                let mut type_name = String::new();
5856                let mut unsigned_type_name = String::new();
5857                match vector_size {
5858                    None => {
5859                        put_numeric_type(&mut type_name, scalar, &[])?;
5860                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5861                    }
5862                    Some(size) => {
5863                        put_numeric_type(&mut type_name, scalar, &[size])?;
5864                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5865                    }
5866                };
5867
5868                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5869                let level = back::Level(1);
5870                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5871                writeln!(self.out, "}}")?;
5872                writeln!(self.out)?;
5873            }
5874
5875            crate::MathFunction::Dot => match *arg_ty {
5876                crate::TypeInner::Vector { size, scalar }
5877                    if matches!(
5878                        scalar.kind,
5879                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5880                    ) =>
5881                {
5882                    // De-duplicate per (fun, arg type) like other wrapped math functions
5883                    let wrapped = WrappedFunction::Math {
5884                        fun,
5885                        arg_ty: (Some(size), scalar),
5886                    };
5887                    if !self.wrapped_functions.insert(wrapped) {
5888                        return Ok(());
5889                    }
5890
5891                    let mut vec_ty = String::new();
5892                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
5893                    let mut ret_ty = String::new();
5894                    put_numeric_type(&mut ret_ty, scalar, &[])?;
5895
5896                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5897
5898                    // Emit function signature and body using put_dot_product for the expression
5899                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5900                    let level = back::Level(1);
5901                    write!(self.out, "{level}return ")?;
5902                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5903                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5904                        Ok(())
5905                    })?;
5906                    writeln!(self.out, ";")?;
5907                    writeln!(self.out, "}}")?;
5908                    writeln!(self.out)?;
5909                }
5910                _ => {}
5911            },
5912
5913            _ => {}
5914        }
5915        Ok(())
5916    }
5917
5918    fn write_wrapped_cast(
5919        &mut self,
5920        module: &crate::Module,
5921        func_ctx: &back::FunctionCtx,
5922        expr: Handle<crate::Expression>,
5923        kind: crate::ScalarKind,
5924        convert: Option<crate::Bytes>,
5925    ) -> BackendResult {
5926        // Avoid undefined behaviour when casting from a float to integer
5927        // when the value is out of range for the target type. Additionally
5928        // ensure we clamp to the correct value as per the WGSL spec.
5929        //
5930        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
5931        // * If X is exactly representable in the target type T, then the
5932        //   result is that value.
5933        // * Otherwise, the result is the value in T closest to
5934        //   truncate(X) and also exactly representable in the original
5935        //   floating point type.
5936        let src_ty = func_ctx.resolve_type(expr, &module.types);
5937        let Some(width) = convert else {
5938            return Ok(());
5939        };
5940        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5941            return Ok(());
5942        };
5943        let dst_scalar = crate::Scalar { kind, width };
5944        if src_scalar.kind != crate::ScalarKind::Float
5945            || (dst_scalar.kind != crate::ScalarKind::Sint
5946                && dst_scalar.kind != crate::ScalarKind::Uint)
5947        {
5948            return Ok(());
5949        }
5950        let wrapped = WrappedFunction::Cast {
5951            src_scalar,
5952            vector_size,
5953            dst_scalar,
5954        };
5955        if !self.wrapped_functions.insert(wrapped) {
5956            return Ok(());
5957        }
5958        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5959
5960        let mut src_type_name = String::new();
5961        match vector_size {
5962            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5963            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5964        };
5965        let mut dst_type_name = String::new();
5966        match vector_size {
5967            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5968            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5969        };
5970        let fun_name = match dst_scalar {
5971            crate::Scalar::I32 => F2I32_FUNCTION,
5972            crate::Scalar::U32 => F2U32_FUNCTION,
5973            crate::Scalar::I64 => F2I64_FUNCTION,
5974            crate::Scalar::U64 => F2U64_FUNCTION,
5975            _ => unreachable!(),
5976        };
5977
5978        writeln!(
5979            self.out,
5980            "{dst_type_name} {fun_name}({src_type_name} value) {{"
5981        )?;
5982        let level = back::Level(1);
5983        write!(
5984            self.out,
5985            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5986        )?;
5987        self.put_literal(min)?;
5988        write!(self.out, ", ")?;
5989        self.put_literal(max)?;
5990        writeln!(self.out, "));")?;
5991        writeln!(self.out, "}}")?;
5992        writeln!(self.out)?;
5993        Ok(())
5994    }
5995
5996    /// Helper function used by [`Self::write_wrapped_image_load`] and
5997    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
5998    /// conversion code for external textures. Expects the preceding code to
5999    /// declare the Y component as a `float` variable of name `y`, the UV
6000    /// components as a `float2` variable of name `uv`, and the external
6001    /// texture params as a variable of name `params`. The emitted code will
6002    /// return the result.
6003    fn write_convert_yuv_to_rgb_and_return(
6004        &mut self,
6005        level: back::Level,
6006        y: &str,
6007        uv: &str,
6008        params: &str,
6009    ) -> BackendResult {
6010        let l1 = level;
6011        let l2 = l1.next();
6012
6013        // Convert from YUV to non-linear RGB in the source color space.
6014        writeln!(
6015            self.out,
6016            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6017        )?;
6018
6019        // Apply the inverse of the source transfer function to convert to
6020        // linear RGB in the source color space.
6021        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6022        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6023        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6024        writeln!(
6025            self.out,
6026            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6027        )?;
6028
6029        // Multiply by the gamut conversion matrix to convert to linear RGB in
6030        // the destination color space.
6031        writeln!(
6032            self.out,
6033            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6034        )?;
6035
6036        // Finally, apply the dest transfer function to convert to non-linear
6037        // RGB in the destination color space, and return the result.
6038        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6039        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6040        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6041        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6042
6043        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6044        Ok(())
6045    }
6046
6047    #[allow(clippy::too_many_arguments)]
6048    fn write_wrapped_image_load(
6049        &mut self,
6050        module: &crate::Module,
6051        func_ctx: &back::FunctionCtx,
6052        image: Handle<crate::Expression>,
6053        _coordinate: Handle<crate::Expression>,
6054        _array_index: Option<Handle<crate::Expression>>,
6055        _sample: Option<Handle<crate::Expression>>,
6056        _level: Option<Handle<crate::Expression>>,
6057    ) -> BackendResult {
6058        // We currently only need to wrap image loads for external textures
6059        let class = match *func_ctx.resolve_type(image, &module.types) {
6060            crate::TypeInner::Image { class, .. } => class,
6061            _ => unreachable!(),
6062        };
6063        if class != crate::ImageClass::External {
6064            return Ok(());
6065        }
6066        let wrapped = WrappedFunction::ImageLoad { class };
6067        if !self.wrapped_functions.insert(wrapped) {
6068            return Ok(());
6069        }
6070
6071        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6072        let l1 = back::Level(1);
6073        let l2 = l1.next();
6074        let l3 = l2.next();
6075        writeln!(
6076            self.out,
6077            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6078        )?;
6079        // Clamp coords to provided size of external texture to prevent OOB
6080        // read. If params.size is zero then clamp to the actual size of the
6081        // texture.
6082        writeln!(
6083            self.out,
6084            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6085        )?;
6086        writeln!(
6087            self.out,
6088            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6089        )?;
6090
6091        // Apply load transformation
6092        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6093        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6094        // For single plane, simply read from plane0
6095        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6096        writeln!(self.out, "{l1}}} else {{")?;
6097
6098        // Chroma planes may be subsampled so we must scale the coords accordingly.
6099        writeln!(
6100            self.out,
6101            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6102        )?;
6103        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6104
6105        // For multi-plane, read the Y value from plane 0
6106        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6107
6108        writeln!(self.out, "{l2}float2 uv;")?;
6109        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6110        // For 2 planes, read UV from interleaved plane 1
6111        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6112        writeln!(self.out, "{l2}}} else {{")?;
6113        // For 3 planes, read U and V from planes 1 and 2 respectively
6114        writeln!(
6115            self.out,
6116            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6117        )?;
6118        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6119        writeln!(
6120            self.out,
6121            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6122        )?;
6123        writeln!(self.out, "{l2}}}")?;
6124
6125        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6126
6127        writeln!(self.out, "{l1}}}")?;
6128        writeln!(self.out, "}}")?;
6129        writeln!(self.out)?;
6130        Ok(())
6131    }
6132
6133    #[allow(clippy::too_many_arguments)]
6134    fn write_wrapped_image_sample(
6135        &mut self,
6136        module: &crate::Module,
6137        func_ctx: &back::FunctionCtx,
6138        image: Handle<crate::Expression>,
6139        _sampler: Handle<crate::Expression>,
6140        _gather: Option<crate::SwizzleComponent>,
6141        _coordinate: Handle<crate::Expression>,
6142        _array_index: Option<Handle<crate::Expression>>,
6143        _offset: Option<Handle<crate::Expression>>,
6144        _level: crate::SampleLevel,
6145        _depth_ref: Option<Handle<crate::Expression>>,
6146        clamp_to_edge: bool,
6147    ) -> BackendResult {
6148        // We currently only need to wrap textureSampleBaseClampToEdge, for
6149        // both sampled and external textures.
6150        if !clamp_to_edge {
6151            return Ok(());
6152        }
6153        let class = match *func_ctx.resolve_type(image, &module.types) {
6154            crate::TypeInner::Image { class, .. } => class,
6155            _ => unreachable!(),
6156        };
6157        let wrapped = WrappedFunction::ImageSample {
6158            class,
6159            clamp_to_edge: true,
6160        };
6161        if !self.wrapped_functions.insert(wrapped) {
6162            return Ok(());
6163        }
6164        match class {
6165            crate::ImageClass::External => {
6166                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6167                let l1 = back::Level(1);
6168                let l2 = l1.next();
6169                let l3 = l2.next();
6170                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6171                writeln!(
6172                    self.out,
6173                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6174                )?;
6175
6176                // Calculate the sample bounds. The purported size of the texture
6177                // (params.size) is irrelevant here as we are dealing with normalized
6178                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6179                // apply the sample transformation to that, also bearing in mind that it
6180                // may contain a flip on either axis. We calculate and adjust for the
6181                // half-texel separately for each plane as it depends on the actual
6182                // texture size which may vary between planes.
6183                writeln!(
6184                    self.out,
6185                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6186                )?;
6187                writeln!(
6188                    self.out,
6189                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6190                )?;
6191                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6192                writeln!(
6193                    self.out,
6194                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6195                )?;
6196                writeln!(
6197                    self.out,
6198                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6199                )?;
6200                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6201                // For single plane, simply sample from plane0
6202                writeln!(
6203                    self.out,
6204                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6205                )?;
6206                writeln!(self.out, "{l1}}} else {{")?;
6207                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6208                writeln!(
6209                    self.out,
6210                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6211                )?;
6212                writeln!(
6213                    self.out,
6214                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6215                )?;
6216
6217                // For multi-plane, sample the Y value from plane 0
6218                writeln!(
6219                    self.out,
6220                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6221                )?;
6222                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6223                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6224                // For 2 planes, sample UV from interleaved plane 1
6225                writeln!(
6226                    self.out,
6227                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6228                )?;
6229                writeln!(self.out, "{l2}}} else {{")?;
6230                // For 3 planes, sample U and V from planes 1 and 2 respectively
6231                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6232                writeln!(
6233                    self.out,
6234                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6235                )?;
6236                writeln!(
6237                    self.out,
6238                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6239                )?;
6240                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6241                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6242                writeln!(self.out, "{l2}}}")?;
6243
6244                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6245
6246                writeln!(self.out, "{l1}}}")?;
6247                writeln!(self.out, "}}")?;
6248                writeln!(self.out)?;
6249            }
6250            _ => {
6251                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6252                let l1 = back::Level(1);
6253                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6254                writeln!(
6255                    self.out,
6256                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6257                )?;
6258                writeln!(self.out, "}}")?;
6259                writeln!(self.out)?;
6260            }
6261        }
6262        Ok(())
6263    }
6264
6265    fn write_wrapped_image_query(
6266        &mut self,
6267        module: &crate::Module,
6268        func_ctx: &back::FunctionCtx,
6269        image: Handle<crate::Expression>,
6270        query: crate::ImageQuery,
6271    ) -> BackendResult {
6272        // We currently only need to wrap size image queries for external textures
6273        if !matches!(query, crate::ImageQuery::Size { .. }) {
6274            return Ok(());
6275        }
6276        let class = match *func_ctx.resolve_type(image, &module.types) {
6277            crate::TypeInner::Image { class, .. } => class,
6278            _ => unreachable!(),
6279        };
6280        if class != crate::ImageClass::External {
6281            return Ok(());
6282        }
6283        let wrapped = WrappedFunction::ImageQuerySize { class };
6284        if !self.wrapped_functions.insert(wrapped) {
6285            return Ok(());
6286        }
6287        writeln!(
6288            self.out,
6289            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6290        )?;
6291        let l1 = back::Level(1);
6292        let l2 = l1.next();
6293        writeln!(
6294            self.out,
6295            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6296        )?;
6297        writeln!(self.out, "{l2}return tex.params.size;")?;
6298        writeln!(self.out, "{l1}}} else {{")?;
6299        // params.size == (0, 0) indicates to query and return plane 0's actual size
6300        writeln!(
6301            self.out,
6302            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6303        )?;
6304        writeln!(self.out, "{l1}}}")?;
6305        writeln!(self.out, "}}")?;
6306        writeln!(self.out)?;
6307        Ok(())
6308    }
6309
6310    pub(super) fn write_wrapped_functions(
6311        &mut self,
6312        module: &crate::Module,
6313        func_ctx: &back::FunctionCtx,
6314    ) -> BackendResult {
6315        for (expr_handle, expr) in func_ctx.expressions.iter() {
6316            match *expr {
6317                crate::Expression::Unary { op, expr: operand } => {
6318                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6319                }
6320                crate::Expression::Binary { op, left, right } => {
6321                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6322                }
6323                crate::Expression::Math {
6324                    fun,
6325                    arg,
6326                    arg1,
6327                    arg2,
6328                    arg3,
6329                } => {
6330                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6331                }
6332                crate::Expression::As {
6333                    expr,
6334                    kind,
6335                    convert,
6336                } => {
6337                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6338                }
6339                crate::Expression::ImageLoad {
6340                    image,
6341                    coordinate,
6342                    array_index,
6343                    sample,
6344                    level,
6345                } => {
6346                    self.write_wrapped_image_load(
6347                        module,
6348                        func_ctx,
6349                        image,
6350                        coordinate,
6351                        array_index,
6352                        sample,
6353                        level,
6354                    )?;
6355                }
6356                crate::Expression::ImageSample {
6357                    image,
6358                    sampler,
6359                    gather,
6360                    coordinate,
6361                    array_index,
6362                    offset,
6363                    level,
6364                    depth_ref,
6365                    clamp_to_edge,
6366                } => {
6367                    self.write_wrapped_image_sample(
6368                        module,
6369                        func_ctx,
6370                        image,
6371                        sampler,
6372                        gather,
6373                        coordinate,
6374                        array_index,
6375                        offset,
6376                        level,
6377                        depth_ref,
6378                        clamp_to_edge,
6379                    )?;
6380                }
6381                crate::Expression::ImageQuery { image, query } => {
6382                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6383                }
6384                _ => {}
6385            }
6386        }
6387
6388        Ok(())
6389    }
6390
6391    // Returns the array of mapped entry point names.
6392    fn write_functions(
6393        &mut self,
6394        module: &crate::Module,
6395        mod_info: &valid::ModuleInfo,
6396        options: &Options,
6397        pipeline_options: &PipelineOptions,
6398    ) -> Result<TranslationInfo, Error> {
6399        use back::msl::VertexFormat;
6400
6401        // Define structs to hold resolved/generated data for vertex buffers and
6402        // their attributes.
6403        struct AttributeMappingResolved {
6404            ty_name: String,
6405            dimension: u32,
6406            ty_is_int: bool,
6407            name: String,
6408        }
6409        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6410
6411        struct VertexBufferMappingResolved<'a> {
6412            id: u32,
6413            stride: u32,
6414            step_mode: back::msl::VertexBufferStepMode,
6415            ty_name: String,
6416            param_name: String,
6417            elem_name: String,
6418            attributes: &'a Vec<back::msl::AttributeMapping>,
6419        }
6420        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6421
6422        // Define a struct to hold a named reference to a byte-unpacking function.
6423        struct UnpackingFunction {
6424            name: String,
6425            byte_count: u32,
6426            dimension: u32,
6427        }
6428        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6429
6430        // Check if we are attempting vertex pulling. If we are, generate some
6431        // names we'll need, and iterate the vertex buffer mappings to output
6432        // all the conversion functions we'll need to unpack the attribute data.
6433        // We can re-use these names for all entry points that need them, since
6434        // those entry points also use self.namer.
6435        let mut needs_vertex_id = false;
6436        let v_id = self.namer.call("v_id");
6437
6438        let mut needs_instance_id = false;
6439        let i_id = self.namer.call("i_id");
6440        if pipeline_options.vertex_pulling_transform {
6441            for vbm in &pipeline_options.vertex_buffer_mappings {
6442                let buffer_id = vbm.id;
6443                let buffer_stride = vbm.stride;
6444
6445                assert!(
6446                    buffer_stride > 0,
6447                    "Vertex pulling requires a non-zero buffer stride."
6448                );
6449
6450                match vbm.step_mode {
6451                    back::msl::VertexBufferStepMode::Constant => {}
6452                    back::msl::VertexBufferStepMode::ByVertex => {
6453                        needs_vertex_id = true;
6454                    }
6455                    back::msl::VertexBufferStepMode::ByInstance => {
6456                        needs_instance_id = true;
6457                    }
6458                }
6459
6460                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6461                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6462                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6463
6464                vbm_resolved.push(VertexBufferMappingResolved {
6465                    id: buffer_id,
6466                    stride: buffer_stride,
6467                    step_mode: vbm.step_mode,
6468                    ty_name: buffer_ty,
6469                    param_name: buffer_param,
6470                    elem_name: buffer_elem,
6471                    attributes: &vbm.attributes,
6472                });
6473
6474                // Iterate the attributes and generate needed unpacking functions.
6475                for attribute in &vbm.attributes {
6476                    if unpacking_functions.contains_key(&attribute.format) {
6477                        continue;
6478                    }
6479                    let (name, byte_count, dimension) =
6480                        match self.write_unpacking_function(attribute.format) {
6481                            Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6482                            _ => {
6483                                continue;
6484                            }
6485                        };
6486                    unpacking_functions.insert(
6487                        attribute.format,
6488                        UnpackingFunction {
6489                            name,
6490                            byte_count,
6491                            dimension,
6492                        },
6493                    );
6494                }
6495            }
6496        }
6497
6498        let mut pass_through_globals = Vec::new();
6499        for (fun_handle, fun) in module.functions.iter() {
6500            log::trace!(
6501                "function {:?}, handle {:?}",
6502                fun.name.as_deref().unwrap_or("(anonymous)"),
6503                fun_handle
6504            );
6505
6506            let ctx = back::FunctionCtx {
6507                ty: back::FunctionType::Function(fun_handle),
6508                info: &mod_info[fun_handle],
6509                expressions: &fun.expressions,
6510                named_expressions: &fun.named_expressions,
6511            };
6512
6513            writeln!(self.out)?;
6514            self.write_wrapped_functions(module, &ctx)?;
6515
6516            let fun_info = &mod_info[fun_handle];
6517            pass_through_globals.clear();
6518            let mut needs_buffer_sizes = false;
6519            for (handle, var) in module.global_variables.iter() {
6520                if !fun_info[handle].is_empty() {
6521                    if var.space.needs_pass_through() {
6522                        pass_through_globals.push(handle);
6523                    }
6524                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6525                }
6526            }
6527
6528            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6529            match fun.result {
6530                Some(ref result) => {
6531                    let ty_name = TypeContext {
6532                        handle: result.ty,
6533                        gctx: module.to_ctx(),
6534                        names: &self.names,
6535                        access: crate::StorageAccess::empty(),
6536                        first_time: false,
6537                    };
6538                    write!(self.out, "{ty_name}")?;
6539                }
6540                None => {
6541                    write!(self.out, "void")?;
6542                }
6543            }
6544            writeln!(self.out, " {fun_name}(")?;
6545
6546            for (index, arg) in fun.arguments.iter().enumerate() {
6547                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6548                let param_type_name = TypeContext {
6549                    handle: arg.ty,
6550                    gctx: module.to_ctx(),
6551                    names: &self.names,
6552                    access: crate::StorageAccess::empty(),
6553                    first_time: false,
6554                };
6555                let separator = separate(
6556                    !pass_through_globals.is_empty()
6557                        || index + 1 != fun.arguments.len()
6558                        || needs_buffer_sizes,
6559                );
6560                writeln!(
6561                    self.out,
6562                    "{}{} {}{}",
6563                    back::INDENT,
6564                    param_type_name,
6565                    name,
6566                    separator
6567                )?;
6568            }
6569            for (index, &handle) in pass_through_globals.iter().enumerate() {
6570                let tyvar = TypedGlobalVariable {
6571                    module,
6572                    names: &self.names,
6573                    handle,
6574                    usage: fun_info[handle],
6575
6576                    reference: true,
6577                };
6578                let separator =
6579                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6580                write!(self.out, "{}", back::INDENT)?;
6581                tyvar.try_fmt(&mut self.out)?;
6582                writeln!(self.out, "{separator}")?;
6583            }
6584
6585            if needs_buffer_sizes {
6586                writeln!(
6587                    self.out,
6588                    "{}constant _mslBufferSizes& _buffer_sizes",
6589                    back::INDENT
6590                )?;
6591            }
6592
6593            writeln!(self.out, ") {{")?;
6594
6595            let guarded_indices =
6596                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6597
6598            let context = StatementContext {
6599                expression: ExpressionContext {
6600                    function: fun,
6601                    origin: FunctionOrigin::Handle(fun_handle),
6602                    info: fun_info,
6603                    lang_version: options.lang_version,
6604                    policies: options.bounds_check_policies,
6605                    guarded_indices,
6606                    module,
6607                    mod_info,
6608                    pipeline_options,
6609                    force_loop_bounding: options.force_loop_bounding,
6610                },
6611                result_struct: None,
6612            };
6613
6614            self.put_locals(&context.expression)?;
6615            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6616            self.put_block(back::Level(1), &fun.body, &context)?;
6617            writeln!(self.out, "}}")?;
6618            self.named_expressions.clear();
6619        }
6620
6621        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6622            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6623
6624        let mut info = TranslationInfo {
6625            entry_point_names: Vec::with_capacity(ep_range.len()),
6626        };
6627
6628        for ep_index in ep_range {
6629            let ep = &module.entry_points[ep_index];
6630            let fun = &ep.function;
6631            let fun_info = mod_info.get_entry_point(ep_index);
6632            let mut ep_error = None;
6633
6634            // For vertex_id and instance_id arguments, presume that we'll
6635            // use our generated names, but switch to the name of an
6636            // existing @builtin param, if we find one.
6637            let mut v_existing_id = None;
6638            let mut i_existing_id = None;
6639
6640            log::trace!(
6641                "entry point {:?}, index {:?}",
6642                fun.name.as_deref().unwrap_or("(anonymous)"),
6643                ep_index
6644            );
6645
6646            let ctx = back::FunctionCtx {
6647                ty: back::FunctionType::EntryPoint(ep_index as u16),
6648                info: fun_info,
6649                expressions: &fun.expressions,
6650                named_expressions: &fun.named_expressions,
6651            };
6652
6653            self.write_wrapped_functions(module, &ctx)?;
6654
6655            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6656                crate::ShaderStage::Vertex => (
6657                    "vertex",
6658                    LocationMode::VertexInput,
6659                    LocationMode::VertexOutput,
6660                    true,
6661                ),
6662                crate::ShaderStage::Fragment => (
6663                    "fragment",
6664                    LocationMode::FragmentInput,
6665                    LocationMode::FragmentOutput,
6666                    false,
6667                ),
6668                crate::ShaderStage::Compute => (
6669                    "kernel",
6670                    LocationMode::Uniform,
6671                    LocationMode::Uniform,
6672                    false,
6673                ),
6674                crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6675            };
6676
6677            // Should this entry point be modified to do vertex pulling?
6678            let do_vertex_pulling = can_vertex_pull
6679                && pipeline_options.vertex_pulling_transform
6680                && !pipeline_options.vertex_buffer_mappings.is_empty();
6681
6682            // Is any global variable used by this entry point dynamically sized?
6683            let needs_buffer_sizes = do_vertex_pulling
6684                || module
6685                    .global_variables
6686                    .iter()
6687                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6688                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6689
6690            // skip this entry point if any global bindings are missing,
6691            // or their types are incompatible.
6692            if !options.fake_missing_bindings {
6693                for (var_handle, var) in module.global_variables.iter() {
6694                    if fun_info[var_handle].is_empty() {
6695                        continue;
6696                    }
6697                    match var.space {
6698                        crate::AddressSpace::Uniform
6699                        | crate::AddressSpace::Storage { .. }
6700                        | crate::AddressSpace::Handle => {
6701                            let br = match var.binding {
6702                                Some(ref br) => br,
6703                                None => {
6704                                    let var_name = var.name.clone().unwrap_or_default();
6705                                    ep_error =
6706                                        Some(super::EntryPointError::MissingBinding(var_name));
6707                                    break;
6708                                }
6709                            };
6710                            let target = options.get_resource_binding_target(ep, br);
6711                            let good = match target {
6712                                Some(target) => {
6713                                    // We intentionally don't dereference binding_arrays here,
6714                                    // so that binding arrays fall to the buffer location.
6715
6716                                    match module.types[var.ty].inner {
6717                                        crate::TypeInner::Image {
6718                                            class: crate::ImageClass::External,
6719                                            ..
6720                                        } => target.external_texture.is_some(),
6721                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6722                                        crate::TypeInner::Sampler { .. } => {
6723                                            target.sampler.is_some()
6724                                        }
6725                                        _ => target.buffer.is_some(),
6726                                    }
6727                                }
6728                                None => false,
6729                            };
6730                            if !good {
6731                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6732                                break;
6733                            }
6734                        }
6735                        crate::AddressSpace::PushConstant => {
6736                            if let Err(e) = options.resolve_push_constants(ep) {
6737                                ep_error = Some(e);
6738                                break;
6739                            }
6740                        }
6741                        crate::AddressSpace::TaskPayload => {
6742                            unimplemented!()
6743                        }
6744                        crate::AddressSpace::Function
6745                        | crate::AddressSpace::Private
6746                        | crate::AddressSpace::WorkGroup => {}
6747                    }
6748                }
6749                if needs_buffer_sizes {
6750                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6751                        ep_error = Some(err);
6752                    }
6753                }
6754            }
6755
6756            if let Some(err) = ep_error {
6757                info.entry_point_names.push(Err(err));
6758                continue;
6759            }
6760            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6761            info.entry_point_names.push(Ok(fun_name.clone()));
6762
6763            writeln!(self.out)?;
6764
6765            // Since `Namer.reset` wasn't expecting struct members to be
6766            // suddenly injected into another namespace like this,
6767            // `self.names` doesn't keep them distinct from other variables.
6768            // Generate fresh names for these arguments, and remember the
6769            // mapping.
6770            let mut flattened_member_names = FastHashMap::default();
6771            // Varyings' members get their own namespace
6772            let mut varyings_namer = proc::Namer::default();
6773
6774            // List all the Naga `EntryPoint`'s `Function`'s arguments,
6775            // flattening structs into their members. In Metal, we will pass
6776            // each of these values to the entry point as a separate argument—
6777            // except for the varyings, handled next.
6778            let mut flattened_arguments = Vec::new();
6779            for (arg_index, arg) in fun.arguments.iter().enumerate() {
6780                match module.types[arg.ty].inner {
6781                    crate::TypeInner::Struct { ref members, .. } => {
6782                        for (member_index, member) in members.iter().enumerate() {
6783                            let member_index = member_index as u32;
6784                            flattened_arguments.push((
6785                                NameKey::StructMember(arg.ty, member_index),
6786                                member.ty,
6787                                member.binding.as_ref(),
6788                            ));
6789                            let name_key = NameKey::StructMember(arg.ty, member_index);
6790                            let name = match member.binding {
6791                                Some(crate::Binding::Location { .. }) => {
6792                                    if do_vertex_pulling {
6793                                        self.namer.call(&self.names[&name_key])
6794                                    } else {
6795                                        varyings_namer.call(&self.names[&name_key])
6796                                    }
6797                                }
6798                                _ => self.namer.call(&self.names[&name_key]),
6799                            };
6800                            flattened_member_names.insert(name_key, name);
6801                        }
6802                    }
6803                    _ => flattened_arguments.push((
6804                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6805                        arg.ty,
6806                        arg.binding.as_ref(),
6807                    )),
6808                }
6809            }
6810
6811            // Identify the varyings among the argument values, and maybe emit
6812            // a struct type named `<fun>Input` to hold them. If we are doing
6813            // vertex pulling, we instead update our attribute mapping to
6814            // note the types, names, and zero values of the attributes.
6815            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6816            let varyings_member_name = self.namer.call("varyings");
6817            let mut has_varyings = false;
6818            if !flattened_arguments.is_empty() {
6819                if !do_vertex_pulling {
6820                    writeln!(self.out, "struct {stage_in_name} {{")?;
6821                }
6822                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6823                    let (binding, location) = match binding {
6824                        Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6825                            (binding, location)
6826                        }
6827                        _ => continue,
6828                    };
6829                    let name = match *name_key {
6830                        NameKey::StructMember(..) => &flattened_member_names[name_key],
6831                        _ => &self.names[name_key],
6832                    };
6833                    let ty_name = TypeContext {
6834                        handle: ty,
6835                        gctx: module.to_ctx(),
6836                        names: &self.names,
6837                        access: crate::StorageAccess::empty(),
6838                        first_time: false,
6839                    };
6840                    let resolved = options.resolve_local_binding(binding, in_mode)?;
6841                    if do_vertex_pulling {
6842                        // Update our attribute mapping.
6843                        am_resolved.insert(
6844                            location,
6845                            AttributeMappingResolved {
6846                                ty_name: ty_name.to_string(),
6847                                dimension: ty_name.vertex_input_dimension(),
6848                                ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6849                                name: name.to_string(),
6850                            },
6851                        );
6852                    } else {
6853                        has_varyings = true;
6854                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6855                        resolved.try_fmt(&mut self.out)?;
6856                        writeln!(self.out, ";")?;
6857                    }
6858                }
6859                if !do_vertex_pulling {
6860                    writeln!(self.out, "}};")?;
6861                }
6862            }
6863
6864            // Define a struct type named for the return value, if any, named
6865            // `<fun>Output`.
6866            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6867            let result_member_name = self.namer.call("member");
6868            let result_type_name = match fun.result {
6869                Some(ref result) => {
6870                    let mut result_members = Vec::new();
6871                    if let crate::TypeInner::Struct { ref members, .. } =
6872                        module.types[result.ty].inner
6873                    {
6874                        for (member_index, member) in members.iter().enumerate() {
6875                            result_members.push((
6876                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6877                                member.ty,
6878                                member.binding.as_ref(),
6879                            ));
6880                        }
6881                    } else {
6882                        result_members.push((
6883                            &result_member_name,
6884                            result.ty,
6885                            result.binding.as_ref(),
6886                        ));
6887                    }
6888
6889                    writeln!(self.out, "struct {stage_out_name} {{")?;
6890                    let mut has_point_size = false;
6891                    for (name, ty, binding) in result_members {
6892                        let ty_name = TypeContext {
6893                            handle: ty,
6894                            gctx: module.to_ctx(),
6895                            names: &self.names,
6896                            access: crate::StorageAccess::empty(),
6897                            first_time: true,
6898                        };
6899                        let binding = binding.ok_or_else(|| {
6900                            Error::GenericValidation("Expected binding, got None".into())
6901                        })?;
6902
6903                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6904                            has_point_size = true;
6905                            if !pipeline_options.allow_and_force_point_size {
6906                                continue;
6907                            }
6908                        }
6909
6910                        let array_len = match module.types[ty].inner {
6911                            crate::TypeInner::Array {
6912                                size: crate::ArraySize::Constant(size),
6913                                ..
6914                            } => Some(size),
6915                            _ => None,
6916                        };
6917                        let resolved = options.resolve_local_binding(binding, out_mode)?;
6918                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6919                        if let Some(array_len) = array_len {
6920                            write!(self.out, " [{array_len}]")?;
6921                        }
6922                        resolved.try_fmt(&mut self.out)?;
6923                        writeln!(self.out, ";")?;
6924                    }
6925
6926                    if pipeline_options.allow_and_force_point_size
6927                        && ep.stage == crate::ShaderStage::Vertex
6928                        && !has_point_size
6929                    {
6930                        // inject the point size output last
6931                        writeln!(
6932                            self.out,
6933                            "{}float _point_size [[point_size]];",
6934                            back::INDENT
6935                        )?;
6936                    }
6937                    writeln!(self.out, "}};")?;
6938                    &stage_out_name
6939                }
6940                None => "void",
6941            };
6942
6943            // If we're doing a vertex pulling transform, define the buffer
6944            // structure types.
6945            if do_vertex_pulling {
6946                for vbm in &vbm_resolved {
6947                    let buffer_stride = vbm.stride;
6948                    let buffer_ty = &vbm.ty_name;
6949
6950                    // Define a structure of bytes of the appropriate size.
6951                    // When we access the attributes, we'll be unpacking these
6952                    // bytes at some offset.
6953                    writeln!(
6954                        self.out,
6955                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6956                    )?;
6957                }
6958            }
6959
6960            // Write the entry point function's name, and begin its argument list.
6961            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6962
6963            let mut is_first_argument = true;
6964            let mut separator = || {
6965                if is_first_argument {
6966                    is_first_argument = false;
6967                    ' '
6968                } else {
6969                    ','
6970                }
6971            };
6972
6973            // If we have produced a struct holding the `EntryPoint`'s
6974            // `Function`'s arguments' varyings, pass that struct first.
6975            if has_varyings {
6976                writeln!(
6977                    self.out,
6978                    "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
6979                    separator()
6980                )?;
6981            }
6982
6983            let mut local_invocation_id = None;
6984
6985            // Then pass the remaining arguments not included in the varyings
6986            // struct.
6987            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6988                let binding = match binding {
6989                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6990                    _ => continue,
6991                };
6992                let name = match *name_key {
6993                    NameKey::StructMember(..) => &flattened_member_names[name_key],
6994                    _ => &self.names[name_key],
6995                };
6996
6997                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6998                    local_invocation_id = Some(name_key);
6999                }
7000
7001                let ty_name = TypeContext {
7002                    handle: ty,
7003                    gctx: module.to_ctx(),
7004                    names: &self.names,
7005                    access: crate::StorageAccess::empty(),
7006                    first_time: false,
7007                };
7008
7009                match *binding {
7010                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7011                        v_existing_id = Some(name.clone());
7012                    }
7013                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7014                        i_existing_id = Some(name.clone());
7015                    }
7016                    _ => {}
7017                };
7018
7019                let resolved = options.resolve_local_binding(binding, in_mode)?;
7020                write!(self.out, "{} {ty_name} {name}", separator())?;
7021                resolved.try_fmt(&mut self.out)?;
7022                writeln!(self.out)?;
7023            }
7024
7025            let need_workgroup_variables_initialization =
7026                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7027
7028            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7029                writeln!(
7030                    self.out,
7031                    "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7032                )?;
7033            }
7034
7035            // Those global variables used by this entry point and its callees
7036            // get passed as arguments. `Private` globals are an exception, they
7037            // don't outlive this invocation, so we declare them below as locals
7038            // within the entry point.
7039            for (handle, var) in module.global_variables.iter() {
7040                let usage = fun_info[handle];
7041                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7042                    continue;
7043                }
7044
7045                if options.lang_version < (1, 2) {
7046                    match var.space {
7047                        // This restriction is not documented in the MSL spec
7048                        // but validation will fail if it is not upheld.
7049                        //
7050                        // We infer the required version from the "Function
7051                        // Buffer Read-Writes" section of [what's new], where
7052                        // the feature sets listed correspond with the ones
7053                        // supporting MSL 1.2.
7054                        //
7055                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7056                        crate::AddressSpace::Storage { access }
7057                            if access.contains(crate::StorageAccess::STORE)
7058                                && ep.stage == crate::ShaderStage::Fragment =>
7059                        {
7060                            return Err(Error::UnsupportedWriteableStorageBuffer)
7061                        }
7062                        crate::AddressSpace::Handle => {
7063                            match module.types[var.ty].inner {
7064                                crate::TypeInner::Image {
7065                                    class: crate::ImageClass::Storage { access, .. },
7066                                    ..
7067                                } => {
7068                                    // This restriction is not documented in the MSL spec
7069                                    // but validation will fail if it is not upheld.
7070                                    //
7071                                    // We infer the required version from the "Function
7072                                    // Texture Read-Writes" section of [what's new], where
7073                                    // the feature sets listed correspond with the ones
7074                                    // supporting MSL 1.2.
7075                                    //
7076                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7077                                    if access.contains(crate::StorageAccess::STORE)
7078                                        && (ep.stage == crate::ShaderStage::Vertex
7079                                            || ep.stage == crate::ShaderStage::Fragment)
7080                                    {
7081                                        return Err(Error::UnsupportedWriteableStorageTexture(
7082                                            ep.stage,
7083                                        ));
7084                                    }
7085
7086                                    if access.contains(
7087                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7088                                    ) {
7089                                        return Err(Error::UnsupportedRWStorageTexture);
7090                                    }
7091                                }
7092                                _ => {}
7093                            }
7094                        }
7095                        _ => {}
7096                    }
7097                }
7098
7099                // Check min MSL version for binding arrays
7100                match var.space {
7101                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7102                        crate::TypeInner::BindingArray { base, .. } => {
7103                            match module.types[base].inner {
7104                                crate::TypeInner::Sampler { .. } => {
7105                                    if options.lang_version < (2, 0) {
7106                                        return Err(Error::UnsupportedArrayOf(
7107                                            "samplers".to_string(),
7108                                        ));
7109                                    }
7110                                }
7111                                crate::TypeInner::Image { class, .. } => match class {
7112                                    crate::ImageClass::Sampled { .. }
7113                                    | crate::ImageClass::Depth { .. }
7114                                    | crate::ImageClass::Storage {
7115                                        access: crate::StorageAccess::LOAD,
7116                                        ..
7117                                    } => {
7118                                        // Array of textures since:
7119                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7120                                        // - macOS: Metal 2
7121
7122                                        if options.lang_version < (2, 0) {
7123                                            return Err(Error::UnsupportedArrayOf(
7124                                                "textures".to_string(),
7125                                            ));
7126                                        }
7127                                    }
7128                                    crate::ImageClass::Storage {
7129                                        access: crate::StorageAccess::STORE,
7130                                        ..
7131                                    } => {
7132                                        // Array of write-only textures since:
7133                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7134                                        // - macOS: Metal 2
7135
7136                                        if options.lang_version < (2, 0) {
7137                                            return Err(Error::UnsupportedArrayOf(
7138                                                "write-only textures".to_string(),
7139                                            ));
7140                                        }
7141                                    }
7142                                    crate::ImageClass::Storage { .. } => {
7143                                        return Err(Error::UnsupportedArrayOf(
7144                                            "read-write textures".to_string(),
7145                                        ));
7146                                    }
7147                                    crate::ImageClass::External => {
7148                                        return Err(Error::UnsupportedArrayOf(
7149                                            "external textures".to_string(),
7150                                        ));
7151                                    }
7152                                },
7153                                _ => {
7154                                    return Err(Error::UnsupportedArrayOfType(base));
7155                                }
7156                            }
7157                        }
7158                        _ => {}
7159                    },
7160                    _ => {}
7161                }
7162
7163                // the resolves have already been checked for `!fake_missing_bindings` case
7164                let resolved = match var.space {
7165                    crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7166                    crate::AddressSpace::WorkGroup => None,
7167                    _ => options
7168                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7169                        .ok(),
7170                };
7171                if let Some(ref resolved) = resolved {
7172                    // Inline samplers are be defined in the EP body
7173                    if resolved.as_inline_sampler(options).is_some() {
7174                        continue;
7175                    }
7176                }
7177
7178                match module.types[var.ty].inner {
7179                    crate::TypeInner::Image {
7180                        class: crate::ImageClass::External,
7181                        ..
7182                    } => {
7183                        // External texture global variables get lowered to 3 textures
7184                        // and a constant buffer. We must emit a separate argument for
7185                        // each of these.
7186                        let target = match resolved {
7187                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7188                                target.external_texture
7189                            }
7190                            _ => None,
7191                        };
7192
7193                        for i in 0..3 {
7194                            write!(self.out, "{} ", separator())?;
7195
7196                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7197                                handle,
7198                                ExternalTextureNameKey::Plane(i),
7199                            )];
7200                            write!(
7201                              self.out,
7202                              "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7203                            )?;
7204                            if let Some(ref target) = target {
7205                                write!(self.out, " [[texture({})]]", target.planes[i])?;
7206                            }
7207                            writeln!(self.out)?;
7208                        }
7209                        let params_ty_name = &self.names
7210                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7211                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7212                            handle,
7213                            ExternalTextureNameKey::Params,
7214                        )];
7215                        write!(self.out, "{} ", separator())?;
7216                        write!(self.out, "constant {params_ty_name}& {params_name}")?;
7217                        if let Some(ref target) = target {
7218                            write!(self.out, " [[buffer({})]]", target.params)?;
7219                        }
7220                    }
7221                    _ => {
7222                        let tyvar = TypedGlobalVariable {
7223                            module,
7224                            names: &self.names,
7225                            handle,
7226                            usage,
7227                            reference: true,
7228                        };
7229                        write!(self.out, "{} ", separator())?;
7230                        tyvar.try_fmt(&mut self.out)?;
7231                        if let Some(resolved) = resolved {
7232                            resolved.try_fmt(&mut self.out)?;
7233                        }
7234                        if let Some(value) = var.init {
7235                            write!(self.out, " = ")?;
7236                            self.put_const_expression(
7237                                value,
7238                                module,
7239                                mod_info,
7240                                &module.global_expressions,
7241                            )?;
7242                        }
7243                    }
7244                }
7245                writeln!(self.out)?;
7246            }
7247
7248            if do_vertex_pulling {
7249                if needs_vertex_id && v_existing_id.is_none() {
7250                    // Write the [[vertex_id]] argument.
7251                    writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7252                }
7253
7254                if needs_instance_id && i_existing_id.is_none() {
7255                    writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7256                }
7257
7258                // Iterate vbm_resolved, output one argument for every vertex buffer,
7259                // using the names we generated earlier.
7260                for vbm in &vbm_resolved {
7261                    let id = &vbm.id;
7262                    let ty_name = &vbm.ty_name;
7263                    let param_name = &vbm.param_name;
7264                    writeln!(
7265                        self.out,
7266                        "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7267                        separator()
7268                    )?;
7269                }
7270            }
7271
7272            // If this entry uses any variable-length arrays, their sizes are
7273            // passed as a final struct-typed argument.
7274            if needs_buffer_sizes {
7275                // this is checked earlier
7276                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7277                write!(
7278                    self.out,
7279                    "{} constant _mslBufferSizes& _buffer_sizes",
7280                    separator()
7281                )?;
7282                resolved.try_fmt(&mut self.out)?;
7283                writeln!(self.out)?;
7284            }
7285
7286            // end of the entry point argument list
7287            writeln!(self.out, ") {{")?;
7288
7289            // Starting the function body.
7290            if do_vertex_pulling {
7291                // Provide zero values for all the attributes, which we will overwrite with
7292                // real data from the vertex attribute buffers, if the indices are in-bounds.
7293                for vbm in &vbm_resolved {
7294                    for attribute in vbm.attributes {
7295                        let location = attribute.shader_location;
7296                        let am_option = am_resolved.get(&location);
7297                        if am_option.is_none() {
7298                            // This bound attribute isn't used in this entry point, so
7299                            // don't bother zero-initializing it.
7300                            continue;
7301                        }
7302                        let am = am_option.unwrap();
7303                        let attribute_ty_name = &am.ty_name;
7304                        let attribute_name = &am.name;
7305
7306                        writeln!(
7307                            self.out,
7308                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7309                            back::Level(1)
7310                        )?;
7311                    }
7312
7313                    // Output a bounds check block that will set real values for the
7314                    // attributes, if the bounds are satisfied.
7315                    write!(self.out, "{}if (", back::Level(1))?;
7316
7317                    let idx = &vbm.id;
7318                    let stride = &vbm.stride;
7319                    let index_name = match vbm.step_mode {
7320                        back::msl::VertexBufferStepMode::Constant => "0",
7321                        back::msl::VertexBufferStepMode::ByVertex => {
7322                            if let Some(ref name) = v_existing_id {
7323                                name
7324                            } else {
7325                                &v_id
7326                            }
7327                        }
7328                        back::msl::VertexBufferStepMode::ByInstance => {
7329                            if let Some(ref name) = i_existing_id {
7330                                name
7331                            } else {
7332                                &i_id
7333                            }
7334                        }
7335                    };
7336                    write!(
7337                        self.out,
7338                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7339                    )?;
7340
7341                    writeln!(self.out, ") {{")?;
7342
7343                    // Pull the bytes out of the vertex buffer.
7344                    let ty_name = &vbm.ty_name;
7345                    let elem_name = &vbm.elem_name;
7346                    let param_name = &vbm.param_name;
7347
7348                    writeln!(
7349                        self.out,
7350                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7351                        back::Level(2),
7352                    )?;
7353
7354                    // Now set real values for each of the attributes, by unpacking the data
7355                    // from the buffer elements.
7356                    for attribute in vbm.attributes {
7357                        let location = attribute.shader_location;
7358                        let Some(am) = am_resolved.get(&location) else {
7359                            // This bound attribute isn't used in this entry point, so
7360                            // don't bother extracting the data. Too bad we emitted the
7361                            // unpacking function earlier -- it might not get used.
7362                            continue;
7363                        };
7364                        let attribute_name = &am.name;
7365                        let attribute_ty_name = &am.ty_name;
7366
7367                        let offset = attribute.offset;
7368                        let func = unpacking_functions
7369                            .get(&attribute.format)
7370                            .expect("Should have generated this unpacking function earlier.");
7371                        let func_name = &func.name;
7372
7373                        // Check dimensionality of the attribute compared to the unpacking
7374                        // function. If attribute dimension > unpack dimension, we have to
7375                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7376                        // scalar type. Otherwise, if attribute dimension is < unpack
7377                        // dimension, then we need to explicitly truncate the result.
7378
7379                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7380
7381                        if needs_padding_or_truncation != Ordering::Equal {
7382                            // Emit a comment flagging that a conversion is happening,
7383                            // since the actual logic can be at the end of a long line.
7384                            writeln!(
7385                                self.out,
7386                                "{}// {attribute_ty_name} <- {:?}",
7387                                back::Level(2),
7388                                attribute.format
7389                            )?;
7390                        }
7391
7392                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7393
7394                        if needs_padding_or_truncation == Ordering::Greater {
7395                            // Needs padding: emit constructor call for wider type
7396                            write!(self.out, "{attribute_ty_name}(")?;
7397                        }
7398
7399                        // Emit call to unpacking function
7400                        write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7401                        for i in (offset + 1)..(offset + func.byte_count) {
7402                            write!(self.out, ", {elem_name}.data[{i}]")?;
7403                        }
7404                        write!(self.out, ")")?;
7405
7406                        match needs_padding_or_truncation {
7407                            Ordering::Greater => {
7408                                // Padding
7409                                let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7410                                let one_value = if am.ty_is_int { "1" } else { "1.0" };
7411                                for i in func.dimension..am.dimension {
7412                                    write!(
7413                                        self.out,
7414                                        ", {}",
7415                                        if i == 3 { one_value } else { zero_value }
7416                                    )?;
7417                                }
7418                                write!(self.out, ")")?;
7419                            }
7420                            Ordering::Less => {
7421                                // Truncate to the first `am.dimension` components
7422                                write!(
7423                                    self.out,
7424                                    ".{}",
7425                                    &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7426                                )?;
7427                            }
7428                            Ordering::Equal => {}
7429                        }
7430
7431                        writeln!(self.out, ";")?;
7432                    }
7433
7434                    // End the bounds check / attribute setting block.
7435                    writeln!(self.out, "{}}}", back::Level(1))?;
7436                }
7437            }
7438
7439            if need_workgroup_variables_initialization {
7440                self.write_workgroup_variables_initialization(
7441                    module,
7442                    mod_info,
7443                    fun_info,
7444                    local_invocation_id,
7445                )?;
7446            }
7447
7448            // Metal doesn't support private mutable variables outside of functions,
7449            // so we put them here, just like the locals.
7450            for (handle, var) in module.global_variables.iter() {
7451                let usage = fun_info[handle];
7452                if usage.is_empty() {
7453                    continue;
7454                }
7455                if var.space == crate::AddressSpace::Private {
7456                    let tyvar = TypedGlobalVariable {
7457                        module,
7458                        names: &self.names,
7459                        handle,
7460                        usage,
7461
7462                        reference: false,
7463                    };
7464                    write!(self.out, "{}", back::INDENT)?;
7465                    tyvar.try_fmt(&mut self.out)?;
7466                    match var.init {
7467                        Some(value) => {
7468                            write!(self.out, " = ")?;
7469                            self.put_const_expression(
7470                                value,
7471                                module,
7472                                mod_info,
7473                                &module.global_expressions,
7474                            )?;
7475                            writeln!(self.out, ";")?;
7476                        }
7477                        None => {
7478                            writeln!(self.out, " = {{}};")?;
7479                        }
7480                    };
7481                } else if let Some(ref binding) = var.binding {
7482                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7483                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7484                        // write an inline sampler
7485                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7486                        writeln!(
7487                            self.out,
7488                            "{}constexpr {}::sampler {}(",
7489                            back::INDENT,
7490                            NAMESPACE,
7491                            name
7492                        )?;
7493                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7494                        writeln!(self.out, "{});", back::INDENT)?;
7495                    } else if let crate::TypeInner::Image {
7496                        class: crate::ImageClass::External,
7497                        ..
7498                    } = module.types[var.ty].inner
7499                    {
7500                        // Wrap the individual arguments for each external texture global
7501                        // in a struct which can be easily passed around.
7502                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7503                        let l1 = back::Level(1);
7504                        let l2 = l1.next();
7505                        writeln!(
7506                            self.out,
7507                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7508                        )?;
7509                        for i in 0..3 {
7510                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7511                                handle,
7512                                ExternalTextureNameKey::Plane(i),
7513                            )];
7514                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7515                        }
7516                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7517                            handle,
7518                            ExternalTextureNameKey::Params,
7519                        )];
7520                        writeln!(self.out, "{l2}.params = {params_name},")?;
7521                        writeln!(self.out, "{l1}}};")?;
7522                    }
7523                }
7524            }
7525
7526            // Now take the arguments that we gathered into structs, and the
7527            // structs that we flattened into arguments, and emit local
7528            // variables with initializers that put everything back the way the
7529            // body code expects.
7530            //
7531            // If we had to generate fresh names for struct members passed as
7532            // arguments, be sure to use those names when rebuilding the struct.
7533            //
7534            // "Each day, I change some zeros to ones, and some ones to zeros.
7535            // The rest, I leave alone."
7536            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7537                let arg_name =
7538                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7539                match module.types[arg.ty].inner {
7540                    crate::TypeInner::Struct { ref members, .. } => {
7541                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7542                        write!(
7543                            self.out,
7544                            "{}const {} {} = {{ ",
7545                            back::INDENT,
7546                            struct_name,
7547                            arg_name
7548                        )?;
7549                        for (member_index, member) in members.iter().enumerate() {
7550                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7551                            let name = &flattened_member_names[&key];
7552                            if member_index != 0 {
7553                                write!(self.out, ", ")?;
7554                            }
7555                            // insert padding initialization, if needed
7556                            if self
7557                                .struct_member_pads
7558                                .contains(&(arg.ty, member_index as u32))
7559                            {
7560                                write!(self.out, "{{}}, ")?;
7561                            }
7562                            if let Some(crate::Binding::Location { .. }) = member.binding {
7563                                if has_varyings {
7564                                    write!(self.out, "{varyings_member_name}.")?;
7565                                }
7566                            }
7567                            write!(self.out, "{name}")?;
7568                        }
7569                        writeln!(self.out, " }};")?;
7570                    }
7571                    _ => {
7572                        if let Some(crate::Binding::Location { .. }) = arg.binding {
7573                            if has_varyings {
7574                                writeln!(
7575                                    self.out,
7576                                    "{}const auto {} = {}.{};",
7577                                    back::INDENT,
7578                                    arg_name,
7579                                    varyings_member_name,
7580                                    arg_name
7581                                )?;
7582                            }
7583                        }
7584                    }
7585                }
7586            }
7587
7588            let guarded_indices =
7589                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7590
7591            let context = StatementContext {
7592                expression: ExpressionContext {
7593                    function: fun,
7594                    origin: FunctionOrigin::EntryPoint(ep_index as _),
7595                    info: fun_info,
7596                    lang_version: options.lang_version,
7597                    policies: options.bounds_check_policies,
7598                    guarded_indices,
7599                    module,
7600                    mod_info,
7601                    pipeline_options,
7602                    force_loop_bounding: options.force_loop_bounding,
7603                },
7604                result_struct: Some(&stage_out_name),
7605            };
7606
7607            // Finally, declare all the local variables that we need
7608            //TODO: we can postpone this till the relevant expressions are emitted
7609            self.put_locals(&context.expression)?;
7610            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7611            self.put_block(back::Level(1), &fun.body, &context)?;
7612            writeln!(self.out, "}}")?;
7613            if ep_index + 1 != module.entry_points.len() {
7614                writeln!(self.out)?;
7615            }
7616            self.named_expressions.clear();
7617        }
7618
7619        Ok(info)
7620    }
7621
7622    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7623        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
7624        // so we try to avoid it here.
7625        if flags.is_empty() {
7626            writeln!(
7627                self.out,
7628                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7629            )?;
7630        }
7631        if flags.contains(crate::Barrier::STORAGE) {
7632            writeln!(
7633                self.out,
7634                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7635            )?;
7636        }
7637        if flags.contains(crate::Barrier::WORK_GROUP) {
7638            writeln!(
7639                self.out,
7640                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7641            )?;
7642        }
7643        if flags.contains(crate::Barrier::SUB_GROUP) {
7644            writeln!(
7645                self.out,
7646                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7647            )?;
7648        }
7649        if flags.contains(crate::Barrier::TEXTURE) {
7650            writeln!(
7651                self.out,
7652                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7653            )?;
7654        }
7655        Ok(())
7656    }
7657}
7658
7659/// Initializing workgroup variables is more tricky for Metal because we have to deal
7660/// with atomics at the type-level (which don't have a copy constructor).
7661mod workgroup_mem_init {
7662    use crate::EntryPoint;
7663
7664    use super::*;
7665
7666    enum Access {
7667        GlobalVariable(Handle<crate::GlobalVariable>),
7668        StructMember(Handle<crate::Type>, u32),
7669        Array(usize),
7670    }
7671
7672    impl Access {
7673        fn write<W: Write>(
7674            &self,
7675            writer: &mut W,
7676            names: &FastHashMap<NameKey, String>,
7677        ) -> Result<(), core::fmt::Error> {
7678            match *self {
7679                Access::GlobalVariable(handle) => {
7680                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7681                }
7682                Access::StructMember(handle, index) => {
7683                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7684                }
7685                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7686            }
7687        }
7688    }
7689
7690    struct AccessStack {
7691        stack: Vec<Access>,
7692        array_depth: usize,
7693    }
7694
7695    impl AccessStack {
7696        const fn new() -> Self {
7697            Self {
7698                stack: Vec::new(),
7699                array_depth: 0,
7700            }
7701        }
7702
7703        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7704            let array_depth = self.array_depth;
7705            self.stack.push(Access::Array(array_depth));
7706            self.array_depth += 1;
7707            let res = cb(self, array_depth);
7708            self.stack.pop();
7709            self.array_depth -= 1;
7710            res
7711        }
7712
7713        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7714            self.stack.push(new);
7715            let res = cb(self);
7716            self.stack.pop();
7717            res
7718        }
7719
7720        fn write<W: Write>(
7721            &self,
7722            writer: &mut W,
7723            names: &FastHashMap<NameKey, String>,
7724        ) -> Result<(), core::fmt::Error> {
7725            for next in self.stack.iter() {
7726                next.write(writer, names)?;
7727            }
7728            Ok(())
7729        }
7730    }
7731
7732    impl<W: Write> Writer<W> {
7733        pub(super) fn need_workgroup_variables_initialization(
7734            &mut self,
7735            options: &Options,
7736            ep: &EntryPoint,
7737            module: &crate::Module,
7738            fun_info: &valid::FunctionInfo,
7739        ) -> bool {
7740            options.zero_initialize_workgroup_memory
7741                && ep.stage.compute_like()
7742                && module.global_variables.iter().any(|(handle, var)| {
7743                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7744                })
7745        }
7746
7747        pub(super) fn write_workgroup_variables_initialization(
7748            &mut self,
7749            module: &crate::Module,
7750            module_info: &valid::ModuleInfo,
7751            fun_info: &valid::FunctionInfo,
7752            local_invocation_id: Option<&NameKey>,
7753        ) -> BackendResult {
7754            let level = back::Level(1);
7755
7756            writeln!(
7757                self.out,
7758                "{}if ({}::all({} == {}::uint3(0u))) {{",
7759                level,
7760                NAMESPACE,
7761                local_invocation_id
7762                    .map(|name_key| self.names[name_key].as_str())
7763                    .unwrap_or("__local_invocation_id"),
7764                NAMESPACE,
7765            )?;
7766
7767            let mut access_stack = AccessStack::new();
7768
7769            let vars = module.global_variables.iter().filter(|&(handle, var)| {
7770                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7771            });
7772
7773            for (handle, var) in vars {
7774                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7775                    self.write_workgroup_variable_initialization(
7776                        module,
7777                        module_info,
7778                        var.ty,
7779                        access_stack,
7780                        level.next(),
7781                    )
7782                })?;
7783            }
7784
7785            writeln!(self.out, "{level}}}")?;
7786            self.write_barrier(crate::Barrier::WORK_GROUP, level)
7787        }
7788
7789        fn write_workgroup_variable_initialization(
7790            &mut self,
7791            module: &crate::Module,
7792            module_info: &valid::ModuleInfo,
7793            ty: Handle<crate::Type>,
7794            access_stack: &mut AccessStack,
7795            level: back::Level,
7796        ) -> BackendResult {
7797            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7798                write!(self.out, "{level}")?;
7799                access_stack.write(&mut self.out, &self.names)?;
7800                writeln!(self.out, " = {{}};")?;
7801            } else {
7802                match module.types[ty].inner {
7803                    crate::TypeInner::Atomic { .. } => {
7804                        write!(
7805                            self.out,
7806                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7807                        )?;
7808                        access_stack.write(&mut self.out, &self.names)?;
7809                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7810                    }
7811                    crate::TypeInner::Array { base, size, .. } => {
7812                        let count = match size.resolve(module.to_ctx())? {
7813                            proc::IndexableLength::Known(count) => count,
7814                            proc::IndexableLength::Dynamic => unreachable!(),
7815                        };
7816
7817                        access_stack.enter_array(|access_stack, array_depth| {
7818                            writeln!(
7819                                self.out,
7820                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7821                            )?;
7822                            self.write_workgroup_variable_initialization(
7823                                module,
7824                                module_info,
7825                                base,
7826                                access_stack,
7827                                level.next(),
7828                            )?;
7829                            writeln!(self.out, "{level}}}")?;
7830                            BackendResult::Ok(())
7831                        })?;
7832                    }
7833                    crate::TypeInner::Struct { ref members, .. } => {
7834                        for (index, member) in members.iter().enumerate() {
7835                            access_stack.enter(
7836                                Access::StructMember(ty, index as u32),
7837                                |access_stack| {
7838                                    self.write_workgroup_variable_initialization(
7839                                        module,
7840                                        module_info,
7841                                        member.ty,
7842                                        access_stack,
7843                                        level,
7844                                    )
7845                                },
7846                            )?;
7847                        }
7848                    }
7849                    _ => unreachable!(),
7850                }
7851            }
7852
7853            Ok(())
7854        }
7855    }
7856}
7857
7858impl crate::AtomicFunction {
7859    const fn to_msl(self) -> &'static str {
7860        match self {
7861            Self::Add => "fetch_add",
7862            Self::Subtract => "fetch_sub",
7863            Self::And => "fetch_and",
7864            Self::InclusiveOr => "fetch_or",
7865            Self::ExclusiveOr => "fetch_xor",
7866            Self::Min => "fetch_min",
7867            Self::Max => "fetch_max",
7868            Self::Exchange { compare: None } => "exchange",
7869            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7870        }
7871    }
7872
7873    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7874        Ok(match self {
7875            Self::Min => "min",
7876            Self::Max => "max",
7877            _ => Err(Error::FeatureNotImplemented(
7878                "64-bit atomic operation other than min/max".to_string(),
7879            ))?,
7880        })
7881    }
7882}