naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18    arena::{Handle, HandleSet},
19    back::{self, get_entry_points, Baked},
20    common,
21    proc::{
22        self,
23        index::{self, BoundsCheck},
24        ExternalTextureNameKey, NameKey, TypeResolution,
25    },
26    valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32/// Shorthand result used internally by the backend
33type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36// The name of the array member of the Metal struct types we generate to
37// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
38// for details.
39const WRAPPED_ARRAY_FIELD: &str = "inner";
40// This is a hack: we need to pass a pointer to an atomic,
41// but generally the backend isn't putting "&" in front of every pointer.
42// Some more general handling of pointers is needed to be implemented here.
43const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
50const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67    "nagaTextureSampleBaseClampToEdge";
68/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
69/// However, if you put that texture inside a struct, everything is totally fine. This
70/// baffles me to no end.
71///
72/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
73/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
74/// you have noticed that this should be exactly the same to the compiler, and you're correct.
75pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76/// Name of the struct that is declared to wrap the 3 textures and parameters
77/// buffer that [`crate::ImageClass::External`] variables are lowered to,
78/// allowing them to be conveniently passed to user-defined or wrapper
79/// functions. The struct is declared in [`Writer::write_type_defs`].
80pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81
82/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
83///
84/// The `sizes` slice determines whether this function writes a
85/// scalar, vector, or matrix type:
86///
87/// - An empty slice produces a scalar type.
88/// - A one-element slice produces a vector type.
89/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
90fn put_numeric_type(
91    out: &mut impl Write,
92    scalar: crate::Scalar,
93    sizes: &[crate::VectorSize],
94) -> Result<(), FmtError> {
95    match (scalar, sizes) {
96        (scalar, &[]) => {
97            write!(out, "{}", scalar.to_msl_name())
98        }
99        (scalar, &[rows]) => {
100            write!(
101                out,
102                "{}::{}{}",
103                NAMESPACE,
104                scalar.to_msl_name(),
105                common::vector_size_str(rows)
106            )
107        }
108        (scalar, &[rows, columns]) => {
109            write!(
110                out,
111                "{}::{}{}x{}",
112                NAMESPACE,
113                scalar.to_msl_name(),
114                common::vector_size_str(columns),
115                common::vector_size_str(rows)
116            )
117        }
118        (_, _) => Ok(()), // not meaningful
119    }
120}
121
122const fn scalar_is_int(scalar: crate::Scalar) -> bool {
123    use crate::ScalarKind::*;
124    match scalar.kind {
125        Sint | Uint | AbstractInt | Bool => true,
126        Float | AbstractFloat => false,
127    }
128}
129
130/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
131const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
132
133/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
134const REINTERPRET_PREFIX: &str = "reinterpreted_";
135
136/// Wrapper for identifier names for clamped level-of-detail values
137///
138/// Values of this type implement [`core::fmt::Display`], formatting as
139/// the name of the variable used to hold the cached clamped
140/// level-of-detail value for an `ImageLoad` expression.
141struct ClampedLod(Handle<crate::Expression>);
142
143impl Display for ClampedLod {
144    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
145        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
146    }
147}
148
149/// Wrapper for generating `struct _mslBufferSizes` member names for
150/// runtime-sized array lengths.
151///
152/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
153/// as an argument to the entry point. This argument's type in the MSL is
154/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
155/// each global variable containing a runtime-sized array.
156///
157/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
158/// runtime-sized array, then the value `ArraySize(global)` implements
159/// [`core::fmt::Display`], formatting as the name of the struct member carrying
160/// the number of elements in that runtime-sized array.
161///
162/// [`GlobalVariable`]: crate::GlobalVariable
163struct ArraySizeMember(Handle<crate::GlobalVariable>);
164
165impl Display for ArraySizeMember {
166    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167        self.0.write_prefixed(f, "size")
168    }
169}
170
171/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
172///
173/// Implements [`core::fmt::Display`], formatting as a name derived from
174/// `target_type` and the variable name of `orig`.
175#[derive(Clone, Copy)]
176struct Reinterpreted<'a> {
177    target_type: &'a str,
178    orig: Handle<crate::Expression>,
179}
180
181impl<'a> Reinterpreted<'a> {
182    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
183        Self { target_type, orig }
184    }
185}
186
187impl Display for Reinterpreted<'_> {
188    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
189        f.write_str(REINTERPRET_PREFIX)?;
190        f.write_str(self.target_type)?;
191        self.orig.write_prefixed(f, "_e")
192    }
193}
194
195struct TypeContext<'a> {
196    handle: Handle<crate::Type>,
197    gctx: proc::GlobalCtx<'a>,
198    names: &'a FastHashMap<NameKey, String>,
199    access: crate::StorageAccess,
200    first_time: bool,
201}
202
203impl TypeContext<'_> {
204    fn scalar(&self) -> Option<crate::Scalar> {
205        let ty = &self.gctx.types[self.handle];
206        ty.inner.scalar()
207    }
208
209    fn vertex_input_dimension(&self) -> u32 {
210        let ty = &self.gctx.types[self.handle];
211        match ty.inner {
212            crate::TypeInner::Scalar(_) => 1,
213            crate::TypeInner::Vector { size, .. } => size as u32,
214            _ => unreachable!(),
215        }
216    }
217}
218
219impl Display for TypeContext<'_> {
220    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221        let ty = &self.gctx.types[self.handle];
222        if ty.needs_alias() && !self.first_time {
223            let name = &self.names[&NameKey::Type(self.handle)];
224            return write!(out, "{name}");
225        }
226
227        match ty.inner {
228            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229            crate::TypeInner::Atomic(scalar) => {
230                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231            }
232            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233            crate::TypeInner::Matrix {
234                columns,
235                rows,
236                scalar,
237            } => put_numeric_type(out, scalar, &[rows, columns]),
238            crate::TypeInner::Pointer { base, space } => {
239                let sub = Self {
240                    handle: base,
241                    first_time: false,
242                    ..*self
243                };
244                let space_name = match space.to_msl_name() {
245                    Some(name) => name,
246                    None => return Ok(()),
247                };
248                write!(out, "{space_name} {sub}&")
249            }
250            crate::TypeInner::ValuePointer {
251                size,
252                scalar,
253                space,
254            } => {
255                match space.to_msl_name() {
256                    Some(name) => write!(out, "{name} ")?,
257                    None => return Ok(()),
258                };
259                match size {
260                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
261                    None => put_numeric_type(out, scalar, &[])?,
262                };
263
264                write!(out, "&")
265            }
266            crate::TypeInner::Array { base, .. } => {
267                let sub = Self {
268                    handle: base,
269                    first_time: false,
270                    ..*self
271                };
272                // Array lengths go at the end of the type definition,
273                // so just print the element type here.
274                write!(out, "{sub}")
275            }
276            crate::TypeInner::Struct { .. } => unreachable!(),
277            crate::TypeInner::Image {
278                dim,
279                arrayed,
280                class,
281            } => {
282                let dim_str = match dim {
283                    crate::ImageDimension::D1 => "1d",
284                    crate::ImageDimension::D2 => "2d",
285                    crate::ImageDimension::D3 => "3d",
286                    crate::ImageDimension::Cube => "cube",
287                };
288                let (texture_str, msaa_str, scalar, access) = match class {
289                    crate::ImageClass::Sampled { kind, multi } => {
290                        let (msaa_str, access) = if multi {
291                            ("_ms", "read")
292                        } else {
293                            ("", "sample")
294                        };
295                        let scalar = crate::Scalar { kind, width: 4 };
296                        ("texture", msaa_str, scalar, access)
297                    }
298                    crate::ImageClass::Depth { multi } => {
299                        let (msaa_str, access) = if multi {
300                            ("_ms", "read")
301                        } else {
302                            ("", "sample")
303                        };
304                        let scalar = crate::Scalar {
305                            kind: crate::ScalarKind::Float,
306                            width: 4,
307                        };
308                        ("depth", msaa_str, scalar, access)
309                    }
310                    crate::ImageClass::Storage { format, .. } => {
311                        let access = if self
312                            .access
313                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
314                        {
315                            "read_write"
316                        } else if self.access.contains(crate::StorageAccess::STORE) {
317                            "write"
318                        } else if self.access.contains(crate::StorageAccess::LOAD) {
319                            "read"
320                        } else {
321                            log::warn!(
322                                "Storage access for {:?} (name '{}'): {:?}",
323                                self.handle,
324                                ty.name.as_deref().unwrap_or_default(),
325                                self.access
326                            );
327                            unreachable!("module is not valid");
328                        };
329                        ("texture", "", format.into(), access)
330                    }
331                    crate::ImageClass::External => {
332                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
333                    }
334                };
335                let base_name = scalar.to_msl_name();
336                let array_str = if arrayed { "_array" } else { "" };
337                write!(
338                    out,
339                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
340                )
341            }
342            crate::TypeInner::Sampler { comparison: _ } => {
343                write!(out, "{NAMESPACE}::sampler")
344            }
345            crate::TypeInner::AccelerationStructure { vertex_return } => {
346                if vertex_return {
347                    unimplemented!("metal does not support vertex ray hit return")
348                }
349                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
350            }
351            crate::TypeInner::RayQuery { vertex_return } => {
352                if vertex_return {
353                    unimplemented!("metal does not support vertex ray hit return")
354                }
355                write!(out, "{RAY_QUERY_TYPE}")
356            }
357            crate::TypeInner::BindingArray { base, .. } => {
358                let base_tyname = Self {
359                    handle: base,
360                    first_time: false,
361                    ..*self
362                };
363
364                write!(
365                    out,
366                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
367                )
368            }
369        }
370    }
371}
372
373struct TypedGlobalVariable<'a> {
374    module: &'a crate::Module,
375    names: &'a FastHashMap<NameKey, String>,
376    handle: Handle<crate::GlobalVariable>,
377    usage: valid::GlobalUse,
378    reference: bool,
379}
380
381impl TypedGlobalVariable<'_> {
382    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
383        let var = &self.module.global_variables[self.handle];
384        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
385
386        let storage_access = match var.space {
387            crate::AddressSpace::Storage { access } => access,
388            _ => match self.module.types[var.ty].inner {
389                crate::TypeInner::Image {
390                    class: crate::ImageClass::Storage { access, .. },
391                    ..
392                } => access,
393                crate::TypeInner::BindingArray { base, .. } => {
394                    match self.module.types[base].inner {
395                        crate::TypeInner::Image {
396                            class: crate::ImageClass::Storage { access, .. },
397                            ..
398                        } => access,
399                        _ => crate::StorageAccess::default(),
400                    }
401                }
402                _ => crate::StorageAccess::default(),
403            },
404        };
405        let ty_name = TypeContext {
406            handle: var.ty,
407            gctx: self.module.to_ctx(),
408            names: self.names,
409            access: storage_access,
410            first_time: false,
411        };
412
413        let (space, access, reference) = match var.space.to_msl_name() {
414            Some(space) if self.reference => {
415                let access = if var.space.needs_access_qualifier()
416                    && !self.usage.intersects(valid::GlobalUse::WRITE)
417                {
418                    "const"
419                } else {
420                    ""
421                };
422                (space, access, "&")
423            }
424            _ => ("", "", ""),
425        };
426
427        Ok(write!(
428            out,
429            "{}{}{}{}{}{} {}",
430            space,
431            if space.is_empty() { "" } else { " " },
432            ty_name,
433            if access.is_empty() { "" } else { " " },
434            access,
435            reference,
436            name,
437        )?)
438    }
439}
440
441#[derive(Eq, PartialEq, Hash)]
442enum WrappedFunction {
443    UnaryOp {
444        op: crate::UnaryOperator,
445        ty: (Option<crate::VectorSize>, crate::Scalar),
446    },
447    BinaryOp {
448        op: crate::BinaryOperator,
449        left_ty: (Option<crate::VectorSize>, crate::Scalar),
450        right_ty: (Option<crate::VectorSize>, crate::Scalar),
451    },
452    Math {
453        fun: crate::MathFunction,
454        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
455    },
456    Cast {
457        src_scalar: crate::Scalar,
458        vector_size: Option<crate::VectorSize>,
459        dst_scalar: crate::Scalar,
460    },
461    ImageLoad {
462        class: crate::ImageClass,
463    },
464    ImageSample {
465        class: crate::ImageClass,
466        clamp_to_edge: bool,
467    },
468    ImageQuerySize {
469        class: crate::ImageClass,
470    },
471}
472
473pub struct Writer<W> {
474    out: W,
475    names: FastHashMap<NameKey, String>,
476    named_expressions: crate::NamedExpressions,
477    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
478    need_bake_expressions: back::NeedBakeExpressions,
479    namer: proc::Namer,
480    wrapped_functions: FastHashSet<WrappedFunction>,
481    #[cfg(test)]
482    put_expression_stack_pointers: FastHashSet<*const ()>,
483    #[cfg(test)]
484    put_block_stack_pointers: FastHashSet<*const ()>,
485    /// Set of (struct type, struct field index) denoting which fields require
486    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
487    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
488}
489
490impl crate::Scalar {
491    fn to_msl_name(self) -> &'static str {
492        use crate::ScalarKind as Sk;
493        match self {
494            Self {
495                kind: Sk::Float,
496                width: 4,
497            } => "float",
498            Self {
499                kind: Sk::Float,
500                width: 2,
501            } => "half",
502            Self {
503                kind: Sk::Sint,
504                width: 4,
505            } => "int",
506            Self {
507                kind: Sk::Uint,
508                width: 4,
509            } => "uint",
510            Self {
511                kind: Sk::Sint,
512                width: 8,
513            } => "long",
514            Self {
515                kind: Sk::Uint,
516                width: 8,
517            } => "ulong",
518            Self {
519                kind: Sk::Bool,
520                width: _,
521            } => "bool",
522            Self {
523                kind: Sk::AbstractInt | Sk::AbstractFloat,
524                width: _,
525            } => unreachable!("Found Abstract scalar kind"),
526            _ => unreachable!("Unsupported scalar kind: {:?}", self),
527        }
528    }
529}
530
531const fn separate(need_separator: bool) -> &'static str {
532    if need_separator {
533        ","
534    } else {
535        ""
536    }
537}
538
539fn should_pack_struct_member(
540    members: &[crate::StructMember],
541    span: u32,
542    index: usize,
543    module: &crate::Module,
544) -> Option<crate::Scalar> {
545    let member = &members[index];
546
547    let ty_inner = &module.types[member.ty].inner;
548    let last_offset = member.offset + ty_inner.size(module.to_ctx());
549    let next_offset = match members.get(index + 1) {
550        Some(next) => next.offset,
551        None => span,
552    };
553    let is_tight = next_offset == last_offset;
554
555    match *ty_inner {
556        crate::TypeInner::Vector {
557            size: crate::VectorSize::Tri,
558            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
559        } if is_tight => Some(scalar),
560        _ => None,
561    }
562}
563
564fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
565    match arena[ty].inner {
566        crate::TypeInner::Struct { ref members, .. } => {
567            if let Some(member) = members.last() {
568                if let crate::TypeInner::Array {
569                    size: crate::ArraySize::Dynamic,
570                    ..
571                } = arena[member.ty].inner
572                {
573                    return true;
574                }
575            }
576            false
577        }
578        crate::TypeInner::Array {
579            size: crate::ArraySize::Dynamic,
580            ..
581        } => true,
582        _ => false,
583    }
584}
585
586impl crate::AddressSpace {
587    /// Returns true if global variables in this address space are
588    /// passed in function arguments. These arguments need to be
589    /// passed through any functions called from the entry point.
590    const fn needs_pass_through(&self) -> bool {
591        match *self {
592            Self::Uniform
593            | Self::Storage { .. }
594            | Self::Private
595            | Self::WorkGroup
596            | Self::PushConstant
597            | Self::Handle => true,
598            Self::Function => false,
599        }
600    }
601
602    /// Returns true if the address space may need a "const" qualifier.
603    const fn needs_access_qualifier(&self) -> bool {
604        match *self {
605            //Note: we are ignoring the storage access here, and instead
606            // rely on the actual use of a global by functions. This means we
607            // may end up with "const" even if the binding is read-write,
608            // and that should be OK.
609            Self::Storage { .. } => true,
610            // These should always be read-write.
611            Self::Private | Self::WorkGroup => false,
612            // These translate to `constant` address space, no need for qualifiers.
613            Self::Uniform | Self::PushConstant => false,
614            // Not applicable.
615            Self::Handle | Self::Function => false,
616        }
617    }
618
619    const fn to_msl_name(self) -> Option<&'static str> {
620        match self {
621            Self::Handle => None,
622            Self::Uniform | Self::PushConstant => Some("constant"),
623            Self::Storage { .. } => Some("device"),
624            Self::Private | Self::Function => Some("thread"),
625            Self::WorkGroup => Some("threadgroup"),
626        }
627    }
628}
629
630impl crate::Type {
631    // Returns `true` if we need to emit an alias for this type.
632    const fn needs_alias(&self) -> bool {
633        use crate::TypeInner as Ti;
634
635        match self.inner {
636            // value types are concise enough, we only alias them if they are named
637            Ti::Scalar(_)
638            | Ti::Vector { .. }
639            | Ti::Matrix { .. }
640            | Ti::Atomic(_)
641            | Ti::Pointer { .. }
642            | Ti::ValuePointer { .. } => self.name.is_some(),
643            // composite types are better to be aliased, regardless of the name
644            Ti::Struct { .. } | Ti::Array { .. } => true,
645            // handle types may be different, depending on the global var access, so we always inline them
646            Ti::Image { .. }
647            | Ti::Sampler { .. }
648            | Ti::AccelerationStructure { .. }
649            | Ti::RayQuery { .. }
650            | Ti::BindingArray { .. } => false,
651        }
652    }
653}
654
655#[derive(Clone, Copy)]
656enum FunctionOrigin {
657    Handle(Handle<crate::Function>),
658    EntryPoint(proc::EntryPointIndex),
659}
660
661trait NameKeyExt {
662    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
663        match origin {
664            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
665            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
666        }
667    }
668
669    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
670    /// policy when it needs to produce a pointer-typed result for an OOB access. These
671    /// are unique per accessed type, so the second argument is a type handle. See docs
672    /// for [`crate::back::msl`].
673    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
674        match origin {
675            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
676            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
677        }
678    }
679}
680
681impl NameKeyExt for NameKey {}
682
683/// A level of detail argument.
684///
685/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
686/// save the clamped level of detail in a temporary variable whose name is based
687/// on the handle of the `ImageLoad` expression. But for other policies, we just
688/// use the expression directly.
689///
690/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
691/// [`ImageLoad`]: crate::Expression::ImageLoad
692#[derive(Clone, Copy)]
693enum LevelOfDetail {
694    Direct(Handle<crate::Expression>),
695    Restricted(Handle<crate::Expression>),
696}
697
698/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
699///
700/// When this is used in code paths unconcerned with the `Restrict` bounds check
701/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
702/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
703/// be too awkward. If that changes, we can revisit.
704///
705/// [`ImageLoad`]: crate::Expression::ImageLoad
706/// [`ImageStore`]: crate::Statement::ImageStore
707struct TexelAddress {
708    coordinate: Handle<crate::Expression>,
709    array_index: Option<Handle<crate::Expression>>,
710    sample: Option<Handle<crate::Expression>>,
711    level: Option<LevelOfDetail>,
712}
713
714struct ExpressionContext<'a> {
715    function: &'a crate::Function,
716    origin: FunctionOrigin,
717    info: &'a valid::FunctionInfo,
718    module: &'a crate::Module,
719    mod_info: &'a valid::ModuleInfo,
720    pipeline_options: &'a PipelineOptions,
721    lang_version: (u8, u8),
722    policies: index::BoundsCheckPolicies,
723
724    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
725    /// accesses. These may need to be cached in temporary variables. See
726    /// `index::find_checked_indexes` for details.
727    guarded_indices: HandleSet<crate::Expression>,
728    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
729    force_loop_bounding: bool,
730}
731
732impl<'a> ExpressionContext<'a> {
733    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
734        self.info[handle].ty.inner_with(&self.module.types)
735    }
736
737    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
738    ///
739    /// Only mipmapped images need to specify a level of detail. Since 1D
740    /// textures cannot have mipmaps, MSL requires that the level argument to
741    /// texture1d queries and accesses must be a constexpr 0. It's easiest
742    /// just to omit the level entirely for 1D textures.
743    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
744        let image_ty = self.resolve_type(image);
745        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
746            class.is_mipmapped() && dim != crate::ImageDimension::D1
747        } else {
748            false
749        }
750    }
751
752    fn choose_bounds_check_policy(
753        &self,
754        pointer: Handle<crate::Expression>,
755    ) -> index::BoundsCheckPolicy {
756        self.policies
757            .choose_policy(pointer, &self.module.types, self.info)
758    }
759
760    /// See docs for [`proc::index::access_needs_check`].
761    fn access_needs_check(
762        &self,
763        base: Handle<crate::Expression>,
764        index: index::GuardedIndex,
765    ) -> Option<index::IndexableLength> {
766        index::access_needs_check(
767            base,
768            index,
769            self.module,
770            &self.function.expressions,
771            self.info,
772        )
773    }
774
775    /// See docs for [`proc::index::bounds_check_iter`].
776    fn bounds_check_iter(
777        &self,
778        chain: Handle<crate::Expression>,
779    ) -> impl Iterator<Item = BoundsCheck> + '_ {
780        index::bounds_check_iter(chain, self.module, self.function, self.info)
781    }
782
783    /// See docs for [`proc::index::oob_local_types`].
784    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
785        index::oob_local_types(self.module, self.function, self.info, self.policies)
786    }
787
788    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
789        match self.function.expressions[expr_handle] {
790            crate::Expression::AccessIndex { base, index } => {
791                let ty = match *self.resolve_type(base) {
792                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
793                    ref ty => ty,
794                };
795                match *ty {
796                    crate::TypeInner::Struct {
797                        ref members, span, ..
798                    } => should_pack_struct_member(members, span, index as usize, self.module),
799                    _ => None,
800                }
801            }
802            _ => None,
803        }
804    }
805}
806
807struct StatementContext<'a> {
808    expression: ExpressionContext<'a>,
809    result_struct: Option<&'a str>,
810}
811
812impl<W: Write> Writer<W> {
813    /// Creates a new `Writer` instance.
814    pub fn new(out: W) -> Self {
815        Writer {
816            out,
817            names: FastHashMap::default(),
818            named_expressions: Default::default(),
819            need_bake_expressions: Default::default(),
820            namer: proc::Namer::default(),
821            wrapped_functions: FastHashSet::default(),
822            #[cfg(test)]
823            put_expression_stack_pointers: Default::default(),
824            #[cfg(test)]
825            put_block_stack_pointers: Default::default(),
826            struct_member_pads: FastHashSet::default(),
827        }
828    }
829
830    /// Finishes writing and returns the output.
831    // See https://github.com/rust-lang/rust-clippy/issues/4979.
832    #[allow(clippy::missing_const_for_fn)]
833    pub fn finish(self) -> W {
834        self.out
835    }
836
837    /// Generates statements to be inserted immediately before and at the very
838    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
839    /// The 0th item of the returned tuple should be inserted immediately prior
840    /// to the loop and the 1st item should be inserted at the very start of
841    /// the loop body.
842    ///
843    /// # What is this trying to solve?
844    ///
845    /// In Metal Shading Language, an infinite loop has undefined behavior.
846    /// (This rule is inherited from C++14.) This means that, if the MSL
847    /// compiler determines that a given loop will never exit, it may assume
848    /// that it is never reached. It may thus assume that any conditions
849    /// sufficient to cause the loop to be reached must be false. Like many
850    /// optimizing compilers, MSL uses this kind of analysis to establish limits
851    /// on the range of values variables involved in those conditions might
852    /// hold.
853    ///
854    /// For example, suppose the MSL compiler sees the code:
855    ///
856    /// ```ignore
857    /// if (i >= 10) {
858    ///     while (true) { }
859    /// }
860    /// ```
861    ///
862    /// It will recognize that the `while` loop will never terminate, conclude
863    /// that it must be unreachable, and thus infer that, if this code is
864    /// reached, then `i < 10` at that point.
865    ///
866    /// Now suppose that, at some point where `i` has the same value as above,
867    /// the compiler sees the code:
868    ///
869    /// ```ignore
870    /// if (i < 10) {
871    ///     a[i] = 1;
872    /// }
873    /// ```
874    ///
875    /// Because the compiler is confident that `i < 10`, it will make the
876    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
877    ///
878    /// ```ignore
879    /// a[i] = 1;
880    /// ```
881    ///
882    /// If that `if` condition was injected by Naga to implement a bounds check,
883    /// the MSL compiler's optimizations could allow out-of-bounds array
884    /// accesses to occur.
885    ///
886    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
887    /// that a loop is infinite, so an attacker could craft a Naga module
888    /// containing an infinite loop protected by conditions that cause the Metal
889    /// compiler to remove bounds checks that Naga injected elsewhere in the
890    /// function.
891    ///
892    /// This rewrite could occur even if the conditional assignment appears
893    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
894    /// reached. This would allow the attacker to save the results of
895    /// unauthorized reads somewhere accessible before entering the infinite
896    /// loop. But even worse, the MSL compiler has been observed to simply
897    /// delete the infinite loop entirely, so that even code dominated by the
898    /// loop becomes reachable. This would make the attack even more flexible,
899    /// since shaders that would appear to never terminate would actually exit
900    /// nicely, after having stolen data from elsewhere in the GPU address
901    /// space.
902    ///
903    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
904    /// generates is infinite. One approach would be to add inline assembly to
905    /// each loop that is annotated as potentially branching out of the loop,
906    /// but which in fact generates no instructions. Unfortunately, inline
907    /// assembly is not handled correctly by some Metal device drivers.
908    ///
909    /// A previously used approach was to add the following code to the bottom
910    /// of every loop:
911    ///
912    /// ```ignore
913    /// if (volatile bool unpredictable = false; unpredictable)
914    ///     break;
915    /// ```
916    ///
917    /// Although the `if` condition will always be false in any real execution,
918    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
919    /// it must assume that the `break` might be reached, and hence that the
920    /// loop is not unbounded. This prevents the range analysis impact described
921    /// above. Unfortunately this prevented the compiler from making important,
922    /// and safe, optimizations such as loop unrolling and was observed to
923    /// significantly hurt performance.
924    ///
925    /// Our current approach declares a counter before every loop and
926    /// increments it every iteration, breaking after 2^64 iterations:
927    ///
928    /// ```ignore
929    /// uint2 loop_bound = uint2(0);
930    /// while (true) {
931    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
932    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
933    /// }
934    /// ```
935    ///
936    /// This convinces the compiler that the loop is finite and therefore may
937    /// execute, whilst at the same time allowing optimizations such as loop
938    /// unrolling. Furthermore the 64-bit counter is large enough it seems
939    /// implausible that it would affect the execution of any shader.
940    ///
941    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
942    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
943    fn gen_force_bounded_loop_statements(
944        &mut self,
945        level: back::Level,
946        context: &StatementContext,
947    ) -> Option<(String, String)> {
948        if !context.expression.force_loop_bounding {
949            return None;
950        }
951
952        let loop_bound_name = self.namer.call("loop_bound");
953        // Count down from u32::MAX rather than up from 0 to avoid hang on
954        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
955        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
956        let level = level.next();
957        let break_and_inc = format!(
958            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
959{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
960        );
961
962        Some((decl, break_and_inc))
963    }
964
965    fn put_call_parameters(
966        &mut self,
967        parameters: impl Iterator<Item = Handle<crate::Expression>>,
968        context: &ExpressionContext,
969    ) -> BackendResult {
970        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
971            writer.put_expression(expr, context, true)
972        })
973    }
974
975    fn put_call_parameters_impl<C, E>(
976        &mut self,
977        parameters: impl Iterator<Item = Handle<crate::Expression>>,
978        ctx: &C,
979        put_expression: E,
980    ) -> BackendResult
981    where
982        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
983    {
984        write!(self.out, "(")?;
985        for (i, handle) in parameters.enumerate() {
986            if i != 0 {
987                write!(self.out, ", ")?;
988            }
989            put_expression(self, ctx, handle)?;
990        }
991        write!(self.out, ")")?;
992        Ok(())
993    }
994
995    /// Writes the local variables of the given function, as well as any extra
996    /// out-of-bounds locals that are needed.
997    ///
998    /// The names of the OOB locals are also added to `self.names` at the same
999    /// time.
1000    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1001        let oob_local_types = context.oob_local_types();
1002        for &ty in oob_local_types.iter() {
1003            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1004            self.names.insert(name_key, self.namer.call("oob"));
1005        }
1006
1007        for (name_key, ty, init) in context
1008            .function
1009            .local_variables
1010            .iter()
1011            .map(|(local_handle, local)| {
1012                let name_key = NameKey::local(context.origin, local_handle);
1013                (name_key, local.ty, local.init)
1014            })
1015            .chain(oob_local_types.iter().map(|&ty| {
1016                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1017                (name_key, ty, None)
1018            }))
1019        {
1020            let ty_name = TypeContext {
1021                handle: ty,
1022                gctx: context.module.to_ctx(),
1023                names: &self.names,
1024                access: crate::StorageAccess::empty(),
1025                first_time: false,
1026            };
1027            write!(
1028                self.out,
1029                "{}{} {}",
1030                back::INDENT,
1031                ty_name,
1032                self.names[&name_key]
1033            )?;
1034            match init {
1035                Some(value) => {
1036                    write!(self.out, " = ")?;
1037                    self.put_expression(value, context, true)?;
1038                }
1039                None => {
1040                    write!(self.out, " = {{}}")?;
1041                }
1042            };
1043            writeln!(self.out, ";")?;
1044        }
1045        Ok(())
1046    }
1047
1048    fn put_level_of_detail(
1049        &mut self,
1050        level: LevelOfDetail,
1051        context: &ExpressionContext,
1052    ) -> BackendResult {
1053        match level {
1054            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1055            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1056        }
1057        Ok(())
1058    }
1059
1060    fn put_image_query(
1061        &mut self,
1062        image: Handle<crate::Expression>,
1063        query: &str,
1064        level: Option<LevelOfDetail>,
1065        context: &ExpressionContext,
1066    ) -> BackendResult {
1067        self.put_expression(image, context, false)?;
1068        write!(self.out, ".get_{query}(")?;
1069        if let Some(level) = level {
1070            self.put_level_of_detail(level, context)?;
1071        }
1072        write!(self.out, ")")?;
1073        Ok(())
1074    }
1075
1076    fn put_image_size_query(
1077        &mut self,
1078        image: Handle<crate::Expression>,
1079        level: Option<LevelOfDetail>,
1080        kind: crate::ScalarKind,
1081        context: &ExpressionContext,
1082    ) -> BackendResult {
1083        if let crate::TypeInner::Image {
1084            class: crate::ImageClass::External,
1085            ..
1086        } = *context.resolve_type(image)
1087        {
1088            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1089            self.put_expression(image, context, true)?;
1090            write!(self.out, ")")?;
1091            return Ok(());
1092        }
1093
1094        //Note: MSL only has separate width/height/depth queries,
1095        // so compose the result of them.
1096        let dim = match *context.resolve_type(image) {
1097            crate::TypeInner::Image { dim, .. } => dim,
1098            ref other => unreachable!("Unexpected type {:?}", other),
1099        };
1100        let scalar = crate::Scalar { kind, width: 4 };
1101        let coordinate_type = scalar.to_msl_name();
1102        match dim {
1103            crate::ImageDimension::D1 => {
1104                // Since 1D textures never have mipmaps, MSL requires that the
1105                // `level` argument be a constexpr 0. It's simplest for us just
1106                // to pass `None` and omit the level entirely.
1107                if kind == crate::ScalarKind::Uint {
1108                    // No need to construct a vector. No cast needed.
1109                    self.put_image_query(image, "width", None, context)?;
1110                } else {
1111                    // There's no definition for `int` in the `metal` namespace.
1112                    write!(self.out, "int(")?;
1113                    self.put_image_query(image, "width", None, context)?;
1114                    write!(self.out, ")")?;
1115                }
1116            }
1117            crate::ImageDimension::D2 => {
1118                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1119                self.put_image_query(image, "width", level, context)?;
1120                write!(self.out, ", ")?;
1121                self.put_image_query(image, "height", level, context)?;
1122                write!(self.out, ")")?;
1123            }
1124            crate::ImageDimension::D3 => {
1125                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1126                self.put_image_query(image, "width", level, context)?;
1127                write!(self.out, ", ")?;
1128                self.put_image_query(image, "height", level, context)?;
1129                write!(self.out, ", ")?;
1130                self.put_image_query(image, "depth", level, context)?;
1131                write!(self.out, ")")?;
1132            }
1133            crate::ImageDimension::Cube => {
1134                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1135                self.put_image_query(image, "width", level, context)?;
1136                write!(self.out, ")")?;
1137            }
1138        }
1139        Ok(())
1140    }
1141
1142    fn put_cast_to_uint_scalar_or_vector(
1143        &mut self,
1144        expr: Handle<crate::Expression>,
1145        context: &ExpressionContext,
1146    ) -> BackendResult {
1147        // coordinates in IR are int, but Metal expects uint
1148        match *context.resolve_type(expr) {
1149            crate::TypeInner::Scalar(_) => {
1150                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1151            }
1152            crate::TypeInner::Vector { size, .. } => {
1153                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1154            }
1155            _ => {
1156                return Err(Error::GenericValidation(
1157                    "Invalid type for image coordinate".into(),
1158                ))
1159            }
1160        };
1161
1162        write!(self.out, "(")?;
1163        self.put_expression(expr, context, true)?;
1164        write!(self.out, ")")?;
1165        Ok(())
1166    }
1167
1168    fn put_image_sample_level(
1169        &mut self,
1170        image: Handle<crate::Expression>,
1171        level: crate::SampleLevel,
1172        context: &ExpressionContext,
1173    ) -> BackendResult {
1174        let has_levels = context.image_needs_lod(image);
1175        match level {
1176            crate::SampleLevel::Auto => {}
1177            crate::SampleLevel::Zero => {
1178                //TODO: do we support Zero on `Sampled` image classes?
1179            }
1180            _ if !has_levels => {
1181                log::warn!("1D image can't be sampled with level {level:?}");
1182            }
1183            crate::SampleLevel::Exact(h) => {
1184                write!(self.out, ", {NAMESPACE}::level(")?;
1185                self.put_expression(h, context, true)?;
1186                write!(self.out, ")")?;
1187            }
1188            crate::SampleLevel::Bias(h) => {
1189                write!(self.out, ", {NAMESPACE}::bias(")?;
1190                self.put_expression(h, context, true)?;
1191                write!(self.out, ")")?;
1192            }
1193            crate::SampleLevel::Gradient { x, y } => {
1194                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1195                self.put_expression(x, context, true)?;
1196                write!(self.out, ", ")?;
1197                self.put_expression(y, context, true)?;
1198                write!(self.out, ")")?;
1199            }
1200        }
1201        Ok(())
1202    }
1203
1204    fn put_image_coordinate_limits(
1205        &mut self,
1206        image: Handle<crate::Expression>,
1207        level: Option<LevelOfDetail>,
1208        context: &ExpressionContext,
1209    ) -> BackendResult {
1210        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1211        write!(self.out, " - 1")?;
1212        Ok(())
1213    }
1214
1215    /// General function for writing restricted image indexes.
1216    ///
1217    /// This is used to produce restricted mip levels, array indices, and sample
1218    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1219    /// [`Restrict`] bounds check policy.
1220    ///
1221    /// This function writes an expression of the form:
1222    ///
1223    /// ```ignore
1224    ///
1225    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1226    ///
1227    /// ```
1228    ///
1229    /// [`ImageLoad`]: crate::Expression::ImageLoad
1230    /// [`ImageStore`]: crate::Statement::ImageStore
1231    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1232    fn put_restricted_scalar_image_index(
1233        &mut self,
1234        image: Handle<crate::Expression>,
1235        index: Handle<crate::Expression>,
1236        limit_method: &str,
1237        context: &ExpressionContext,
1238    ) -> BackendResult {
1239        write!(self.out, "{NAMESPACE}::min(uint(")?;
1240        self.put_expression(index, context, true)?;
1241        write!(self.out, "), ")?;
1242        self.put_expression(image, context, false)?;
1243        write!(self.out, ".{limit_method}() - 1)")?;
1244        Ok(())
1245    }
1246
1247    fn put_restricted_texel_address(
1248        &mut self,
1249        image: Handle<crate::Expression>,
1250        address: &TexelAddress,
1251        context: &ExpressionContext,
1252    ) -> BackendResult {
1253        // Write the coordinate.
1254        write!(self.out, "{NAMESPACE}::min(")?;
1255        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1256        write!(self.out, ", ")?;
1257        self.put_image_coordinate_limits(image, address.level, context)?;
1258        write!(self.out, ")")?;
1259
1260        // Write the array index, if present.
1261        if let Some(array_index) = address.array_index {
1262            write!(self.out, ", ")?;
1263            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1264        }
1265
1266        // Write the sample index, if present.
1267        if let Some(sample) = address.sample {
1268            write!(self.out, ", ")?;
1269            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1270        }
1271
1272        // The level of detail should be clamped and cached by
1273        // `put_cache_restricted_level`, so we don't need to clamp it here.
1274        if let Some(level) = address.level {
1275            write!(self.out, ", ")?;
1276            self.put_level_of_detail(level, context)?;
1277        }
1278
1279        Ok(())
1280    }
1281
1282    /// Write an expression that is true if the given image access is in bounds.
1283    fn put_image_access_bounds_check(
1284        &mut self,
1285        image: Handle<crate::Expression>,
1286        address: &TexelAddress,
1287        context: &ExpressionContext,
1288    ) -> BackendResult {
1289        let mut conjunction = "";
1290
1291        // First, check the level of detail. Only if that is in bounds can we
1292        // use it to find the appropriate bounds for the coordinates.
1293        let level = if let Some(level) = address.level {
1294            write!(self.out, "uint(")?;
1295            self.put_level_of_detail(level, context)?;
1296            write!(self.out, ") < ")?;
1297            self.put_expression(image, context, true)?;
1298            write!(self.out, ".get_num_mip_levels()")?;
1299            conjunction = " && ";
1300            Some(level)
1301        } else {
1302            None
1303        };
1304
1305        // Check sample index, if present.
1306        if let Some(sample) = address.sample {
1307            write!(self.out, "uint(")?;
1308            self.put_expression(sample, context, true)?;
1309            write!(self.out, ") < ")?;
1310            self.put_expression(image, context, true)?;
1311            write!(self.out, ".get_num_samples()")?;
1312            conjunction = " && ";
1313        }
1314
1315        // Check array index, if present.
1316        if let Some(array_index) = address.array_index {
1317            write!(self.out, "{conjunction}uint(")?;
1318            self.put_expression(array_index, context, true)?;
1319            write!(self.out, ") < ")?;
1320            self.put_expression(image, context, true)?;
1321            write!(self.out, ".get_array_size()")?;
1322            conjunction = " && ";
1323        }
1324
1325        // Finally, check if the coordinates are within bounds.
1326        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1327            crate::TypeInner::Vector { .. } => true,
1328            _ => false,
1329        };
1330        write!(self.out, "{conjunction}")?;
1331        if coord_is_vector {
1332            write!(self.out, "{NAMESPACE}::all(")?;
1333        }
1334        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1335        write!(self.out, " < ")?;
1336        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1337        if coord_is_vector {
1338            write!(self.out, ")")?;
1339        }
1340
1341        Ok(())
1342    }
1343
1344    fn put_image_load(
1345        &mut self,
1346        load: Handle<crate::Expression>,
1347        image: Handle<crate::Expression>,
1348        mut address: TexelAddress,
1349        context: &ExpressionContext,
1350    ) -> BackendResult {
1351        if let crate::TypeInner::Image {
1352            class: crate::ImageClass::External,
1353            ..
1354        } = *context.resolve_type(image)
1355        {
1356            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1357            self.put_expression(image, context, true)?;
1358            write!(self.out, ", ")?;
1359            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1360            write!(self.out, ")")?;
1361            return Ok(());
1362        }
1363
1364        match context.policies.image_load {
1365            proc::BoundsCheckPolicy::Restrict => {
1366                // Use the cached restricted level of detail, if any. Omit the
1367                // level altogether for 1D textures.
1368                if address.level.is_some() {
1369                    address.level = if context.image_needs_lod(image) {
1370                        Some(LevelOfDetail::Restricted(load))
1371                    } else {
1372                        None
1373                    }
1374                }
1375
1376                self.put_expression(image, context, false)?;
1377                write!(self.out, ".read(")?;
1378                self.put_restricted_texel_address(image, &address, context)?;
1379                write!(self.out, ")")?;
1380            }
1381            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1382                write!(self.out, "(")?;
1383                self.put_image_access_bounds_check(image, &address, context)?;
1384                write!(self.out, " ? ")?;
1385                self.put_unchecked_image_load(image, &address, context)?;
1386                write!(self.out, ": DefaultConstructible())")?;
1387            }
1388            proc::BoundsCheckPolicy::Unchecked => {
1389                self.put_unchecked_image_load(image, &address, context)?;
1390            }
1391        }
1392
1393        Ok(())
1394    }
1395
1396    fn put_unchecked_image_load(
1397        &mut self,
1398        image: Handle<crate::Expression>,
1399        address: &TexelAddress,
1400        context: &ExpressionContext,
1401    ) -> BackendResult {
1402        self.put_expression(image, context, false)?;
1403        write!(self.out, ".read(")?;
1404        // coordinates in IR are int, but Metal expects uint
1405        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1406        if let Some(expr) = address.array_index {
1407            write!(self.out, ", ")?;
1408            self.put_expression(expr, context, true)?;
1409        }
1410        if let Some(sample) = address.sample {
1411            write!(self.out, ", ")?;
1412            self.put_expression(sample, context, true)?;
1413        }
1414        if let Some(level) = address.level {
1415            if context.image_needs_lod(image) {
1416                write!(self.out, ", ")?;
1417                self.put_level_of_detail(level, context)?;
1418            }
1419        }
1420        write!(self.out, ")")?;
1421
1422        Ok(())
1423    }
1424
1425    fn put_image_atomic(
1426        &mut self,
1427        level: back::Level,
1428        image: Handle<crate::Expression>,
1429        address: &TexelAddress,
1430        fun: crate::AtomicFunction,
1431        value: Handle<crate::Expression>,
1432        context: &StatementContext,
1433    ) -> BackendResult {
1434        write!(self.out, "{level}")?;
1435        self.put_expression(image, &context.expression, false)?;
1436        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1437            fun.to_msl_64_bit()?
1438        } else {
1439            fun.to_msl()
1440        };
1441        write!(self.out, ".atomic_{op}(")?;
1442        // coordinates in IR are int, but Metal expects uint
1443        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1444        write!(self.out, ", ")?;
1445        self.put_expression(value, &context.expression, true)?;
1446        writeln!(self.out, ");")?;
1447
1448        Ok(())
1449    }
1450
1451    fn put_image_store(
1452        &mut self,
1453        level: back::Level,
1454        image: Handle<crate::Expression>,
1455        address: &TexelAddress,
1456        value: Handle<crate::Expression>,
1457        context: &StatementContext,
1458    ) -> BackendResult {
1459        write!(self.out, "{level}")?;
1460        self.put_expression(image, &context.expression, false)?;
1461        write!(self.out, ".write(")?;
1462        self.put_expression(value, &context.expression, true)?;
1463        write!(self.out, ", ")?;
1464        // coordinates in IR are int, but Metal expects uint
1465        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1466        if let Some(expr) = address.array_index {
1467            write!(self.out, ", ")?;
1468            self.put_expression(expr, &context.expression, true)?;
1469        }
1470        writeln!(self.out, ");")?;
1471
1472        Ok(())
1473    }
1474
1475    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1476    ///
1477    /// The 'maximum valid index' is simply one less than the array's length.
1478    ///
1479    /// This emits an expression of the form `a / b`, so the caller must
1480    /// parenthesize its output if it will be applying operators of higher
1481    /// precedence.
1482    ///
1483    /// `handle` must be the handle of a global variable whose final member is a
1484    /// dynamically sized array.
1485    fn put_dynamic_array_max_index(
1486        &mut self,
1487        handle: Handle<crate::GlobalVariable>,
1488        context: &ExpressionContext,
1489    ) -> BackendResult {
1490        let global = &context.module.global_variables[handle];
1491        let (offset, array_ty) = match context.module.types[global.ty].inner {
1492            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1493                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1494                None => return Err(Error::GenericValidation("Struct has no members".into())),
1495            },
1496            crate::TypeInner::Array {
1497                size: crate::ArraySize::Dynamic,
1498                ..
1499            } => (0, global.ty),
1500            ref ty => {
1501                return Err(Error::GenericValidation(format!(
1502                    "Expected type with dynamic array, got {ty:?}"
1503                )))
1504            }
1505        };
1506
1507        let (size, stride) = match context.module.types[array_ty].inner {
1508            crate::TypeInner::Array { base, stride, .. } => (
1509                context.module.types[base]
1510                    .inner
1511                    .size(context.module.to_ctx()),
1512                stride,
1513            ),
1514            ref ty => {
1515                return Err(Error::GenericValidation(format!(
1516                    "Expected array type, got {ty:?}"
1517                )))
1518            }
1519        };
1520
1521        // When the stride length is larger than the size, the final element's stride of
1522        // bytes would have padding following the value. But the buffer size in
1523        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1524        // enough to hold the actual values' bytes.
1525        //
1526        // So subtract off the size to get a byte size that falls at the start or within
1527        // the final element. Then divide by the stride size, to get one less than the
1528        // length, and then add one. This works even if the buffer size does include the
1529        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1530        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1531        // rules, together with draw-time validation when `minBindingSize` is zero,
1532        // prevent that.
1533        write!(
1534            self.out,
1535            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1536            member = ArraySizeMember(handle),
1537            offset = offset,
1538            size = size,
1539            stride = stride,
1540        )?;
1541        Ok(())
1542    }
1543
1544    /// Emit code for the arithmetic expression of the dot product.
1545    ///
1546    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1547    /// an index. It writes out the expression for the vector component at that index.
1548    fn put_dot_product<T: Copy>(
1549        &mut self,
1550        arg: T,
1551        arg1: T,
1552        size: usize,
1553        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1554    ) -> BackendResult {
1555        // Write parentheses around the dot product expression to prevent operators
1556        // with different precedences from applying earlier.
1557        write!(self.out, "(")?;
1558
1559        // Cycle through all the components of the vector
1560        for index in 0..size {
1561            // Write the addition to the previous product
1562            // This will print an extra '+' at the beginning but that is fine in msl
1563            write!(self.out, " + ")?;
1564            extractor(self, arg, index)?;
1565            write!(self.out, " * ")?;
1566            extractor(self, arg1, index)?;
1567        }
1568
1569        write!(self.out, ")")?;
1570        Ok(())
1571    }
1572
1573    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1574    fn put_pack4x8(
1575        &mut self,
1576        arg: Handle<crate::Expression>,
1577        context: &ExpressionContext<'_>,
1578        was_signed: bool,
1579        clamp_bounds: Option<(&str, &str)>,
1580    ) -> Result<(), Error> {
1581        let write_arg = |this: &mut Self| -> BackendResult {
1582            if let Some((min, max)) = clamp_bounds {
1583                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1584                write!(this.out, "{NAMESPACE}::clamp(")?;
1585                this.put_expression(arg, context, true)?;
1586                write!(this.out, ", {min}, {max})")?;
1587            } else {
1588                this.put_expression(arg, context, true)?;
1589            }
1590            Ok(())
1591        };
1592
1593        if context.lang_version >= (2, 1) {
1594            let packed_type = if was_signed {
1595                "packed_char4"
1596            } else {
1597                "packed_uchar4"
1598            };
1599            // Metal uses little endian byte order, which matches what WGSL expects here.
1600            write!(self.out, "as_type<uint>({packed_type}(")?;
1601            write_arg(self)?;
1602            write!(self.out, "))")?;
1603        } else {
1604            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1605            if was_signed {
1606                write!(self.out, "uint(")?;
1607            }
1608            write!(self.out, "(")?;
1609            write_arg(self)?;
1610            write!(self.out, "[0] & 0xFF) | ((")?;
1611            write_arg(self)?;
1612            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1613            write_arg(self)?;
1614            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1615            write_arg(self)?;
1616            write!(self.out, "[3] & 0xFF) << 24)")?;
1617            if was_signed {
1618                write!(self.out, ")")?;
1619            }
1620        }
1621
1622        Ok(())
1623    }
1624
1625    /// Emit code for the isign expression.
1626    ///
1627    fn put_isign(
1628        &mut self,
1629        arg: Handle<crate::Expression>,
1630        context: &ExpressionContext,
1631    ) -> BackendResult {
1632        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1633        let scalar = context
1634            .resolve_type(arg)
1635            .scalar()
1636            .expect("put_isign should only be called for args which have an integer scalar type")
1637            .to_msl_name();
1638        match context.resolve_type(arg) {
1639            &crate::TypeInner::Vector { size, .. } => {
1640                let size = common::vector_size_str(size);
1641                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1642            }
1643            _ => {
1644                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1645            }
1646        }
1647        write!(self.out, ", (")?;
1648        self.put_expression(arg, context, true)?;
1649        write!(self.out, " > 0)), {scalar}(0), (")?;
1650        self.put_expression(arg, context, true)?;
1651        write!(self.out, " == 0))")?;
1652        Ok(())
1653    }
1654
1655    fn put_const_expression(
1656        &mut self,
1657        expr_handle: Handle<crate::Expression>,
1658        module: &crate::Module,
1659        mod_info: &valid::ModuleInfo,
1660        arena: &crate::Arena<crate::Expression>,
1661    ) -> BackendResult {
1662        self.put_possibly_const_expression(
1663            expr_handle,
1664            arena,
1665            module,
1666            mod_info,
1667            &(module, mod_info),
1668            |&(_, mod_info), expr| &mod_info[expr],
1669            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1670        )
1671    }
1672
1673    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1674        match literal {
1675            crate::Literal::F64(_) => {
1676                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1677            }
1678            crate::Literal::F16(value) => {
1679                if value.is_infinite() {
1680                    let sign = if value.is_sign_negative() { "-" } else { "" };
1681                    write!(self.out, "{sign}INFINITY")?;
1682                } else if value.is_nan() {
1683                    write!(self.out, "NAN")?;
1684                } else {
1685                    let suffix = if value.fract() == f16::from_f32(0.0) {
1686                        ".0h"
1687                    } else {
1688                        "h"
1689                    };
1690                    write!(self.out, "{value}{suffix}")?;
1691                }
1692            }
1693            crate::Literal::F32(value) => {
1694                if value.is_infinite() {
1695                    let sign = if value.is_sign_negative() { "-" } else { "" };
1696                    write!(self.out, "{sign}INFINITY")?;
1697                } else if value.is_nan() {
1698                    write!(self.out, "NAN")?;
1699                } else {
1700                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1701                    write!(self.out, "{value}{suffix}")?;
1702                }
1703            }
1704            crate::Literal::U32(value) => {
1705                write!(self.out, "{value}u")?;
1706            }
1707            crate::Literal::I32(value) => {
1708                // `-2147483648` is parsed as unary negation of positive 2147483648.
1709                // 2147483648 is too large for int32_t meaning the expression gets
1710                // promoted to a int64_t which is not our intention. Avoid this by instead
1711                // using `-2147483647 - 1`.
1712                if value == i32::MIN {
1713                    write!(self.out, "({} - 1)", value + 1)?;
1714                } else {
1715                    write!(self.out, "{value}")?;
1716                }
1717            }
1718            crate::Literal::U64(value) => {
1719                write!(self.out, "{value}uL")?;
1720            }
1721            crate::Literal::I64(value) => {
1722                // `-9223372036854775808` is parsed as unary negation of positive
1723                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1724                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1725                // value to `-9223372036854775808`. Which would then be negated, possibly
1726                // causing undefined behaviour. Avoid this by instead using
1727                // `-9223372036854775808L - 1L`.
1728                if value == i64::MIN {
1729                    write!(self.out, "({}L - 1L)", value + 1)?;
1730                } else {
1731                    write!(self.out, "{value}L")?;
1732                }
1733            }
1734            crate::Literal::Bool(value) => {
1735                write!(self.out, "{value}")?;
1736            }
1737            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1738                return Err(Error::GenericValidation(
1739                    "Unsupported abstract literal".into(),
1740                ));
1741            }
1742        }
1743        Ok(())
1744    }
1745
1746    #[allow(clippy::too_many_arguments)]
1747    fn put_possibly_const_expression<C, I, E>(
1748        &mut self,
1749        expr_handle: Handle<crate::Expression>,
1750        expressions: &crate::Arena<crate::Expression>,
1751        module: &crate::Module,
1752        mod_info: &valid::ModuleInfo,
1753        ctx: &C,
1754        get_expr_ty: I,
1755        put_expression: E,
1756    ) -> BackendResult
1757    where
1758        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1759        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1760    {
1761        match expressions[expr_handle] {
1762            crate::Expression::Literal(literal) => {
1763                self.put_literal(literal)?;
1764            }
1765            crate::Expression::Constant(handle) => {
1766                let constant = &module.constants[handle];
1767                if constant.name.is_some() {
1768                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1769                } else {
1770                    self.put_const_expression(
1771                        constant.init,
1772                        module,
1773                        mod_info,
1774                        &module.global_expressions,
1775                    )?;
1776                }
1777            }
1778            crate::Expression::ZeroValue(ty) => {
1779                let ty_name = TypeContext {
1780                    handle: ty,
1781                    gctx: module.to_ctx(),
1782                    names: &self.names,
1783                    access: crate::StorageAccess::empty(),
1784                    first_time: false,
1785                };
1786                write!(self.out, "{ty_name} {{}}")?;
1787            }
1788            crate::Expression::Compose { ty, ref components } => {
1789                let ty_name = TypeContext {
1790                    handle: ty,
1791                    gctx: module.to_ctx(),
1792                    names: &self.names,
1793                    access: crate::StorageAccess::empty(),
1794                    first_time: false,
1795                };
1796                write!(self.out, "{ty_name}")?;
1797                match module.types[ty].inner {
1798                    crate::TypeInner::Scalar(_)
1799                    | crate::TypeInner::Vector { .. }
1800                    | crate::TypeInner::Matrix { .. } => {
1801                        self.put_call_parameters_impl(
1802                            components.iter().copied(),
1803                            ctx,
1804                            put_expression,
1805                        )?;
1806                    }
1807                    crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1808                        write!(self.out, " {{")?;
1809                        for (index, &component) in components.iter().enumerate() {
1810                            if index != 0 {
1811                                write!(self.out, ", ")?;
1812                            }
1813                            // insert padding initialization, if needed
1814                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1815                                write!(self.out, "{{}}, ")?;
1816                            }
1817                            put_expression(self, ctx, component)?;
1818                        }
1819                        write!(self.out, "}}")?;
1820                    }
1821                    _ => return Err(Error::UnsupportedCompose(ty)),
1822                }
1823            }
1824            crate::Expression::Splat { size, value } => {
1825                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1826                    crate::TypeInner::Scalar(scalar) => scalar,
1827                    ref ty => {
1828                        return Err(Error::GenericValidation(format!(
1829                            "Expected splat value type must be a scalar, got {ty:?}",
1830                        )))
1831                    }
1832                };
1833                put_numeric_type(&mut self.out, scalar, &[size])?;
1834                write!(self.out, "(")?;
1835                put_expression(self, ctx, value)?;
1836                write!(self.out, ")")?;
1837            }
1838            _ => {
1839                return Err(Error::Override);
1840            }
1841        }
1842
1843        Ok(())
1844    }
1845
1846    /// Emit code for the expression `expr_handle`.
1847    ///
1848    /// The `is_scoped` argument is true if the surrounding operators have the
1849    /// precedence of the comma operator, or lower. So, for example:
1850    ///
1851    /// - Pass `true` for `is_scoped` when writing function arguments, an
1852    ///   expression statement, an initializer expression, or anything already
1853    ///   wrapped in parenthesis.
1854    ///
1855    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1856    ///   almost anything else.
1857    fn put_expression(
1858        &mut self,
1859        expr_handle: Handle<crate::Expression>,
1860        context: &ExpressionContext,
1861        is_scoped: bool,
1862    ) -> BackendResult {
1863        // Add to the set in order to track the stack size.
1864        #[cfg(test)]
1865        self.put_expression_stack_pointers
1866            .insert(ptr::from_ref(&expr_handle).cast());
1867
1868        if let Some(name) = self.named_expressions.get(&expr_handle) {
1869            write!(self.out, "{name}")?;
1870            return Ok(());
1871        }
1872
1873        let expression = &context.function.expressions[expr_handle];
1874        match *expression {
1875            crate::Expression::Literal(_)
1876            | crate::Expression::Constant(_)
1877            | crate::Expression::ZeroValue(_)
1878            | crate::Expression::Compose { .. }
1879            | crate::Expression::Splat { .. } => {
1880                self.put_possibly_const_expression(
1881                    expr_handle,
1882                    &context.function.expressions,
1883                    context.module,
1884                    context.mod_info,
1885                    context,
1886                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1887                    |writer, context, expr| writer.put_expression(expr, context, true),
1888                )?;
1889            }
1890            crate::Expression::Override(_) => return Err(Error::Override),
1891            crate::Expression::Access { base, .. }
1892            | crate::Expression::AccessIndex { base, .. } => {
1893                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1894                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1895                // access chain at a time, recursing back through `put_expression` only
1896                // for index expressions and the base object, we will never see intermediate
1897                // `Access` or `AccessIndex` expressions here.
1898                let policy = context.choose_bounds_check_policy(base);
1899                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1900                    && self.put_bounds_checks(
1901                        expr_handle,
1902                        context,
1903                        back::Level(0),
1904                        if is_scoped { "" } else { "(" },
1905                    )?
1906                {
1907                    write!(self.out, " ? ")?;
1908                    self.put_access_chain(expr_handle, policy, context)?;
1909                    write!(self.out, " : ")?;
1910
1911                    if context.resolve_type(base).pointer_space().is_some() {
1912                        // We can't just use `DefaultConstructible` if this is a pointer.
1913                        // Instead, we create a dummy local variable to serve as pointer
1914                        // target if the access is out of bounds.
1915                        let result_ty = context.info[expr_handle]
1916                            .ty
1917                            .inner_with(&context.module.types)
1918                            .pointer_base_type();
1919                        let result_ty_handle = match result_ty {
1920                            Some(TypeResolution::Handle(handle)) => handle,
1921                            Some(TypeResolution::Value(_)) => {
1922                                // As long as the result of a pointer access expression is
1923                                // passed to a function or stored in a let binding, the
1924                                // type will be in the arena. If additional uses of
1925                                // pointers become valid, this assumption might no longer
1926                                // hold. Note that the LHS of a load or store doesn't
1927                                // take this path -- there is dedicated code in `put_load`
1928                                // and `put_store`.
1929                                unreachable!(
1930                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1931                                );
1932                            }
1933                            None => {
1934                                unreachable!(
1935                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1936                                )
1937                            }
1938                        };
1939                        let name_key =
1940                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
1941                        self.out.write_str(&self.names[&name_key])?;
1942                    } else {
1943                        write!(self.out, "DefaultConstructible()")?;
1944                    }
1945
1946                    if !is_scoped {
1947                        write!(self.out, ")")?;
1948                    }
1949                } else {
1950                    self.put_access_chain(expr_handle, policy, context)?;
1951                }
1952            }
1953            crate::Expression::Swizzle {
1954                size,
1955                vector,
1956                pattern,
1957            } => {
1958                self.put_wrapped_expression_for_packed_vec3_access(
1959                    vector,
1960                    context,
1961                    false,
1962                    &Self::put_expression,
1963                )?;
1964                write!(self.out, ".")?;
1965                for &sc in pattern[..size as usize].iter() {
1966                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1967                }
1968            }
1969            crate::Expression::FunctionArgument(index) => {
1970                let name_key = match context.origin {
1971                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1972                    FunctionOrigin::EntryPoint(ep_index) => {
1973                        NameKey::EntryPointArgument(ep_index, index)
1974                    }
1975                };
1976                let name = &self.names[&name_key];
1977                write!(self.out, "{name}")?;
1978            }
1979            crate::Expression::GlobalVariable(handle) => {
1980                let name = &self.names[&NameKey::GlobalVariable(handle)];
1981                write!(self.out, "{name}")?;
1982            }
1983            crate::Expression::LocalVariable(handle) => {
1984                let name_key = NameKey::local(context.origin, handle);
1985                let name = &self.names[&name_key];
1986                write!(self.out, "{name}")?;
1987            }
1988            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1989            crate::Expression::ImageSample {
1990                coordinate,
1991                image,
1992                sampler,
1993                clamp_to_edge: true,
1994                gather: None,
1995                array_index: None,
1996                offset: None,
1997                level: crate::SampleLevel::Zero,
1998                depth_ref: None,
1999            } => {
2000                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2001                self.put_expression(image, context, true)?;
2002                write!(self.out, ", ")?;
2003                self.put_expression(sampler, context, true)?;
2004                write!(self.out, ", ")?;
2005                self.put_expression(coordinate, context, true)?;
2006                write!(self.out, ")")?;
2007            }
2008            crate::Expression::ImageSample {
2009                image,
2010                sampler,
2011                gather,
2012                coordinate,
2013                array_index,
2014                offset,
2015                level,
2016                depth_ref,
2017                clamp_to_edge,
2018            } => {
2019                if clamp_to_edge {
2020                    return Err(Error::GenericValidation(
2021                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2022                    ));
2023                }
2024
2025                let main_op = match gather {
2026                    Some(_) => "gather",
2027                    None => "sample",
2028                };
2029                let comparison_op = match depth_ref {
2030                    Some(_) => "_compare",
2031                    None => "",
2032                };
2033                self.put_expression(image, context, false)?;
2034                write!(self.out, ".{main_op}{comparison_op}(")?;
2035                self.put_expression(sampler, context, true)?;
2036                write!(self.out, ", ")?;
2037                self.put_expression(coordinate, context, true)?;
2038                if let Some(expr) = array_index {
2039                    write!(self.out, ", ")?;
2040                    self.put_expression(expr, context, true)?;
2041                }
2042                if let Some(dref) = depth_ref {
2043                    write!(self.out, ", ")?;
2044                    self.put_expression(dref, context, true)?;
2045                }
2046
2047                self.put_image_sample_level(image, level, context)?;
2048
2049                if let Some(offset) = offset {
2050                    write!(self.out, ", ")?;
2051                    self.put_expression(offset, context, true)?;
2052                }
2053
2054                match gather {
2055                    None | Some(crate::SwizzleComponent::X) => {}
2056                    Some(component) => {
2057                        let is_cube_map = match *context.resolve_type(image) {
2058                            crate::TypeInner::Image {
2059                                dim: crate::ImageDimension::Cube,
2060                                ..
2061                            } => true,
2062                            _ => false,
2063                        };
2064                        // Offset always comes before the gather, except
2065                        // in cube maps where it's not applicable
2066                        if offset.is_none() && !is_cube_map {
2067                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2068                        }
2069                        let letter = back::COMPONENTS[component as usize];
2070                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2071                    }
2072                }
2073                write!(self.out, ")")?;
2074            }
2075            crate::Expression::ImageLoad {
2076                image,
2077                coordinate,
2078                array_index,
2079                sample,
2080                level,
2081            } => {
2082                let address = TexelAddress {
2083                    coordinate,
2084                    array_index,
2085                    sample,
2086                    level: level.map(LevelOfDetail::Direct),
2087                };
2088                self.put_image_load(expr_handle, image, address, context)?;
2089            }
2090            //Note: for all the queries, the signed integers are expected,
2091            // so a conversion is needed.
2092            crate::Expression::ImageQuery { image, query } => match query {
2093                crate::ImageQuery::Size { level } => {
2094                    self.put_image_size_query(
2095                        image,
2096                        level.map(LevelOfDetail::Direct),
2097                        crate::ScalarKind::Uint,
2098                        context,
2099                    )?;
2100                }
2101                crate::ImageQuery::NumLevels => {
2102                    self.put_expression(image, context, false)?;
2103                    write!(self.out, ".get_num_mip_levels()")?;
2104                }
2105                crate::ImageQuery::NumLayers => {
2106                    self.put_expression(image, context, false)?;
2107                    write!(self.out, ".get_array_size()")?;
2108                }
2109                crate::ImageQuery::NumSamples => {
2110                    self.put_expression(image, context, false)?;
2111                    write!(self.out, ".get_num_samples()")?;
2112                }
2113            },
2114            crate::Expression::Unary { op, expr } => {
2115                let op_str = match op {
2116                    crate::UnaryOperator::Negate => {
2117                        match context.resolve_type(expr).scalar_kind() {
2118                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2119                            _ => "-",
2120                        }
2121                    }
2122                    crate::UnaryOperator::LogicalNot => "!",
2123                    crate::UnaryOperator::BitwiseNot => "~",
2124                };
2125                write!(self.out, "{op_str}(")?;
2126                self.put_expression(expr, context, false)?;
2127                write!(self.out, ")")?;
2128            }
2129            crate::Expression::Binary { op, left, right } => {
2130                let kind = context
2131                    .resolve_type(left)
2132                    .scalar_kind()
2133                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2134
2135                if op == crate::BinaryOperator::Divide
2136                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2137                {
2138                    write!(self.out, "{DIV_FUNCTION}(")?;
2139                    self.put_expression(left, context, true)?;
2140                    write!(self.out, ", ")?;
2141                    self.put_expression(right, context, true)?;
2142                    write!(self.out, ")")?;
2143                } else if op == crate::BinaryOperator::Modulo
2144                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2145                {
2146                    write!(self.out, "{MOD_FUNCTION}(")?;
2147                    self.put_expression(left, context, true)?;
2148                    write!(self.out, ", ")?;
2149                    self.put_expression(right, context, true)?;
2150                    write!(self.out, ")")?;
2151                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2152                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2153                    //
2154                    // float:
2155                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2156                    write!(self.out, "{NAMESPACE}::fmod(")?;
2157                    self.put_expression(left, context, true)?;
2158                    write!(self.out, ", ")?;
2159                    self.put_expression(right, context, true)?;
2160                    write!(self.out, ")")?;
2161                } else if (op == crate::BinaryOperator::Add
2162                    || op == crate::BinaryOperator::Subtract
2163                    || op == crate::BinaryOperator::Multiply)
2164                    && kind == crate::ScalarKind::Sint
2165                {
2166                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2167                        crate::TypeInner::Scalar(scalar) => {
2168                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2169                                kind: crate::ScalarKind::Uint,
2170                                ..scalar
2171                            }))
2172                        }
2173                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2174                            size,
2175                            scalar: crate::Scalar {
2176                                kind: crate::ScalarKind::Uint,
2177                                ..scalar
2178                            },
2179                        }),
2180                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2181                    };
2182
2183                    // Avoid undefined behaviour due to overflowing signed
2184                    // integer arithmetic. Cast the operands to unsigned prior
2185                    // to performing the operation, then cast the result back
2186                    // to signed.
2187                    self.put_bitcasted_expression(
2188                        context.resolve_type(expr_handle),
2189                        context,
2190                        &|writer, context, is_scoped| {
2191                            writer.put_binop(
2192                                op,
2193                                left,
2194                                right,
2195                                context,
2196                                is_scoped,
2197                                &|writer, expr, context, _is_scoped| {
2198                                    writer.put_bitcasted_expression(
2199                                        &to_unsigned(context.resolve_type(expr))?,
2200                                        context,
2201                                        &|writer, context, is_scoped| {
2202                                            writer.put_expression(expr, context, is_scoped)
2203                                        },
2204                                    )
2205                                },
2206                            )
2207                        },
2208                    )?;
2209                } else {
2210                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2211                }
2212            }
2213            crate::Expression::Select {
2214                condition,
2215                accept,
2216                reject,
2217            } => match *context.resolve_type(condition) {
2218                crate::TypeInner::Scalar(crate::Scalar {
2219                    kind: crate::ScalarKind::Bool,
2220                    ..
2221                }) => {
2222                    if !is_scoped {
2223                        write!(self.out, "(")?;
2224                    }
2225                    self.put_expression(condition, context, false)?;
2226                    write!(self.out, " ? ")?;
2227                    self.put_expression(accept, context, false)?;
2228                    write!(self.out, " : ")?;
2229                    self.put_expression(reject, context, false)?;
2230                    if !is_scoped {
2231                        write!(self.out, ")")?;
2232                    }
2233                }
2234                crate::TypeInner::Vector {
2235                    scalar:
2236                        crate::Scalar {
2237                            kind: crate::ScalarKind::Bool,
2238                            ..
2239                        },
2240                    ..
2241                } => {
2242                    write!(self.out, "{NAMESPACE}::select(")?;
2243                    self.put_expression(reject, context, true)?;
2244                    write!(self.out, ", ")?;
2245                    self.put_expression(accept, context, true)?;
2246                    write!(self.out, ", ")?;
2247                    self.put_expression(condition, context, true)?;
2248                    write!(self.out, ")")?;
2249                }
2250                ref ty => {
2251                    return Err(Error::GenericValidation(format!(
2252                        "Expected select condition to be a non-bool type, got {ty:?}",
2253                    )))
2254                }
2255            },
2256            crate::Expression::Derivative { axis, expr, .. } => {
2257                use crate::DerivativeAxis as Axis;
2258                let op = match axis {
2259                    Axis::X => "dfdx",
2260                    Axis::Y => "dfdy",
2261                    Axis::Width => "fwidth",
2262                };
2263                write!(self.out, "{NAMESPACE}::{op}")?;
2264                self.put_call_parameters(iter::once(expr), context)?;
2265            }
2266            crate::Expression::Relational { fun, argument } => {
2267                let op = match fun {
2268                    crate::RelationalFunction::Any => "any",
2269                    crate::RelationalFunction::All => "all",
2270                    crate::RelationalFunction::IsNan => "isnan",
2271                    crate::RelationalFunction::IsInf => "isinf",
2272                };
2273                write!(self.out, "{NAMESPACE}::{op}")?;
2274                self.put_call_parameters(iter::once(argument), context)?;
2275            }
2276            crate::Expression::Math {
2277                fun,
2278                arg,
2279                arg1,
2280                arg2,
2281                arg3,
2282            } => {
2283                use crate::MathFunction as Mf;
2284
2285                let arg_type = context.resolve_type(arg);
2286                let scalar_argument = match arg_type {
2287                    &crate::TypeInner::Scalar(_) => true,
2288                    _ => false,
2289                };
2290
2291                let fun_name = match fun {
2292                    // comparison
2293                    Mf::Abs => "abs",
2294                    Mf::Min => "min",
2295                    Mf::Max => "max",
2296                    Mf::Clamp => "clamp",
2297                    Mf::Saturate => "saturate",
2298                    // trigonometry
2299                    Mf::Cos => "cos",
2300                    Mf::Cosh => "cosh",
2301                    Mf::Sin => "sin",
2302                    Mf::Sinh => "sinh",
2303                    Mf::Tan => "tan",
2304                    Mf::Tanh => "tanh",
2305                    Mf::Acos => "acos",
2306                    Mf::Asin => "asin",
2307                    Mf::Atan => "atan",
2308                    Mf::Atan2 => "atan2",
2309                    Mf::Asinh => "asinh",
2310                    Mf::Acosh => "acosh",
2311                    Mf::Atanh => "atanh",
2312                    Mf::Radians => "",
2313                    Mf::Degrees => "",
2314                    // decomposition
2315                    Mf::Ceil => "ceil",
2316                    Mf::Floor => "floor",
2317                    Mf::Round => "rint",
2318                    Mf::Fract => "fract",
2319                    Mf::Trunc => "trunc",
2320                    Mf::Modf => MODF_FUNCTION,
2321                    Mf::Frexp => FREXP_FUNCTION,
2322                    Mf::Ldexp => "ldexp",
2323                    // exponent
2324                    Mf::Exp => "exp",
2325                    Mf::Exp2 => "exp2",
2326                    Mf::Log => "log",
2327                    Mf::Log2 => "log2",
2328                    Mf::Pow => "pow",
2329                    // geometry
2330                    Mf::Dot => match *context.resolve_type(arg) {
2331                        crate::TypeInner::Vector {
2332                            scalar:
2333                                crate::Scalar {
2334                                    kind: crate::ScalarKind::Float,
2335                                    ..
2336                                },
2337                            ..
2338                        } => "dot",
2339                        crate::TypeInner::Vector { size, .. } => {
2340                            return self.put_dot_product(
2341                                arg,
2342                                arg1.unwrap(),
2343                                size as usize,
2344                                |writer, arg, index| {
2345                                    // Write the vector expression; this expression is marked to be
2346                                    // cached so unless it can't be cached (for example, it's a Constant)
2347                                    // it shouldn't produce large expressions.
2348                                    writer.put_expression(arg, context, true)?;
2349                                    // Access the current component on the vector.
2350                                    write!(writer.out, ".{}", back::COMPONENTS[index])?;
2351                                    Ok(())
2352                                },
2353                            );
2354                        }
2355                        _ => unreachable!(
2356                            "Correct TypeInner for dot product should be already validated"
2357                        ),
2358                    },
2359                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2360                        if context.lang_version >= (2, 1) {
2361                            // Write potentially optimizable code using `packed_(u?)char4`.
2362                            // The two function arguments were already reinterpreted as packed (signed
2363                            // or unsigned) chars in `Self::put_block`.
2364                            let packed_type = match fun {
2365                                Mf::Dot4I8Packed => "packed_char4",
2366                                Mf::Dot4U8Packed => "packed_uchar4",
2367                                _ => unreachable!(),
2368                            };
2369
2370                            return self.put_dot_product(
2371                                Reinterpreted::new(packed_type, arg),
2372                                Reinterpreted::new(packed_type, arg1.unwrap()),
2373                                4,
2374                                |writer, arg, index| {
2375                                    // MSL implicitly promotes these (signed or unsigned) chars to
2376                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2377                                    write!(writer.out, "{arg}[{index}]")?;
2378                                    Ok(())
2379                                },
2380                            );
2381                        } else {
2382                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2383                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2384                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2385                            let conversion = match fun {
2386                                Mf::Dot4I8Packed => "int",
2387                                Mf::Dot4U8Packed => "",
2388                                _ => unreachable!(),
2389                            };
2390
2391                            return self.put_dot_product(
2392                                arg,
2393                                arg1.unwrap(),
2394                                4,
2395                                |writer, arg, index| {
2396                                    write!(writer.out, "({conversion}(")?;
2397                                    writer.put_expression(arg, context, true)?;
2398                                    if index == 3 {
2399                                        write!(writer.out, ") >> 24)")?;
2400                                    } else {
2401                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2402                                    }
2403                                    Ok(())
2404                                },
2405                            );
2406                        }
2407                    }
2408                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2409                    Mf::Cross => "cross",
2410                    Mf::Distance => "distance",
2411                    Mf::Length if scalar_argument => "abs",
2412                    Mf::Length => "length",
2413                    Mf::Normalize => "normalize",
2414                    Mf::FaceForward => "faceforward",
2415                    Mf::Reflect => "reflect",
2416                    Mf::Refract => "refract",
2417                    // computational
2418                    Mf::Sign => match arg_type.scalar_kind() {
2419                        Some(crate::ScalarKind::Sint) => {
2420                            return self.put_isign(arg, context);
2421                        }
2422                        _ => "sign",
2423                    },
2424                    Mf::Fma => "fma",
2425                    Mf::Mix => "mix",
2426                    Mf::Step => "step",
2427                    Mf::SmoothStep => "smoothstep",
2428                    Mf::Sqrt => "sqrt",
2429                    Mf::InverseSqrt => "rsqrt",
2430                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2431                    Mf::Transpose => "transpose",
2432                    Mf::Determinant => "determinant",
2433                    Mf::QuantizeToF16 => "",
2434                    // bits
2435                    Mf::CountTrailingZeros => "ctz",
2436                    Mf::CountLeadingZeros => "clz",
2437                    Mf::CountOneBits => "popcount",
2438                    Mf::ReverseBits => "reverse_bits",
2439                    Mf::ExtractBits => "",
2440                    Mf::InsertBits => "",
2441                    Mf::FirstTrailingBit => "",
2442                    Mf::FirstLeadingBit => "",
2443                    // data packing
2444                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2445                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2446                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2447                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2448                    Mf::Pack2x16float => "",
2449                    Mf::Pack4xI8 => "",
2450                    Mf::Pack4xU8 => "",
2451                    Mf::Pack4xI8Clamp => "",
2452                    Mf::Pack4xU8Clamp => "",
2453                    // data unpacking
2454                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2455                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2456                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2457                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2458                    Mf::Unpack2x16float => "",
2459                    Mf::Unpack4xI8 => "",
2460                    Mf::Unpack4xU8 => "",
2461                };
2462
2463                match fun {
2464                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2465                        // reverse_bits is listed as requiring MSL 2.1 but that
2466                        // is a copy/paste error. Looking at previous snapshots
2467                        // on web.archive.org it's present in MSL 1.2.
2468                        //
2469                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2470                        // also talks about MSL 1.2 adding "New integer
2471                        // functions to extract, insert, and reverse bits, as
2472                        // described in Integer Functions."
2473                        if context.lang_version < (1, 2) {
2474                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2475                        }
2476                    }
2477                    _ => {}
2478                }
2479
2480                match fun {
2481                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2482                        write!(self.out, "{ABS_FUNCTION}(")?;
2483                        self.put_expression(arg, context, true)?;
2484                        write!(self.out, ")")?;
2485                    }
2486                    Mf::Distance if scalar_argument => {
2487                        write!(self.out, "{NAMESPACE}::abs(")?;
2488                        self.put_expression(arg, context, false)?;
2489                        write!(self.out, " - ")?;
2490                        self.put_expression(arg1.unwrap(), context, false)?;
2491                        write!(self.out, ")")?;
2492                    }
2493                    Mf::FirstTrailingBit => {
2494                        let scalar = context.resolve_type(arg).scalar().unwrap();
2495                        let constant = scalar.width * 8 + 1;
2496
2497                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2498                        self.put_expression(arg, context, true)?;
2499                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2500                    }
2501                    Mf::FirstLeadingBit => {
2502                        let inner = context.resolve_type(arg);
2503                        let scalar = inner.scalar().unwrap();
2504                        let constant = scalar.width * 8 - 1;
2505
2506                        write!(
2507                            self.out,
2508                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2509                        )?;
2510
2511                        if scalar.kind == crate::ScalarKind::Sint {
2512                            write!(self.out, "{NAMESPACE}::select(")?;
2513                            self.put_expression(arg, context, true)?;
2514                            write!(self.out, ", ~")?;
2515                            self.put_expression(arg, context, true)?;
2516                            write!(self.out, ", ")?;
2517                            self.put_expression(arg, context, true)?;
2518                            write!(self.out, " < 0)")?;
2519                        } else {
2520                            self.put_expression(arg, context, true)?;
2521                        }
2522
2523                        write!(self.out, "), ")?;
2524
2525                        // or metal will complain that select is ambiguous
2526                        match *inner {
2527                            crate::TypeInner::Vector { size, scalar } => {
2528                                let size = common::vector_size_str(size);
2529                                let name = scalar.to_msl_name();
2530                                write!(self.out, "{name}{size}")?;
2531                            }
2532                            crate::TypeInner::Scalar(scalar) => {
2533                                let name = scalar.to_msl_name();
2534                                write!(self.out, "{name}")?;
2535                            }
2536                            _ => (),
2537                        }
2538
2539                        write!(self.out, "(-1), ")?;
2540                        self.put_expression(arg, context, true)?;
2541                        write!(self.out, " == 0 || ")?;
2542                        self.put_expression(arg, context, true)?;
2543                        write!(self.out, " == -1)")?;
2544                    }
2545                    Mf::Unpack2x16float => {
2546                        write!(self.out, "float2(as_type<half2>(")?;
2547                        self.put_expression(arg, context, false)?;
2548                        write!(self.out, "))")?;
2549                    }
2550                    Mf::Pack2x16float => {
2551                        write!(self.out, "as_type<uint>(half2(")?;
2552                        self.put_expression(arg, context, false)?;
2553                        write!(self.out, "))")?;
2554                    }
2555                    Mf::ExtractBits => {
2556                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2557                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2558                        // will return out-of-spec values if the extracted range is not within the bit width.
2559                        //
2560                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2561                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2562                        //
2563                        // w = sizeof(x) * 8
2564                        // o = min(offset, w)
2565                        // tmp = w - o
2566                        // c = min(count, tmp)
2567                        //
2568                        // bitfieldExtract(x, o, c)
2569                        //
2570                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2571
2572                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2573
2574                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2575                        self.put_expression(arg, context, true)?;
2576                        write!(self.out, ", {NAMESPACE}::min(")?;
2577                        self.put_expression(arg1.unwrap(), context, true)?;
2578                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2579                        self.put_expression(arg2.unwrap(), context, true)?;
2580                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2581                        self.put_expression(arg1.unwrap(), context, true)?;
2582                        write!(self.out, ", {scalar_bits}u)))")?;
2583                    }
2584                    Mf::InsertBits => {
2585                        // The behavior of InsertBits has the same issue as ExtractBits.
2586                        //
2587                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2588
2589                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2590
2591                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2592                        self.put_expression(arg, context, true)?;
2593                        write!(self.out, ", ")?;
2594                        self.put_expression(arg1.unwrap(), context, true)?;
2595                        write!(self.out, ", {NAMESPACE}::min(")?;
2596                        self.put_expression(arg2.unwrap(), context, true)?;
2597                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2598                        self.put_expression(arg3.unwrap(), context, true)?;
2599                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2600                        self.put_expression(arg2.unwrap(), context, true)?;
2601                        write!(self.out, ", {scalar_bits}u)))")?;
2602                    }
2603                    Mf::Radians => {
2604                        write!(self.out, "((")?;
2605                        self.put_expression(arg, context, false)?;
2606                        write!(self.out, ") * 0.017453292519943295474)")?;
2607                    }
2608                    Mf::Degrees => {
2609                        write!(self.out, "((")?;
2610                        self.put_expression(arg, context, false)?;
2611                        write!(self.out, ") * 57.295779513082322865)")?;
2612                    }
2613                    Mf::Modf | Mf::Frexp => {
2614                        write!(self.out, "{fun_name}")?;
2615                        self.put_call_parameters(iter::once(arg), context)?;
2616                    }
2617                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2618                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2619                    Mf::Pack4xI8Clamp => {
2620                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2621                    }
2622                    Mf::Pack4xU8Clamp => {
2623                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2624                    }
2625                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2626                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2627                            "u"
2628                        } else {
2629                            ""
2630                        };
2631
2632                        if context.lang_version >= (2, 1) {
2633                            // Metal uses little endian byte order, which matches what WGSL expects here.
2634                            write!(
2635                                self.out,
2636                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2637                            )?;
2638                            self.put_expression(arg, context, true)?;
2639                            write!(self.out, "))")?;
2640                        } else {
2641                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2642                            write!(self.out, "({sign_prefix}int4(")?;
2643                            self.put_expression(arg, context, true)?;
2644                            write!(self.out, ", ")?;
2645                            self.put_expression(arg, context, true)?;
2646                            write!(self.out, " >> 8, ")?;
2647                            self.put_expression(arg, context, true)?;
2648                            write!(self.out, " >> 16, ")?;
2649                            self.put_expression(arg, context, true)?;
2650                            write!(self.out, " >> 24) << 24 >> 24)")?;
2651                        }
2652                    }
2653                    Mf::QuantizeToF16 => {
2654                        match *context.resolve_type(arg) {
2655                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2656                            crate::TypeInner::Vector { size, .. } => write!(
2657                                self.out,
2658                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2659                                size = common::vector_size_str(size),
2660                            )?,
2661                            _ => unreachable!(
2662                                "Correct TypeInner for QuantizeToF16 should be already validated"
2663                            ),
2664                        };
2665
2666                        self.put_expression(arg, context, true)?;
2667                        write!(self.out, "))")?;
2668                    }
2669                    _ => {
2670                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2671                        self.put_call_parameters(
2672                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2673                            context,
2674                        )?;
2675                    }
2676                }
2677            }
2678            crate::Expression::As {
2679                expr,
2680                kind,
2681                convert,
2682            } => match *context.resolve_type(expr) {
2683                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2684                    if src.kind == crate::ScalarKind::Float
2685                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2686                        && convert.is_some()
2687                    {
2688                        // Use helper functions for float to int casts in order to avoid
2689                        // undefined behaviour when value is out of range for the target
2690                        // type.
2691                        let fun_name = match (kind, convert) {
2692                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2693                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2694                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2695                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2696                            _ => unreachable!(),
2697                        };
2698                        write!(self.out, "{fun_name}(")?;
2699                        self.put_expression(expr, context, true)?;
2700                        write!(self.out, ")")?;
2701                    } else {
2702                        let target_scalar = crate::Scalar {
2703                            kind,
2704                            width: convert.unwrap_or(src.width),
2705                        };
2706                        let op = match convert {
2707                            Some(_) => "static_cast",
2708                            None => "as_type",
2709                        };
2710                        write!(self.out, "{op}<")?;
2711                        match *context.resolve_type(expr) {
2712                            crate::TypeInner::Vector { size, .. } => {
2713                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2714                            }
2715                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2716                        };
2717                        write!(self.out, ">(")?;
2718                        self.put_expression(expr, context, true)?;
2719                        write!(self.out, ")")?;
2720                    }
2721                }
2722                crate::TypeInner::Matrix {
2723                    columns,
2724                    rows,
2725                    scalar,
2726                } => {
2727                    let target_scalar = crate::Scalar {
2728                        kind,
2729                        width: convert.unwrap_or(scalar.width),
2730                    };
2731                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2732                    write!(self.out, "(")?;
2733                    self.put_expression(expr, context, true)?;
2734                    write!(self.out, ")")?;
2735                }
2736                ref ty => {
2737                    return Err(Error::GenericValidation(format!(
2738                        "Unsupported type for As: {ty:?}"
2739                    )))
2740                }
2741            },
2742            // has to be a named expression
2743            crate::Expression::CallResult(_)
2744            | crate::Expression::AtomicResult { .. }
2745            | crate::Expression::WorkGroupUniformLoadResult { .. }
2746            | crate::Expression::SubgroupBallotResult
2747            | crate::Expression::SubgroupOperationResult { .. }
2748            | crate::Expression::RayQueryProceedResult => {
2749                unreachable!()
2750            }
2751            crate::Expression::ArrayLength(expr) => {
2752                // Find the global to which the array belongs.
2753                let global = match context.function.expressions[expr] {
2754                    crate::Expression::AccessIndex { base, .. } => {
2755                        match context.function.expressions[base] {
2756                            crate::Expression::GlobalVariable(handle) => handle,
2757                            ref ex => {
2758                                return Err(Error::GenericValidation(format!(
2759                                    "Expected global variable in AccessIndex, got {ex:?}"
2760                                )))
2761                            }
2762                        }
2763                    }
2764                    crate::Expression::GlobalVariable(handle) => handle,
2765                    ref ex => {
2766                        return Err(Error::GenericValidation(format!(
2767                            "Unexpected expression in ArrayLength, got {ex:?}"
2768                        )))
2769                    }
2770                };
2771
2772                if !is_scoped {
2773                    write!(self.out, "(")?;
2774                }
2775                write!(self.out, "1 + ")?;
2776                self.put_dynamic_array_max_index(global, context)?;
2777                if !is_scoped {
2778                    write!(self.out, ")")?;
2779                }
2780            }
2781            crate::Expression::RayQueryVertexPositions { .. } => {
2782                unimplemented!()
2783            }
2784            crate::Expression::RayQueryGetIntersection {
2785                query,
2786                committed: _,
2787            } => {
2788                if context.lang_version < (2, 4) {
2789                    return Err(Error::UnsupportedRayTracing);
2790                }
2791
2792                let ty = context.module.special_types.ray_intersection.unwrap();
2793                let type_name = &self.names[&NameKey::Type(ty)];
2794                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2795                self.put_expression(query, context, true)?;
2796                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2797                let fields = [
2798                    "distance",
2799                    "user_instance_id", // req Metal 2.4
2800                    "instance_id",
2801                    "", // SBT offset
2802                    "geometry_id",
2803                    "primitive_id",
2804                    "triangle_barycentric_coord",
2805                    "triangle_front_facing",
2806                    "",                          // padding
2807                    "object_to_world_transform", // req Metal 2.4
2808                    "world_to_object_transform", // req Metal 2.4
2809                ];
2810                for field in fields {
2811                    write!(self.out, ", ")?;
2812                    if field.is_empty() {
2813                        write!(self.out, "{{}}")?;
2814                    } else {
2815                        self.put_expression(query, context, true)?;
2816                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2817                    }
2818                }
2819                write!(self.out, "}}")?;
2820            }
2821        }
2822        Ok(())
2823    }
2824
2825    /// Emits code for a binary operation, using the provided callback to emit
2826    /// the left and right operands.
2827    fn put_binop<F>(
2828        &mut self,
2829        op: crate::BinaryOperator,
2830        left: Handle<crate::Expression>,
2831        right: Handle<crate::Expression>,
2832        context: &ExpressionContext,
2833        is_scoped: bool,
2834        put_expression: &F,
2835    ) -> BackendResult
2836    where
2837        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2838    {
2839        let op_str = back::binary_operation_str(op);
2840
2841        if !is_scoped {
2842            write!(self.out, "(")?;
2843        }
2844
2845        // Cast packed vector if necessary
2846        // Packed vector - matrix multiplications are not supported in MSL
2847        if op == crate::BinaryOperator::Multiply
2848            && matches!(
2849                context.resolve_type(right),
2850                &crate::TypeInner::Matrix { .. }
2851            )
2852        {
2853            self.put_wrapped_expression_for_packed_vec3_access(
2854                left,
2855                context,
2856                false,
2857                put_expression,
2858            )?;
2859        } else {
2860            put_expression(self, left, context, false)?;
2861        }
2862
2863        write!(self.out, " {op_str} ")?;
2864
2865        // See comment above
2866        if op == crate::BinaryOperator::Multiply
2867            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2868        {
2869            self.put_wrapped_expression_for_packed_vec3_access(
2870                right,
2871                context,
2872                false,
2873                put_expression,
2874            )?;
2875        } else {
2876            put_expression(self, right, context, false)?;
2877        }
2878
2879        if !is_scoped {
2880            write!(self.out, ")")?;
2881        }
2882
2883        Ok(())
2884    }
2885
2886    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2887    fn put_wrapped_expression_for_packed_vec3_access<F>(
2888        &mut self,
2889        expr_handle: Handle<crate::Expression>,
2890        context: &ExpressionContext,
2891        is_scoped: bool,
2892        put_expression: &F,
2893    ) -> BackendResult
2894    where
2895        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2896    {
2897        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2898            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2899            put_expression(self, expr_handle, context, is_scoped)?;
2900            write!(self.out, ")")?;
2901        } else {
2902            put_expression(self, expr_handle, context, is_scoped)?;
2903        }
2904        Ok(())
2905    }
2906
2907    /// Emits code for an expression using the provided callback, wrapping the
2908    /// result in a bitcast to the type `cast_to`.
2909    fn put_bitcasted_expression<F>(
2910        &mut self,
2911        cast_to: &crate::TypeInner,
2912        context: &ExpressionContext,
2913        put_expression: &F,
2914    ) -> BackendResult
2915    where
2916        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2917    {
2918        write!(self.out, "as_type<")?;
2919        match *cast_to {
2920            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2921            crate::TypeInner::Vector { size, scalar } => {
2922                put_numeric_type(&mut self.out, scalar, &[size])?
2923            }
2924            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2925        };
2926        write!(self.out, ">(")?;
2927        put_expression(self, context, true)?;
2928        write!(self.out, ")")?;
2929
2930        Ok(())
2931    }
2932
2933    /// Write a `GuardedIndex` as a Metal expression.
2934    fn put_index(
2935        &mut self,
2936        index: index::GuardedIndex,
2937        context: &ExpressionContext,
2938        is_scoped: bool,
2939    ) -> BackendResult {
2940        match index {
2941            index::GuardedIndex::Expression(expr) => {
2942                self.put_expression(expr, context, is_scoped)?
2943            }
2944            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2945        }
2946        Ok(())
2947    }
2948
2949    /// Emit an index bounds check condition for `chain`, if required.
2950    ///
2951    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
2952    /// operating either on a pointer to a value, or on a value directly. If we cannot
2953    /// statically determine that all indexing operations in `chain` are within
2954    /// bounds, then write a conditional expression to check them dynamically,
2955    /// and return true. All accesses in the chain are checked by the generated
2956    /// expression.
2957    ///
2958    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
2959    ///
2960    /// The text written is of the form:
2961    ///
2962    /// ```ignore
2963    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
2964    /// ```
2965    ///
2966    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
2967    /// statements, presumably these arguments start an indented `if` statement; for
2968    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
2969    /// expression. In either case, what is written is not a complete syntactic structure
2970    /// in its own right, and the caller will have to finish it off if we return `true`.
2971    ///
2972    /// If no expression is written, return false.
2973    ///
2974    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
2975    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
2976    /// [`Store`]: crate::Statement::Store
2977    /// [`Load`]: crate::Expression::Load
2978    #[allow(unused_variables)]
2979    fn put_bounds_checks(
2980        &mut self,
2981        chain: Handle<crate::Expression>,
2982        context: &ExpressionContext,
2983        level: back::Level,
2984        prefix: &'static str,
2985    ) -> Result<bool, Error> {
2986        let mut check_written = false;
2987
2988        // Iterate over the access chain, handling each required bounds check.
2989        for item in context.bounds_check_iter(chain) {
2990            let BoundsCheck {
2991                base,
2992                index,
2993                length,
2994            } = item;
2995
2996            if check_written {
2997                write!(self.out, " && ")?;
2998            } else {
2999                write!(self.out, "{level}{prefix}")?;
3000                check_written = true;
3001            }
3002
3003            // Check that the index falls within bounds. Do this with a single
3004            // comparison, by casting the index to `uint` first, so that negative
3005            // indices become large positive values.
3006            write!(self.out, "uint(")?;
3007            self.put_index(index, context, true)?;
3008            self.out.write_str(") < ")?;
3009            match length {
3010                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3011                index::IndexableLength::Dynamic => {
3012                    let global = context.function.originating_global(base).ok_or_else(|| {
3013                        Error::GenericValidation("Could not find originating global".into())
3014                    })?;
3015                    write!(self.out, "1 + ")?;
3016                    self.put_dynamic_array_max_index(global, context)?
3017                }
3018            }
3019        }
3020
3021        Ok(check_written)
3022    }
3023
3024    /// Write the access chain `chain`.
3025    ///
3026    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3027    /// operating either on a pointer to a value, or on a value directly.
3028    ///
3029    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3030    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3031    /// that must be handled in the caller.
3032    ///
3033    /// Handle the entire chain, recursing back into `put_expression` only for index
3034    /// expressions and the base expression that originates the pointer or composite value
3035    /// being accessed. This allows `put_expression` to assume that any `Access` or
3036    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3037    /// `ReadZeroSkipWrite` checks.
3038    ///
3039    /// [`Access`]: crate::Expression::Access
3040    /// [`AccessIndex`]: crate::Expression::AccessIndex
3041    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3042    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3043    fn put_access_chain(
3044        &mut self,
3045        chain: Handle<crate::Expression>,
3046        policy: index::BoundsCheckPolicy,
3047        context: &ExpressionContext,
3048    ) -> BackendResult {
3049        match context.function.expressions[chain] {
3050            crate::Expression::Access { base, index } => {
3051                let mut base_ty = context.resolve_type(base);
3052
3053                // Look through any pointers to see what we're really indexing.
3054                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3055                    base_ty = &context.module.types[base].inner;
3056                }
3057
3058                self.put_subscripted_access_chain(
3059                    base,
3060                    base_ty,
3061                    index::GuardedIndex::Expression(index),
3062                    policy,
3063                    context,
3064                )?;
3065            }
3066            crate::Expression::AccessIndex { base, index } => {
3067                let base_resolution = &context.info[base].ty;
3068                let mut base_ty = base_resolution.inner_with(&context.module.types);
3069                let mut base_ty_handle = base_resolution.handle();
3070
3071                // Look through any pointers to see what we're really indexing.
3072                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3073                    base_ty = &context.module.types[base].inner;
3074                    base_ty_handle = Some(base);
3075                }
3076
3077                // Handle structs and anything else that can use `.x` syntax here, so
3078                // `put_subscripted_access_chain` won't have to handle the absurd case of
3079                // indexing a struct with an expression.
3080                match *base_ty {
3081                    crate::TypeInner::Struct { .. } => {
3082                        let base_ty = base_ty_handle.unwrap();
3083                        self.put_access_chain(base, policy, context)?;
3084                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3085                        write!(self.out, ".{name}")?;
3086                    }
3087                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3088                        self.put_access_chain(base, policy, context)?;
3089                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3090                        // however array indexing is
3091                        if context.get_packed_vec_kind(base).is_some() {
3092                            write!(self.out, "[{index}]")?;
3093                        } else {
3094                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3095                        }
3096                    }
3097                    _ => {
3098                        self.put_subscripted_access_chain(
3099                            base,
3100                            base_ty,
3101                            index::GuardedIndex::Known(index),
3102                            policy,
3103                            context,
3104                        )?;
3105                    }
3106                }
3107            }
3108            _ => self.put_expression(chain, context, false)?,
3109        }
3110
3111        Ok(())
3112    }
3113
3114    /// Write a `[]`-style access of `base` by `index`.
3115    ///
3116    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3117    /// values within bounds.
3118    ///
3119    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3120    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3121    /// removed. Our callers often already have this handy.
3122    ///
3123    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3124    /// referencing vector components by name.
3125    ///
3126    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3127    /// [`Array`]: crate::TypeInner::Array
3128    /// [`Vector`]: crate::TypeInner::Vector
3129    /// [`Pointer`]: crate::TypeInner::Pointer
3130    fn put_subscripted_access_chain(
3131        &mut self,
3132        base: Handle<crate::Expression>,
3133        base_ty: &crate::TypeInner,
3134        index: index::GuardedIndex,
3135        policy: index::BoundsCheckPolicy,
3136        context: &ExpressionContext,
3137    ) -> BackendResult {
3138        let accessing_wrapped_array = match *base_ty {
3139            crate::TypeInner::Array {
3140                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3141                ..
3142            } => true,
3143            _ => false,
3144        };
3145        let accessing_wrapped_binding_array =
3146            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3147
3148        self.put_access_chain(base, policy, context)?;
3149        if accessing_wrapped_array {
3150            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3151        }
3152        write!(self.out, "[")?;
3153
3154        // Decide whether this index needs to be clamped to fall within range.
3155        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3156            context.access_needs_check(base, index)
3157        } else {
3158            None
3159        };
3160        if let Some(limit) = restriction_needed {
3161            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3162            self.put_index(index, context, true)?;
3163            write!(self.out, "), ")?;
3164            match limit {
3165                index::IndexableLength::Known(limit) => {
3166                    write!(self.out, "{}u", limit - 1)?;
3167                }
3168                index::IndexableLength::Dynamic => {
3169                    let global = context.function.originating_global(base).ok_or_else(|| {
3170                        Error::GenericValidation("Could not find originating global".into())
3171                    })?;
3172                    self.put_dynamic_array_max_index(global, context)?;
3173                }
3174            }
3175            write!(self.out, ")")?;
3176        } else {
3177            self.put_index(index, context, true)?;
3178        }
3179
3180        write!(self.out, "]")?;
3181
3182        if accessing_wrapped_binding_array {
3183            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3184        }
3185
3186        Ok(())
3187    }
3188
3189    fn put_load(
3190        &mut self,
3191        pointer: Handle<crate::Expression>,
3192        context: &ExpressionContext,
3193        is_scoped: bool,
3194    ) -> BackendResult {
3195        // Since access chains never cross between address spaces, we can just
3196        // check the index bounds check policy once at the top.
3197        let policy = context.choose_bounds_check_policy(pointer);
3198        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3199            && self.put_bounds_checks(
3200                pointer,
3201                context,
3202                back::Level(0),
3203                if is_scoped { "" } else { "(" },
3204            )?
3205        {
3206            write!(self.out, " ? ")?;
3207            self.put_unchecked_load(pointer, policy, context)?;
3208            write!(self.out, " : DefaultConstructible()")?;
3209
3210            if !is_scoped {
3211                write!(self.out, ")")?;
3212            }
3213        } else {
3214            self.put_unchecked_load(pointer, policy, context)?;
3215        }
3216
3217        Ok(())
3218    }
3219
3220    fn put_unchecked_load(
3221        &mut self,
3222        pointer: Handle<crate::Expression>,
3223        policy: index::BoundsCheckPolicy,
3224        context: &ExpressionContext,
3225    ) -> BackendResult {
3226        let is_atomic_pointer = context
3227            .resolve_type(pointer)
3228            .is_atomic_pointer(&context.module.types);
3229
3230        if is_atomic_pointer {
3231            write!(
3232                self.out,
3233                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3234            )?;
3235            self.put_access_chain(pointer, policy, context)?;
3236            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3237        } else {
3238            // We don't do any dereferencing with `*` here as pointer arguments to functions
3239            // are done by `&` references and not `*` pointers. These do not need to be
3240            // dereferenced.
3241            self.put_access_chain(pointer, policy, context)?;
3242        }
3243
3244        Ok(())
3245    }
3246
3247    fn put_return_value(
3248        &mut self,
3249        level: back::Level,
3250        expr_handle: Handle<crate::Expression>,
3251        result_struct: Option<&str>,
3252        context: &ExpressionContext,
3253    ) -> BackendResult {
3254        match result_struct {
3255            Some(struct_name) => {
3256                let mut has_point_size = false;
3257                let result_ty = context.function.result.as_ref().unwrap().ty;
3258                match context.module.types[result_ty].inner {
3259                    crate::TypeInner::Struct { ref members, .. } => {
3260                        let tmp = "_tmp";
3261                        write!(self.out, "{level}const auto {tmp} = ")?;
3262                        self.put_expression(expr_handle, context, true)?;
3263                        writeln!(self.out, ";")?;
3264                        write!(self.out, "{level}return {struct_name} {{")?;
3265
3266                        let mut is_first = true;
3267
3268                        for (index, member) in members.iter().enumerate() {
3269                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3270                                member.binding
3271                            {
3272                                has_point_size = true;
3273                                if !context.pipeline_options.allow_and_force_point_size {
3274                                    continue;
3275                                }
3276                            }
3277
3278                            let comma = if is_first { "" } else { "," };
3279                            is_first = false;
3280                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3281                            // HACK: we are forcefully deduplicating the expression here
3282                            // to convert from a wrapped struct to a raw array, e.g.
3283                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3284                            if let crate::TypeInner::Array {
3285                                size: crate::ArraySize::Constant(size),
3286                                ..
3287                            } = context.module.types[member.ty].inner
3288                            {
3289                                write!(self.out, "{comma} {{")?;
3290                                for j in 0..size.get() {
3291                                    if j != 0 {
3292                                        write!(self.out, ",")?;
3293                                    }
3294                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3295                                }
3296                                write!(self.out, "}}")?;
3297                            } else {
3298                                write!(self.out, "{comma} {tmp}.{name}")?;
3299                            }
3300                        }
3301                    }
3302                    _ => {
3303                        write!(self.out, "{level}return {struct_name} {{ ")?;
3304                        self.put_expression(expr_handle, context, true)?;
3305                    }
3306                }
3307
3308                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3309                    let stage = context.module.entry_points[ep_index as usize].stage;
3310                    if context.pipeline_options.allow_and_force_point_size
3311                        && stage == crate::ShaderStage::Vertex
3312                        && !has_point_size
3313                    {
3314                        // point size was injected and comes last
3315                        write!(self.out, ", 1.0")?;
3316                    }
3317                }
3318                write!(self.out, " }}")?;
3319            }
3320            None => {
3321                write!(self.out, "{level}return ")?;
3322                self.put_expression(expr_handle, context, true)?;
3323            }
3324        }
3325        writeln!(self.out, ";")?;
3326        Ok(())
3327    }
3328
3329    /// Helper method used to find which expressions of a given function require baking
3330    ///
3331    /// # Notes
3332    /// This function overwrites the contents of `self.need_bake_expressions`
3333    fn update_expressions_to_bake(
3334        &mut self,
3335        func: &crate::Function,
3336        info: &valid::FunctionInfo,
3337        context: &ExpressionContext,
3338    ) {
3339        use crate::Expression;
3340        self.need_bake_expressions.clear();
3341
3342        for (expr_handle, expr) in func.expressions.iter() {
3343            // Expressions whose reference count is above the
3344            // threshold should always be stored in temporaries.
3345            let expr_info = &info[expr_handle];
3346            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3347            if min_ref_count <= expr_info.ref_count {
3348                self.need_bake_expressions.insert(expr_handle);
3349            } else {
3350                match expr_info.ty {
3351                    // force ray desc to be baked: it's used multiple times internally
3352                    TypeResolution::Handle(h)
3353                        if Some(h) == context.module.special_types.ray_desc =>
3354                    {
3355                        self.need_bake_expressions.insert(expr_handle);
3356                    }
3357                    _ => {}
3358                }
3359            }
3360
3361            if let Expression::Math {
3362                fun,
3363                arg,
3364                arg1,
3365                arg2,
3366                ..
3367            } = *expr
3368            {
3369                match fun {
3370                    crate::MathFunction::Dot => {
3371                        // WGSL's `dot` function works on any `vecN` type, but Metal's only
3372                        // works on floating-point vectors, so we emit inline code for
3373                        // integer vector `dot` calls. But that code uses each argument `N`
3374                        // times, once for each component (see `put_dot_product`), so to
3375                        // avoid duplicated evaluation, we must bake integer operands.
3376
3377                        // check what kind of product this is depending
3378                        // on the resolve type of the Dot function itself
3379                        let inner = context.resolve_type(expr_handle);
3380                        if let crate::TypeInner::Scalar(scalar) = *inner {
3381                            match scalar.kind {
3382                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3383                                    self.need_bake_expressions.insert(arg);
3384                                    self.need_bake_expressions.insert(arg1.unwrap());
3385                                }
3386                                _ => {}
3387                            }
3388                        }
3389                    }
3390                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3391                        self.need_bake_expressions.insert(arg);
3392                        self.need_bake_expressions.insert(arg1.unwrap());
3393                    }
3394                    crate::MathFunction::FirstLeadingBit => {
3395                        self.need_bake_expressions.insert(arg);
3396                    }
3397                    crate::MathFunction::Pack4xI8
3398                    | crate::MathFunction::Pack4xU8
3399                    | crate::MathFunction::Pack4xI8Clamp
3400                    | crate::MathFunction::Pack4xU8Clamp
3401                    | crate::MathFunction::Unpack4xI8
3402                    | crate::MathFunction::Unpack4xU8 => {
3403                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3404                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3405                        if context.lang_version < (2, 1) {
3406                            self.need_bake_expressions.insert(arg);
3407                        }
3408                    }
3409                    crate::MathFunction::ExtractBits => {
3410                        // Only argument 1 is re-used.
3411                        self.need_bake_expressions.insert(arg1.unwrap());
3412                    }
3413                    crate::MathFunction::InsertBits => {
3414                        // Only argument 2 is re-used.
3415                        self.need_bake_expressions.insert(arg2.unwrap());
3416                    }
3417                    crate::MathFunction::Sign => {
3418                        // WGSL's `sign` function works also on signed ints, but Metal's only
3419                        // works on floating points, so we emit inline code for integer `sign`
3420                        // calls. But that code uses each argument 2 times (see `put_isign`),
3421                        // so to avoid duplicated evaluation, we must bake the argument.
3422                        let inner = context.resolve_type(expr_handle);
3423                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3424                            self.need_bake_expressions.insert(arg);
3425                        }
3426                    }
3427                    _ => {}
3428                }
3429            }
3430        }
3431    }
3432
3433    fn start_baking_expression(
3434        &mut self,
3435        handle: Handle<crate::Expression>,
3436        context: &ExpressionContext,
3437        name: &str,
3438    ) -> BackendResult {
3439        match context.info[handle].ty {
3440            TypeResolution::Handle(ty_handle) => {
3441                let ty_name = TypeContext {
3442                    handle: ty_handle,
3443                    gctx: context.module.to_ctx(),
3444                    names: &self.names,
3445                    access: crate::StorageAccess::empty(),
3446                    first_time: false,
3447                };
3448                write!(self.out, "{ty_name}")?;
3449            }
3450            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3451                put_numeric_type(&mut self.out, scalar, &[])?;
3452            }
3453            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3454                put_numeric_type(&mut self.out, scalar, &[size])?;
3455            }
3456            TypeResolution::Value(crate::TypeInner::Matrix {
3457                columns,
3458                rows,
3459                scalar,
3460            }) => {
3461                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3462            }
3463            TypeResolution::Value(ref other) => {
3464                log::warn!("Type {other:?} isn't a known local"); //TEMP!
3465                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3466            }
3467        }
3468
3469        //TODO: figure out the naming scheme that wouldn't collide with user names.
3470        write!(self.out, " {name} = ")?;
3471
3472        Ok(())
3473    }
3474
3475    /// Cache a clamped level of detail value, if necessary.
3476    ///
3477    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3478    /// properly clamped level of detail value both in the access itself, and
3479    /// for fetching the size of the requested MIP level, needed to clamp the
3480    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3481    /// it in a temporary variable, as part of the [`Emit`] statement covering
3482    /// the [`ImageLoad`] expression.
3483    ///
3484    /// [`ImageLoad`]: crate::Expression::ImageLoad
3485    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3486    /// [`Emit`]: crate::Statement::Emit
3487    fn put_cache_restricted_level(
3488        &mut self,
3489        load: Handle<crate::Expression>,
3490        image: Handle<crate::Expression>,
3491        mip_level: Option<Handle<crate::Expression>>,
3492        indent: back::Level,
3493        context: &StatementContext,
3494    ) -> BackendResult {
3495        // Does this image access actually require (or even permit) a
3496        // level-of-detail, and does the policy require us to restrict it?
3497        let level_of_detail = match mip_level {
3498            Some(level) => level,
3499            None => return Ok(()),
3500        };
3501
3502        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3503            || !context.expression.image_needs_lod(image)
3504        {
3505            return Ok(());
3506        }
3507
3508        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3509        self.put_restricted_scalar_image_index(
3510            image,
3511            level_of_detail,
3512            "get_num_mip_levels",
3513            &context.expression,
3514        )?;
3515        writeln!(self.out, ";")?;
3516
3517        Ok(())
3518    }
3519
3520    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3521    ///
3522    /// Caches the results in temporary variables (whose names are derived from
3523    /// the original variable names). This caching avoids the need to redo the
3524    /// casting for each vector component when emitting the dot product.
3525    fn put_casting_to_packed_chars(
3526        &mut self,
3527        fun: crate::MathFunction,
3528        arg0: Handle<crate::Expression>,
3529        arg1: Handle<crate::Expression>,
3530        indent: back::Level,
3531        context: &StatementContext<'_>,
3532    ) -> Result<(), Error> {
3533        let packed_type = match fun {
3534            crate::MathFunction::Dot4I8Packed => "packed_char4",
3535            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3536            _ => unreachable!(),
3537        };
3538
3539        for arg in [arg0, arg1] {
3540            write!(
3541                self.out,
3542                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3543                Reinterpreted::new(packed_type, arg)
3544            )?;
3545            self.put_expression(arg, &context.expression, true)?;
3546            writeln!(self.out, ");")?;
3547        }
3548
3549        Ok(())
3550    }
3551
3552    fn put_block(
3553        &mut self,
3554        level: back::Level,
3555        statements: &[crate::Statement],
3556        context: &StatementContext,
3557    ) -> BackendResult {
3558        // Add to the set in order to track the stack size.
3559        #[cfg(test)]
3560        self.put_block_stack_pointers
3561            .insert(ptr::from_ref(&level).cast());
3562
3563        for statement in statements {
3564            log::trace!("statement[{}] {:?}", level.0, statement);
3565            match *statement {
3566                crate::Statement::Emit(ref range) => {
3567                    for handle in range.clone() {
3568                        use crate::MathFunction as Mf;
3569
3570                        match context.expression.function.expressions[handle] {
3571                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3572                            // may need to cache a clamped version of their level-of-detail argument.
3573                            crate::Expression::ImageLoad {
3574                                image,
3575                                level: mip_level,
3576                                ..
3577                            } => {
3578                                self.put_cache_restricted_level(
3579                                    handle, image, mip_level, level, context,
3580                                )?;
3581                            }
3582
3583                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3584                            // 2.1+ then we introduce two intermediate variables that recast the two
3585                            // arguments as packed (signed or unsigned) chars. The actual dot product
3586                            // is implemented in `Self::put_expression`, and it uses both of these
3587                            // intermediate variables multiple times. There's no danger that the
3588                            // original arguments get modified between the definition of these
3589                            // intermediate variables and the implementation of the actual dot
3590                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3591                            crate::Expression::Math {
3592                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3593                                arg,
3594                                arg1,
3595                                ..
3596                            } if context.expression.lang_version >= (2, 1) => {
3597                                self.put_casting_to_packed_chars(
3598                                    fun,
3599                                    arg,
3600                                    arg1.unwrap(),
3601                                    level,
3602                                    context,
3603                                )?;
3604                            }
3605
3606                            _ => (),
3607                        }
3608
3609                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3610                        let expr_name = if ptr_class.is_some() {
3611                            None // don't bake pointer expressions (just yet)
3612                        } else if let Some(name) =
3613                            context.expression.function.named_expressions.get(&handle)
3614                        {
3615                            // The `crate::Function::named_expressions` table holds
3616                            // expressions that should be saved in temporaries once they
3617                            // are `Emit`ted. We only add them to `self.named_expressions`
3618                            // when we reach the `Emit` that covers them, so that we don't
3619                            // try to use their names before we've actually initialized
3620                            // the temporary that holds them.
3621                            //
3622                            // Don't assume the names in `named_expressions` are unique,
3623                            // or even valid. Use the `Namer`.
3624                            Some(self.namer.call(name))
3625                        } else {
3626                            // If this expression is an index that we're going to first compare
3627                            // against a limit, and then actually use as an index, then we may
3628                            // want to cache it in a temporary, to avoid evaluating it twice.
3629                            let bake = if context.expression.guarded_indices.contains(handle) {
3630                                true
3631                            } else {
3632                                self.need_bake_expressions.contains(&handle)
3633                            };
3634
3635                            if bake {
3636                                Some(Baked(handle).to_string())
3637                            } else {
3638                                None
3639                            }
3640                        };
3641
3642                        if let Some(name) = expr_name {
3643                            write!(self.out, "{level}")?;
3644                            self.start_baking_expression(handle, &context.expression, &name)?;
3645                            self.put_expression(handle, &context.expression, true)?;
3646                            self.named_expressions.insert(handle, name);
3647                            writeln!(self.out, ";")?;
3648                        }
3649                    }
3650                }
3651                crate::Statement::Block(ref block) => {
3652                    if !block.is_empty() {
3653                        writeln!(self.out, "{level}{{")?;
3654                        self.put_block(level.next(), block, context)?;
3655                        writeln!(self.out, "{level}}}")?;
3656                    }
3657                }
3658                crate::Statement::If {
3659                    condition,
3660                    ref accept,
3661                    ref reject,
3662                } => {
3663                    write!(self.out, "{level}if (")?;
3664                    self.put_expression(condition, &context.expression, true)?;
3665                    writeln!(self.out, ") {{")?;
3666                    self.put_block(level.next(), accept, context)?;
3667                    if !reject.is_empty() {
3668                        writeln!(self.out, "{level}}} else {{")?;
3669                        self.put_block(level.next(), reject, context)?;
3670                    }
3671                    writeln!(self.out, "{level}}}")?;
3672                }
3673                crate::Statement::Switch {
3674                    selector,
3675                    ref cases,
3676                } => {
3677                    write!(self.out, "{level}switch(")?;
3678                    self.put_expression(selector, &context.expression, true)?;
3679                    writeln!(self.out, ") {{")?;
3680                    let lcase = level.next();
3681                    for case in cases.iter() {
3682                        match case.value {
3683                            crate::SwitchValue::I32(value) => {
3684                                write!(self.out, "{lcase}case {value}:")?;
3685                            }
3686                            crate::SwitchValue::U32(value) => {
3687                                write!(self.out, "{lcase}case {value}u:")?;
3688                            }
3689                            crate::SwitchValue::Default => {
3690                                write!(self.out, "{lcase}default:")?;
3691                            }
3692                        }
3693
3694                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3695                        if write_block_braces {
3696                            writeln!(self.out, " {{")?;
3697                        } else {
3698                            writeln!(self.out)?;
3699                        }
3700
3701                        self.put_block(lcase.next(), &case.body, context)?;
3702                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3703                        {
3704                            writeln!(self.out, "{}break;", lcase.next())?;
3705                        }
3706
3707                        if write_block_braces {
3708                            writeln!(self.out, "{lcase}}}")?;
3709                        }
3710                    }
3711                    writeln!(self.out, "{level}}}")?;
3712                }
3713                crate::Statement::Loop {
3714                    ref body,
3715                    ref continuing,
3716                    break_if,
3717                } => {
3718                    let force_loop_bound_statements =
3719                        self.gen_force_bounded_loop_statements(level, context);
3720                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3721                        .then(|| self.namer.call("loop_init"));
3722
3723                    if let Some((ref decl, _)) = force_loop_bound_statements {
3724                        writeln!(self.out, "{decl}")?;
3725                    }
3726                    if let Some(ref gate_name) = gate_name {
3727                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3728                    }
3729
3730                    writeln!(self.out, "{level}while(true) {{",)?;
3731                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3732                        writeln!(self.out, "{break_and_inc}")?;
3733                    }
3734                    if let Some(ref gate_name) = gate_name {
3735                        let lif = level.next();
3736                        let lcontinuing = lif.next();
3737                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3738                        self.put_block(lcontinuing, continuing, context)?;
3739                        if let Some(condition) = break_if {
3740                            write!(self.out, "{lcontinuing}if (")?;
3741                            self.put_expression(condition, &context.expression, true)?;
3742                            writeln!(self.out, ") {{")?;
3743                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3744                            writeln!(self.out, "{lcontinuing}}}")?;
3745                        }
3746                        writeln!(self.out, "{lif}}}")?;
3747                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3748                    }
3749                    self.put_block(level.next(), body, context)?;
3750
3751                    writeln!(self.out, "{level}}}")?;
3752                }
3753                crate::Statement::Break => {
3754                    writeln!(self.out, "{level}break;")?;
3755                }
3756                crate::Statement::Continue => {
3757                    writeln!(self.out, "{level}continue;")?;
3758                }
3759                crate::Statement::Return {
3760                    value: Some(expr_handle),
3761                } => {
3762                    self.put_return_value(
3763                        level,
3764                        expr_handle,
3765                        context.result_struct,
3766                        &context.expression,
3767                    )?;
3768                }
3769                crate::Statement::Return { value: None } => {
3770                    writeln!(self.out, "{level}return;")?;
3771                }
3772                crate::Statement::Kill => {
3773                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3774                }
3775                crate::Statement::ControlBarrier(flags)
3776                | crate::Statement::MemoryBarrier(flags) => {
3777                    self.write_barrier(flags, level)?;
3778                }
3779                crate::Statement::Store { pointer, value } => {
3780                    self.put_store(pointer, value, level, context)?
3781                }
3782                crate::Statement::ImageStore {
3783                    image,
3784                    coordinate,
3785                    array_index,
3786                    value,
3787                } => {
3788                    let address = TexelAddress {
3789                        coordinate,
3790                        array_index,
3791                        sample: None,
3792                        level: None,
3793                    };
3794                    self.put_image_store(level, image, &address, value, context)?
3795                }
3796                crate::Statement::Call {
3797                    function,
3798                    ref arguments,
3799                    result,
3800                } => {
3801                    write!(self.out, "{level}")?;
3802                    if let Some(expr) = result {
3803                        let name = Baked(expr).to_string();
3804                        self.start_baking_expression(expr, &context.expression, &name)?;
3805                        self.named_expressions.insert(expr, name);
3806                    }
3807                    let fun_name = &self.names[&NameKey::Function(function)];
3808                    write!(self.out, "{fun_name}(")?;
3809                    // first, write down the actual arguments
3810                    for (i, &handle) in arguments.iter().enumerate() {
3811                        if i != 0 {
3812                            write!(self.out, ", ")?;
3813                        }
3814                        self.put_expression(handle, &context.expression, true)?;
3815                    }
3816                    // follow-up with any global resources used
3817                    let mut separate = !arguments.is_empty();
3818                    let fun_info = &context.expression.mod_info[function];
3819                    let mut needs_buffer_sizes = false;
3820                    for (handle, var) in context.expression.module.global_variables.iter() {
3821                        if fun_info[handle].is_empty() {
3822                            continue;
3823                        }
3824                        if var.space.needs_pass_through() {
3825                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3826                            if separate {
3827                                write!(self.out, ", ")?;
3828                            } else {
3829                                separate = true;
3830                            }
3831                            write!(self.out, "{name}")?;
3832                        }
3833                        needs_buffer_sizes |=
3834                            needs_array_length(var.ty, &context.expression.module.types);
3835                    }
3836                    if needs_buffer_sizes {
3837                        if separate {
3838                            write!(self.out, ", ")?;
3839                        }
3840                        write!(self.out, "_buffer_sizes")?;
3841                    }
3842
3843                    // done
3844                    writeln!(self.out, ");")?;
3845                }
3846                crate::Statement::Atomic {
3847                    pointer,
3848                    ref fun,
3849                    value,
3850                    result,
3851                } => {
3852                    let context = &context.expression;
3853
3854                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
3855                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
3856                    // `Some`, we are not operating on a 64-bit value, and that if we are
3857                    // operating on a 64-bit value, `result` is `None`.
3858                    write!(self.out, "{level}")?;
3859                    let fun_key = if let Some(result) = result {
3860                        let res_name = Baked(result).to_string();
3861                        self.start_baking_expression(result, context, &res_name)?;
3862                        self.named_expressions.insert(result, res_name);
3863                        fun.to_msl()
3864                    } else if context.resolve_type(value).scalar_width() == Some(8) {
3865                        fun.to_msl_64_bit()?
3866                    } else {
3867                        fun.to_msl()
3868                    };
3869
3870                    // If the pointer we're passing to the atomic operation needs to be conditional
3871                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
3872                    // the pointer operand should be unchecked.
3873                    let policy = context.choose_bounds_check_policy(pointer);
3874                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3875                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3876
3877                    // If requested and successfully put bounds checks, continue the ternary expression.
3878                    if checked {
3879                        write!(self.out, " ? ")?;
3880                    }
3881
3882                    // Put the atomic function invocation.
3883                    match *fun {
3884                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3885                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3886                            self.put_access_chain(pointer, policy, context)?;
3887                            write!(self.out, ", ")?;
3888                            self.put_expression(cmp, context, true)?;
3889                            write!(self.out, ", ")?;
3890                            self.put_expression(value, context, true)?;
3891                            write!(self.out, ")")?;
3892                        }
3893                        _ => {
3894                            write!(
3895                                self.out,
3896                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3897                            )?;
3898                            self.put_access_chain(pointer, policy, context)?;
3899                            write!(self.out, ", ")?;
3900                            self.put_expression(value, context, true)?;
3901                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3902                        }
3903                    }
3904
3905                    // Finish the ternary expression.
3906                    if checked {
3907                        write!(self.out, " : DefaultConstructible()")?;
3908                    }
3909
3910                    // Done
3911                    writeln!(self.out, ";")?;
3912                }
3913                crate::Statement::ImageAtomic {
3914                    image,
3915                    coordinate,
3916                    array_index,
3917                    fun,
3918                    value,
3919                } => {
3920                    let address = TexelAddress {
3921                        coordinate,
3922                        array_index,
3923                        sample: None,
3924                        level: None,
3925                    };
3926                    self.put_image_atomic(level, image, &address, fun, value, context)?
3927                }
3928                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3929                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3930
3931                    write!(self.out, "{level}")?;
3932                    let name = self.namer.call("");
3933                    self.start_baking_expression(result, &context.expression, &name)?;
3934                    self.put_load(pointer, &context.expression, true)?;
3935                    self.named_expressions.insert(result, name);
3936
3937                    writeln!(self.out, ";")?;
3938                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3939                }
3940                crate::Statement::RayQuery { query, ref fun } => {
3941                    if context.expression.lang_version < (2, 4) {
3942                        return Err(Error::UnsupportedRayTracing);
3943                    }
3944
3945                    match *fun {
3946                        crate::RayQueryFunction::Initialize {
3947                            acceleration_structure,
3948                            descriptor,
3949                        } => {
3950                            //TODO: how to deal with winding?
3951                            write!(self.out, "{level}")?;
3952                            self.put_expression(query, &context.expression, true)?;
3953                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3954                            {
3955                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3956                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3957                                write!(self.out, "{level}")?;
3958                                self.put_expression(query, &context.expression, true)?;
3959                                write!(
3960                                    self.out,
3961                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3962                                )?;
3963                                self.put_expression(descriptor, &context.expression, true)?;
3964                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3965                                self.put_expression(descriptor, &context.expression, true)?;
3966                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3967                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3968                            }
3969                            {
3970                                let f_opaque = back::RayFlag::OPAQUE.bits();
3971                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3972                                write!(self.out, "{level}")?;
3973                                self.put_expression(query, &context.expression, true)?;
3974                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3975                                self.put_expression(descriptor, &context.expression, true)?;
3976                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3977                                self.put_expression(descriptor, &context.expression, true)?;
3978                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3979                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3980                            }
3981                            {
3982                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3983                                write!(self.out, "{level}")?;
3984                                self.put_expression(query, &context.expression, true)?;
3985                                write!(
3986                                    self.out,
3987                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3988                                )?;
3989                                self.put_expression(descriptor, &context.expression, true)?;
3990                                writeln!(self.out, ".flags & {flag}) != 0);")?;
3991                            }
3992
3993                            write!(self.out, "{level}")?;
3994                            self.put_expression(query, &context.expression, true)?;
3995                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3996                            self.put_expression(query, &context.expression, true)?;
3997                            write!(
3998                                self.out,
3999                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4000                            )?;
4001                            self.put_expression(descriptor, &context.expression, true)?;
4002                            write!(self.out, ".origin, ")?;
4003                            self.put_expression(descriptor, &context.expression, true)?;
4004                            write!(self.out, ".dir, ")?;
4005                            self.put_expression(descriptor, &context.expression, true)?;
4006                            write!(self.out, ".tmin, ")?;
4007                            self.put_expression(descriptor, &context.expression, true)?;
4008                            write!(self.out, ".tmax), ")?;
4009                            self.put_expression(acceleration_structure, &context.expression, true)?;
4010                            write!(self.out, ", ")?;
4011                            self.put_expression(descriptor, &context.expression, true)?;
4012                            write!(self.out, ".cull_mask);")?;
4013
4014                            write!(self.out, "{level}")?;
4015                            self.put_expression(query, &context.expression, true)?;
4016                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4017                        }
4018                        crate::RayQueryFunction::Proceed { result } => {
4019                            write!(self.out, "{level}")?;
4020                            let name = Baked(result).to_string();
4021                            self.start_baking_expression(result, &context.expression, &name)?;
4022                            self.named_expressions.insert(result, name);
4023                            self.put_expression(query, &context.expression, true)?;
4024                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4025                            if RAY_QUERY_MODERN_SUPPORT {
4026                                write!(self.out, "{level}")?;
4027                                self.put_expression(query, &context.expression, true)?;
4028                                writeln!(self.out, ".?.next();")?;
4029                            }
4030                        }
4031                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4032                            if RAY_QUERY_MODERN_SUPPORT {
4033                                write!(self.out, "{level}")?;
4034                                self.put_expression(query, &context.expression, true)?;
4035                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4036                                self.put_expression(hit_t, &context.expression, true)?;
4037                                writeln!(self.out, ");")?;
4038                            } else {
4039                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4040                            }
4041                        }
4042                        crate::RayQueryFunction::ConfirmIntersection => {
4043                            if RAY_QUERY_MODERN_SUPPORT {
4044                                write!(self.out, "{level}")?;
4045                                self.put_expression(query, &context.expression, true)?;
4046                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4047                            } else {
4048                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4049                            }
4050                        }
4051                        crate::RayQueryFunction::Terminate => {
4052                            if RAY_QUERY_MODERN_SUPPORT {
4053                                write!(self.out, "{level}")?;
4054                                self.put_expression(query, &context.expression, true)?;
4055                                writeln!(self.out, ".?.abort();")?;
4056                            }
4057                            write!(self.out, "{level}")?;
4058                            self.put_expression(query, &context.expression, true)?;
4059                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4060                        }
4061                    }
4062                }
4063                crate::Statement::SubgroupBallot { result, predicate } => {
4064                    write!(self.out, "{level}")?;
4065                    let name = self.namer.call("");
4066                    self.start_baking_expression(result, &context.expression, &name)?;
4067                    self.named_expressions.insert(result, name);
4068                    write!(
4069                        self.out,
4070                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4071                    )?;
4072                    if let Some(predicate) = predicate {
4073                        self.put_expression(predicate, &context.expression, true)?;
4074                    } else {
4075                        write!(self.out, "true")?;
4076                    }
4077                    writeln!(self.out, "), 0, 0, 0);")?;
4078                }
4079                crate::Statement::SubgroupCollectiveOperation {
4080                    op,
4081                    collective_op,
4082                    argument,
4083                    result,
4084                } => {
4085                    write!(self.out, "{level}")?;
4086                    let name = self.namer.call("");
4087                    self.start_baking_expression(result, &context.expression, &name)?;
4088                    self.named_expressions.insert(result, name);
4089                    match (collective_op, op) {
4090                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4091                            write!(self.out, "{NAMESPACE}::simd_all(")?
4092                        }
4093                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4094                            write!(self.out, "{NAMESPACE}::simd_any(")?
4095                        }
4096                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4097                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4098                        }
4099                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4100                            write!(self.out, "{NAMESPACE}::simd_product(")?
4101                        }
4102                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4103                            write!(self.out, "{NAMESPACE}::simd_max(")?
4104                        }
4105                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4106                            write!(self.out, "{NAMESPACE}::simd_min(")?
4107                        }
4108                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4109                            write!(self.out, "{NAMESPACE}::simd_and(")?
4110                        }
4111                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4112                            write!(self.out, "{NAMESPACE}::simd_or(")?
4113                        }
4114                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4115                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4116                        }
4117                        (
4118                            crate::CollectiveOperation::ExclusiveScan,
4119                            crate::SubgroupOperation::Add,
4120                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4121                        (
4122                            crate::CollectiveOperation::ExclusiveScan,
4123                            crate::SubgroupOperation::Mul,
4124                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4125                        (
4126                            crate::CollectiveOperation::InclusiveScan,
4127                            crate::SubgroupOperation::Add,
4128                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4129                        (
4130                            crate::CollectiveOperation::InclusiveScan,
4131                            crate::SubgroupOperation::Mul,
4132                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4133                        _ => unimplemented!(),
4134                    }
4135                    self.put_expression(argument, &context.expression, true)?;
4136                    writeln!(self.out, ");")?;
4137                }
4138                crate::Statement::SubgroupGather {
4139                    mode,
4140                    argument,
4141                    result,
4142                } => {
4143                    write!(self.out, "{level}")?;
4144                    let name = self.namer.call("");
4145                    self.start_baking_expression(result, &context.expression, &name)?;
4146                    self.named_expressions.insert(result, name);
4147                    match mode {
4148                        crate::GatherMode::BroadcastFirst => {
4149                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4150                        }
4151                        crate::GatherMode::Broadcast(_) => {
4152                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4153                        }
4154                        crate::GatherMode::Shuffle(_) => {
4155                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4156                        }
4157                        crate::GatherMode::ShuffleDown(_) => {
4158                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4159                        }
4160                        crate::GatherMode::ShuffleUp(_) => {
4161                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4162                        }
4163                        crate::GatherMode::ShuffleXor(_) => {
4164                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4165                        }
4166                        crate::GatherMode::QuadBroadcast(_) => {
4167                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4168                        }
4169                        crate::GatherMode::QuadSwap(_) => {
4170                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4171                        }
4172                    }
4173                    self.put_expression(argument, &context.expression, true)?;
4174                    match mode {
4175                        crate::GatherMode::BroadcastFirst => {}
4176                        crate::GatherMode::Broadcast(index)
4177                        | crate::GatherMode::Shuffle(index)
4178                        | crate::GatherMode::ShuffleDown(index)
4179                        | crate::GatherMode::ShuffleUp(index)
4180                        | crate::GatherMode::ShuffleXor(index)
4181                        | crate::GatherMode::QuadBroadcast(index) => {
4182                            write!(self.out, ", ")?;
4183                            self.put_expression(index, &context.expression, true)?;
4184                        }
4185                        crate::GatherMode::QuadSwap(direction) => {
4186                            write!(self.out, ", ")?;
4187                            match direction {
4188                                crate::Direction::X => {
4189                                    write!(self.out, "1u")?;
4190                                }
4191                                crate::Direction::Y => {
4192                                    write!(self.out, "2u")?;
4193                                }
4194                                crate::Direction::Diagonal => {
4195                                    write!(self.out, "3u")?;
4196                                }
4197                            }
4198                        }
4199                    }
4200                    writeln!(self.out, ");")?;
4201                }
4202            }
4203        }
4204
4205        // un-emit expressions
4206        //TODO: take care of loop/continuing?
4207        for statement in statements {
4208            if let crate::Statement::Emit(ref range) = *statement {
4209                for handle in range.clone() {
4210                    self.named_expressions.shift_remove(&handle);
4211                }
4212            }
4213        }
4214        Ok(())
4215    }
4216
4217    fn put_store(
4218        &mut self,
4219        pointer: Handle<crate::Expression>,
4220        value: Handle<crate::Expression>,
4221        level: back::Level,
4222        context: &StatementContext,
4223    ) -> BackendResult {
4224        let policy = context.expression.choose_bounds_check_policy(pointer);
4225        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4226            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4227        {
4228            writeln!(self.out, ") {{")?;
4229            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4230            writeln!(self.out, "{level}}}")?;
4231        } else {
4232            self.put_unchecked_store(pointer, value, policy, level, context)?;
4233        }
4234
4235        Ok(())
4236    }
4237
4238    fn put_unchecked_store(
4239        &mut self,
4240        pointer: Handle<crate::Expression>,
4241        value: Handle<crate::Expression>,
4242        policy: index::BoundsCheckPolicy,
4243        level: back::Level,
4244        context: &StatementContext,
4245    ) -> BackendResult {
4246        let is_atomic_pointer = context
4247            .expression
4248            .resolve_type(pointer)
4249            .is_atomic_pointer(&context.expression.module.types);
4250
4251        if is_atomic_pointer {
4252            write!(
4253                self.out,
4254                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4255            )?;
4256            self.put_access_chain(pointer, policy, &context.expression)?;
4257            write!(self.out, ", ")?;
4258            self.put_expression(value, &context.expression, true)?;
4259            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4260        } else {
4261            write!(self.out, "{level}")?;
4262            self.put_access_chain(pointer, policy, &context.expression)?;
4263            write!(self.out, " = ")?;
4264            self.put_expression(value, &context.expression, true)?;
4265            writeln!(self.out, ";")?;
4266        }
4267
4268        Ok(())
4269    }
4270
4271    pub fn write(
4272        &mut self,
4273        module: &crate::Module,
4274        info: &valid::ModuleInfo,
4275        options: &Options,
4276        pipeline_options: &PipelineOptions,
4277    ) -> Result<TranslationInfo, Error> {
4278        self.names.clear();
4279        self.namer.reset(
4280            module,
4281            &super::keywords::RESERVED_SET,
4282            proc::CaseInsensitiveKeywordSet::empty(),
4283            &[CLAMPED_LOD_LOAD_PREFIX],
4284            &mut self.names,
4285        );
4286        self.wrapped_functions.clear();
4287        self.struct_member_pads.clear();
4288
4289        writeln!(
4290            self.out,
4291            "// language: metal{}.{}",
4292            options.lang_version.0, options.lang_version.1
4293        )?;
4294        writeln!(self.out, "#include <metal_stdlib>")?;
4295        writeln!(self.out, "#include <simd/simd.h>")?;
4296        writeln!(self.out)?;
4297        // Work around Metal bug where `uint` is not available by default
4298        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4299
4300        let mut uses_ray_query = false;
4301        for (_, ty) in module.types.iter() {
4302            match ty.inner {
4303                crate::TypeInner::AccelerationStructure { .. } => {
4304                    if options.lang_version < (2, 4) {
4305                        return Err(Error::UnsupportedRayTracing);
4306                    }
4307                }
4308                crate::TypeInner::RayQuery { .. } => {
4309                    if options.lang_version < (2, 4) {
4310                        return Err(Error::UnsupportedRayTracing);
4311                    }
4312                    uses_ray_query = true;
4313                }
4314                _ => (),
4315            }
4316        }
4317
4318        if module.special_types.ray_desc.is_some()
4319            || module.special_types.ray_intersection.is_some()
4320        {
4321            if options.lang_version < (2, 4) {
4322                return Err(Error::UnsupportedRayTracing);
4323            }
4324        }
4325
4326        if uses_ray_query {
4327            self.put_ray_query_type()?;
4328        }
4329
4330        if options
4331            .bounds_check_policies
4332            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4333        {
4334            self.put_default_constructible()?;
4335        }
4336        writeln!(self.out)?;
4337
4338        {
4339            // Make a `Vec` of all the `GlobalVariable`s that contain
4340            // runtime-sized arrays.
4341            let globals: Vec<Handle<crate::GlobalVariable>> = module
4342                .global_variables
4343                .iter()
4344                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4345                .map(|(handle, _)| handle)
4346                .collect();
4347
4348            let mut buffer_indices = vec![];
4349            for vbm in &pipeline_options.vertex_buffer_mappings {
4350                buffer_indices.push(vbm.id);
4351            }
4352
4353            if !globals.is_empty() || !buffer_indices.is_empty() {
4354                writeln!(self.out, "struct _mslBufferSizes {{")?;
4355
4356                for global in globals {
4357                    writeln!(
4358                        self.out,
4359                        "{}uint {};",
4360                        back::INDENT,
4361                        ArraySizeMember(global)
4362                    )?;
4363                }
4364
4365                for idx in buffer_indices {
4366                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4367                }
4368
4369                writeln!(self.out, "}};")?;
4370                writeln!(self.out)?;
4371            }
4372        };
4373
4374        self.write_type_defs(module)?;
4375        self.write_global_constants(module, info)?;
4376        self.write_functions(module, info, options, pipeline_options)
4377    }
4378
4379    /// Write the definition for the `DefaultConstructible` class.
4380    ///
4381    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4382    /// produce 'zero' values for any type, including structs, arrays, and so
4383    /// on. We could do this by emitting default constructor applications, but
4384    /// that would entail printing the name of the type, which is more trouble
4385    /// than you'd think. Instead, we just construct this magic C++14 class that
4386    /// can be converted to any type that can be default constructed, using
4387    /// template parameter inference to detect which type is needed, so we don't
4388    /// have to figure out the name.
4389    ///
4390    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4391    fn put_default_constructible(&mut self) -> BackendResult {
4392        let tab = back::INDENT;
4393        writeln!(self.out, "struct DefaultConstructible {{")?;
4394        writeln!(self.out, "{tab}template<typename T>")?;
4395        writeln!(self.out, "{tab}operator T() && {{")?;
4396        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4397        writeln!(self.out, "{tab}}}")?;
4398        writeln!(self.out, "}};")?;
4399        Ok(())
4400    }
4401
4402    fn put_ray_query_type(&mut self) -> BackendResult {
4403        let tab = back::INDENT;
4404        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4405        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4406        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4407        writeln!(
4408            self.out,
4409            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4410        )?;
4411        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4412        writeln!(self.out, "}};")?;
4413        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4414        let v_triangle = back::RayIntersectionType::Triangle as u32;
4415        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4416        writeln!(
4417            self.out,
4418            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4419        )?;
4420        writeln!(
4421            self.out,
4422            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4423        )?;
4424        writeln!(self.out, "}}")?;
4425        Ok(())
4426    }
4427
4428    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4429        let mut generated_argument_buffer_wrapper = false;
4430        let mut generated_external_texture_wrapper = false;
4431        for (handle, ty) in module.types.iter() {
4432            match ty.inner {
4433                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4434                    writeln!(self.out, "template <typename T>")?;
4435                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4436                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4437                    writeln!(self.out, "}};")?;
4438                    generated_argument_buffer_wrapper = true;
4439                }
4440                crate::TypeInner::Image {
4441                    class: crate::ImageClass::External,
4442                    ..
4443                } if !generated_external_texture_wrapper => {
4444                    let params_ty_name = &self.names
4445                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4446                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4447                    writeln!(
4448                        self.out,
4449                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4450                        back::INDENT
4451                    )?;
4452                    writeln!(
4453                        self.out,
4454                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4455                        back::INDENT
4456                    )?;
4457                    writeln!(
4458                        self.out,
4459                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4460                        back::INDENT
4461                    )?;
4462                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4463                    writeln!(self.out, "}};")?;
4464                    generated_external_texture_wrapper = true;
4465                }
4466                _ => {}
4467            }
4468
4469            if !ty.needs_alias() {
4470                continue;
4471            }
4472            let name = &self.names[&NameKey::Type(handle)];
4473            match ty.inner {
4474                // Naga IR can pass around arrays by value, but Metal, following
4475                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4476                // on expressions of array type, so assigning the array by value
4477                // isn't possible. However, Metal *does* assign structs by
4478                // value. So in our Metal output, we wrap all array types in
4479                // synthetic struct types:
4480                //
4481                //     struct type1 {
4482                //         float inner[10]
4483                //     };
4484                //
4485                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4486                // any expression that actually wants access to the array.
4487                crate::TypeInner::Array {
4488                    base,
4489                    size,
4490                    stride: _,
4491                } => {
4492                    let base_name = TypeContext {
4493                        handle: base,
4494                        gctx: module.to_ctx(),
4495                        names: &self.names,
4496                        access: crate::StorageAccess::empty(),
4497                        first_time: false,
4498                    };
4499
4500                    match size.resolve(module.to_ctx())? {
4501                        proc::IndexableLength::Known(size) => {
4502                            writeln!(self.out, "struct {name} {{")?;
4503                            writeln!(
4504                                self.out,
4505                                "{}{} {}[{}];",
4506                                back::INDENT,
4507                                base_name,
4508                                WRAPPED_ARRAY_FIELD,
4509                                size
4510                            )?;
4511                            writeln!(self.out, "}};")?;
4512                        }
4513                        proc::IndexableLength::Dynamic => {
4514                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4515                        }
4516                    }
4517                }
4518                crate::TypeInner::Struct {
4519                    ref members, span, ..
4520                } => {
4521                    writeln!(self.out, "struct {name} {{")?;
4522                    let mut last_offset = 0;
4523                    for (index, member) in members.iter().enumerate() {
4524                        if member.offset > last_offset {
4525                            self.struct_member_pads.insert((handle, index as u32));
4526                            let pad = member.offset - last_offset;
4527                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4528                        }
4529                        let ty_inner = &module.types[member.ty].inner;
4530                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4531
4532                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4533
4534                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4535                        match should_pack_struct_member(members, span, index, module) {
4536                            Some(scalar) => {
4537                                writeln!(
4538                                    self.out,
4539                                    "{}{}::packed_{}3 {};",
4540                                    back::INDENT,
4541                                    NAMESPACE,
4542                                    scalar.to_msl_name(),
4543                                    member_name
4544                                )?;
4545                            }
4546                            None => {
4547                                let base_name = TypeContext {
4548                                    handle: member.ty,
4549                                    gctx: module.to_ctx(),
4550                                    names: &self.names,
4551                                    access: crate::StorageAccess::empty(),
4552                                    first_time: false,
4553                                };
4554                                writeln!(
4555                                    self.out,
4556                                    "{}{} {};",
4557                                    back::INDENT,
4558                                    base_name,
4559                                    member_name
4560                                )?;
4561
4562                                // for 3-component vectors, add one component
4563                                if let crate::TypeInner::Vector {
4564                                    size: crate::VectorSize::Tri,
4565                                    scalar,
4566                                } = *ty_inner
4567                                {
4568                                    last_offset += scalar.width as u32;
4569                                }
4570                            }
4571                        }
4572                    }
4573                    if last_offset < span {
4574                        let pad = span - last_offset;
4575                        writeln!(
4576                            self.out,
4577                            "{}char _pad{}[{}];",
4578                            back::INDENT,
4579                            members.len(),
4580                            pad
4581                        )?;
4582                    }
4583                    writeln!(self.out, "}};")?;
4584                }
4585                _ => {
4586                    let ty_name = TypeContext {
4587                        handle,
4588                        gctx: module.to_ctx(),
4589                        names: &self.names,
4590                        access: crate::StorageAccess::empty(),
4591                        first_time: true,
4592                    };
4593                    writeln!(self.out, "typedef {ty_name} {name};")?;
4594                }
4595            }
4596        }
4597
4598        // Write functions to create special types.
4599        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4600            match type_key {
4601                &crate::PredeclaredType::ModfResult { size, scalar }
4602                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4603                    let arg_type_name_owner;
4604                    let arg_type_name = if let Some(size) = size {
4605                        arg_type_name_owner = format!(
4606                            "{NAMESPACE}::{}{}",
4607                            if scalar.width == 8 { "double" } else { "float" },
4608                            size as u8
4609                        );
4610                        &arg_type_name_owner
4611                    } else if scalar.width == 8 {
4612                        "double"
4613                    } else {
4614                        "float"
4615                    };
4616
4617                    let other_type_name_owner;
4618                    let (defined_func_name, called_func_name, other_type_name) =
4619                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4620                            (MODF_FUNCTION, "modf", arg_type_name)
4621                        } else {
4622                            let other_type_name = if let Some(size) = size {
4623                                other_type_name_owner = format!("int{}", size as u8);
4624                                &other_type_name_owner
4625                            } else {
4626                                "int"
4627                            };
4628                            (FREXP_FUNCTION, "frexp", other_type_name)
4629                        };
4630
4631                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4632
4633                    writeln!(self.out)?;
4634                    writeln!(
4635                        self.out,
4636                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4637    {other_type_name} other;
4638    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4639    return {struct_name}{{ fract, other }};
4640}}"
4641                    )?;
4642                }
4643                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4644                    let arg_type_name = scalar.to_msl_name();
4645                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4646                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4647                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4648
4649                    writeln!(self.out)?;
4650
4651                    for address_space_name in ["device", "threadgroup"] {
4652                        writeln!(
4653                            self.out,
4654                            "\
4655template <typename A>
4656{struct_name} {defined_func_name}(
4657    {address_space_name} A *atomic_ptr,
4658    {arg_type_name} cmp,
4659    {arg_type_name} v
4660) {{
4661    bool swapped = {NAMESPACE}::{called_func_name}(
4662        atomic_ptr, &cmp, v,
4663        metal::memory_order_relaxed, metal::memory_order_relaxed
4664    );
4665    return {struct_name}{{cmp, swapped}};
4666}}"
4667                        )?;
4668                    }
4669                }
4670            }
4671        }
4672
4673        Ok(())
4674    }
4675
4676    /// Writes all named constants
4677    fn write_global_constants(
4678        &mut self,
4679        module: &crate::Module,
4680        mod_info: &valid::ModuleInfo,
4681    ) -> BackendResult {
4682        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4683
4684        for (handle, constant) in constants {
4685            let ty_name = TypeContext {
4686                handle: constant.ty,
4687                gctx: module.to_ctx(),
4688                names: &self.names,
4689                access: crate::StorageAccess::empty(),
4690                first_time: false,
4691            };
4692            let name = &self.names[&NameKey::Constant(handle)];
4693            write!(self.out, "constant {ty_name} {name} = ")?;
4694            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4695            writeln!(self.out, ";")?;
4696        }
4697
4698        Ok(())
4699    }
4700
4701    fn put_inline_sampler_properties(
4702        &mut self,
4703        level: back::Level,
4704        sampler: &sm::InlineSampler,
4705    ) -> BackendResult {
4706        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4707            writeln!(
4708                self.out,
4709                "{}{}::{}_address::{},",
4710                level,
4711                NAMESPACE,
4712                letter,
4713                address.as_str(),
4714            )?;
4715        }
4716        writeln!(
4717            self.out,
4718            "{}{}::mag_filter::{},",
4719            level,
4720            NAMESPACE,
4721            sampler.mag_filter.as_str(),
4722        )?;
4723        writeln!(
4724            self.out,
4725            "{}{}::min_filter::{},",
4726            level,
4727            NAMESPACE,
4728            sampler.min_filter.as_str(),
4729        )?;
4730        if let Some(filter) = sampler.mip_filter {
4731            writeln!(
4732                self.out,
4733                "{}{}::mip_filter::{},",
4734                level,
4735                NAMESPACE,
4736                filter.as_str(),
4737            )?;
4738        }
4739        // avoid setting it on platforms that don't support it
4740        if sampler.border_color != sm::BorderColor::TransparentBlack {
4741            writeln!(
4742                self.out,
4743                "{}{}::border_color::{},",
4744                level,
4745                NAMESPACE,
4746                sampler.border_color.as_str(),
4747            )?;
4748        }
4749        //TODO: I'm not able to feed this in a way that MSL likes:
4750        //>error: use of undeclared identifier 'lod_clamp'
4751        //>error: no member named 'max_anisotropy' in namespace 'metal'
4752        if false {
4753            if let Some(ref lod) = sampler.lod_clamp {
4754                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4755            }
4756            if let Some(aniso) = sampler.max_anisotropy {
4757                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4758            }
4759        }
4760        if sampler.compare_func != sm::CompareFunc::Never {
4761            writeln!(
4762                self.out,
4763                "{}{}::compare_func::{},",
4764                level,
4765                NAMESPACE,
4766                sampler.compare_func.as_str(),
4767            )?;
4768        }
4769        writeln!(
4770            self.out,
4771            "{}{}::coord::{}",
4772            level,
4773            NAMESPACE,
4774            sampler.coord.as_str()
4775        )?;
4776        Ok(())
4777    }
4778
4779    fn write_unpacking_function(
4780        &mut self,
4781        format: back::msl::VertexFormat,
4782    ) -> Result<(String, u32, u32), Error> {
4783        use back::msl::VertexFormat::*;
4784        match format {
4785            Uint8 => {
4786                let name = self.namer.call("unpackUint8");
4787                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4788                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4789                writeln!(self.out, "}}")?;
4790                Ok((name, 1, 1))
4791            }
4792            Uint8x2 => {
4793                let name = self.namer.call("unpackUint8x2");
4794                writeln!(
4795                    self.out,
4796                    "metal::uint2 {name}(metal::uchar b0, \
4797                                         metal::uchar b1) {{"
4798                )?;
4799                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4800                writeln!(self.out, "}}")?;
4801                Ok((name, 2, 2))
4802            }
4803            Uint8x4 => {
4804                let name = self.namer.call("unpackUint8x4");
4805                writeln!(
4806                    self.out,
4807                    "metal::uint4 {name}(metal::uchar b0, \
4808                                         metal::uchar b1, \
4809                                         metal::uchar b2, \
4810                                         metal::uchar b3) {{"
4811                )?;
4812                writeln!(
4813                    self.out,
4814                    "{}return metal::uint4(b0, b1, b2, b3);",
4815                    back::INDENT
4816                )?;
4817                writeln!(self.out, "}}")?;
4818                Ok((name, 4, 4))
4819            }
4820            Sint8 => {
4821                let name = self.namer.call("unpackSint8");
4822                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4823                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4824                writeln!(self.out, "}}")?;
4825                Ok((name, 1, 1))
4826            }
4827            Sint8x2 => {
4828                let name = self.namer.call("unpackSint8x2");
4829                writeln!(
4830                    self.out,
4831                    "metal::int2 {name}(metal::uchar b0, \
4832                                        metal::uchar b1) {{"
4833                )?;
4834                writeln!(
4835                    self.out,
4836                    "{}return metal::int2(as_type<char>(b0), \
4837                                          as_type<char>(b1));",
4838                    back::INDENT
4839                )?;
4840                writeln!(self.out, "}}")?;
4841                Ok((name, 2, 2))
4842            }
4843            Sint8x4 => {
4844                let name = self.namer.call("unpackSint8x4");
4845                writeln!(
4846                    self.out,
4847                    "metal::int4 {name}(metal::uchar b0, \
4848                                        metal::uchar b1, \
4849                                        metal::uchar b2, \
4850                                        metal::uchar b3) {{"
4851                )?;
4852                writeln!(
4853                    self.out,
4854                    "{}return metal::int4(as_type<char>(b0), \
4855                                          as_type<char>(b1), \
4856                                          as_type<char>(b2), \
4857                                          as_type<char>(b3));",
4858                    back::INDENT
4859                )?;
4860                writeln!(self.out, "}}")?;
4861                Ok((name, 4, 4))
4862            }
4863            Unorm8 => {
4864                let name = self.namer.call("unpackUnorm8");
4865                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4866                writeln!(
4867                    self.out,
4868                    "{}return float(float(b0) / 255.0f);",
4869                    back::INDENT
4870                )?;
4871                writeln!(self.out, "}}")?;
4872                Ok((name, 1, 1))
4873            }
4874            Unorm8x2 => {
4875                let name = self.namer.call("unpackUnorm8x2");
4876                writeln!(
4877                    self.out,
4878                    "metal::float2 {name}(metal::uchar b0, \
4879                                          metal::uchar b1) {{"
4880                )?;
4881                writeln!(
4882                    self.out,
4883                    "{}return metal::float2(float(b0) / 255.0f, \
4884                                            float(b1) / 255.0f);",
4885                    back::INDENT
4886                )?;
4887                writeln!(self.out, "}}")?;
4888                Ok((name, 2, 2))
4889            }
4890            Unorm8x4 => {
4891                let name = self.namer.call("unpackUnorm8x4");
4892                writeln!(
4893                    self.out,
4894                    "metal::float4 {name}(metal::uchar b0, \
4895                                          metal::uchar b1, \
4896                                          metal::uchar b2, \
4897                                          metal::uchar b3) {{"
4898                )?;
4899                writeln!(
4900                    self.out,
4901                    "{}return metal::float4(float(b0) / 255.0f, \
4902                                            float(b1) / 255.0f, \
4903                                            float(b2) / 255.0f, \
4904                                            float(b3) / 255.0f);",
4905                    back::INDENT
4906                )?;
4907                writeln!(self.out, "}}")?;
4908                Ok((name, 4, 4))
4909            }
4910            Snorm8 => {
4911                let name = self.namer.call("unpackSnorm8");
4912                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4913                writeln!(
4914                    self.out,
4915                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4916                    back::INDENT
4917                )?;
4918                writeln!(self.out, "}}")?;
4919                Ok((name, 1, 1))
4920            }
4921            Snorm8x2 => {
4922                let name = self.namer.call("unpackSnorm8x2");
4923                writeln!(
4924                    self.out,
4925                    "metal::float2 {name}(metal::uchar b0, \
4926                                          metal::uchar b1) {{"
4927                )?;
4928                writeln!(
4929                    self.out,
4930                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4931                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4932                    back::INDENT
4933                )?;
4934                writeln!(self.out, "}}")?;
4935                Ok((name, 2, 2))
4936            }
4937            Snorm8x4 => {
4938                let name = self.namer.call("unpackSnorm8x4");
4939                writeln!(
4940                    self.out,
4941                    "metal::float4 {name}(metal::uchar b0, \
4942                                          metal::uchar b1, \
4943                                          metal::uchar b2, \
4944                                          metal::uchar b3) {{"
4945                )?;
4946                writeln!(
4947                    self.out,
4948                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4949                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4950                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4951                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4952                    back::INDENT
4953                )?;
4954                writeln!(self.out, "}}")?;
4955                Ok((name, 4, 4))
4956            }
4957            Uint16 => {
4958                let name = self.namer.call("unpackUint16");
4959                writeln!(
4960                    self.out,
4961                    "metal::uint {name}(metal::uint b0, \
4962                                        metal::uint b1) {{"
4963                )?;
4964                writeln!(
4965                    self.out,
4966                    "{}return metal::uint(b1 << 8 | b0);",
4967                    back::INDENT
4968                )?;
4969                writeln!(self.out, "}}")?;
4970                Ok((name, 2, 1))
4971            }
4972            Uint16x2 => {
4973                let name = self.namer.call("unpackUint16x2");
4974                writeln!(
4975                    self.out,
4976                    "metal::uint2 {name}(metal::uint b0, \
4977                                         metal::uint b1, \
4978                                         metal::uint b2, \
4979                                         metal::uint b3) {{"
4980                )?;
4981                writeln!(
4982                    self.out,
4983                    "{}return metal::uint2(b1 << 8 | b0, \
4984                                           b3 << 8 | b2);",
4985                    back::INDENT
4986                )?;
4987                writeln!(self.out, "}}")?;
4988                Ok((name, 4, 2))
4989            }
4990            Uint16x4 => {
4991                let name = self.namer.call("unpackUint16x4");
4992                writeln!(
4993                    self.out,
4994                    "metal::uint4 {name}(metal::uint b0, \
4995                                         metal::uint b1, \
4996                                         metal::uint b2, \
4997                                         metal::uint b3, \
4998                                         metal::uint b4, \
4999                                         metal::uint b5, \
5000                                         metal::uint b6, \
5001                                         metal::uint b7) {{"
5002                )?;
5003                writeln!(
5004                    self.out,
5005                    "{}return metal::uint4(b1 << 8 | b0, \
5006                                           b3 << 8 | b2, \
5007                                           b5 << 8 | b4, \
5008                                           b7 << 8 | b6);",
5009                    back::INDENT
5010                )?;
5011                writeln!(self.out, "}}")?;
5012                Ok((name, 8, 4))
5013            }
5014            Sint16 => {
5015                let name = self.namer.call("unpackSint16");
5016                writeln!(
5017                    self.out,
5018                    "int {name}(metal::ushort b0, \
5019                                metal::ushort b1) {{"
5020                )?;
5021                writeln!(
5022                    self.out,
5023                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5024                    back::INDENT
5025                )?;
5026                writeln!(self.out, "}}")?;
5027                Ok((name, 2, 1))
5028            }
5029            Sint16x2 => {
5030                let name = self.namer.call("unpackSint16x2");
5031                writeln!(
5032                    self.out,
5033                    "metal::int2 {name}(metal::ushort b0, \
5034                                        metal::ushort b1, \
5035                                        metal::ushort b2, \
5036                                        metal::ushort b3) {{"
5037                )?;
5038                writeln!(
5039                    self.out,
5040                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5041                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5042                    back::INDENT
5043                )?;
5044                writeln!(self.out, "}}")?;
5045                Ok((name, 4, 2))
5046            }
5047            Sint16x4 => {
5048                let name = self.namer.call("unpackSint16x4");
5049                writeln!(
5050                    self.out,
5051                    "metal::int4 {name}(metal::ushort b0, \
5052                                        metal::ushort b1, \
5053                                        metal::ushort b2, \
5054                                        metal::ushort b3, \
5055                                        metal::ushort b4, \
5056                                        metal::ushort b5, \
5057                                        metal::ushort b6, \
5058                                        metal::ushort b7) {{"
5059                )?;
5060                writeln!(
5061                    self.out,
5062                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5063                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5064                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5065                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5066                    back::INDENT
5067                )?;
5068                writeln!(self.out, "}}")?;
5069                Ok((name, 8, 4))
5070            }
5071            Unorm16 => {
5072                let name = self.namer.call("unpackUnorm16");
5073                writeln!(
5074                    self.out,
5075                    "float {name}(metal::ushort b0, \
5076                                  metal::ushort b1) {{"
5077                )?;
5078                writeln!(
5079                    self.out,
5080                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5081                    back::INDENT
5082                )?;
5083                writeln!(self.out, "}}")?;
5084                Ok((name, 2, 1))
5085            }
5086            Unorm16x2 => {
5087                let name = self.namer.call("unpackUnorm16x2");
5088                writeln!(
5089                    self.out,
5090                    "metal::float2 {name}(metal::ushort b0, \
5091                                          metal::ushort b1, \
5092                                          metal::ushort b2, \
5093                                          metal::ushort b3) {{"
5094                )?;
5095                writeln!(
5096                    self.out,
5097                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5098                                            float(b3 << 8 | b2) / 65535.0f);",
5099                    back::INDENT
5100                )?;
5101                writeln!(self.out, "}}")?;
5102                Ok((name, 4, 2))
5103            }
5104            Unorm16x4 => {
5105                let name = self.namer.call("unpackUnorm16x4");
5106                writeln!(
5107                    self.out,
5108                    "metal::float4 {name}(metal::ushort b0, \
5109                                          metal::ushort b1, \
5110                                          metal::ushort b2, \
5111                                          metal::ushort b3, \
5112                                          metal::ushort b4, \
5113                                          metal::ushort b5, \
5114                                          metal::ushort b6, \
5115                                          metal::ushort b7) {{"
5116                )?;
5117                writeln!(
5118                    self.out,
5119                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5120                                            float(b3 << 8 | b2) / 65535.0f, \
5121                                            float(b5 << 8 | b4) / 65535.0f, \
5122                                            float(b7 << 8 | b6) / 65535.0f);",
5123                    back::INDENT
5124                )?;
5125                writeln!(self.out, "}}")?;
5126                Ok((name, 8, 4))
5127            }
5128            Snorm16 => {
5129                let name = self.namer.call("unpackSnorm16");
5130                writeln!(
5131                    self.out,
5132                    "float {name}(metal::ushort b0, \
5133                                  metal::ushort b1) {{"
5134                )?;
5135                writeln!(
5136                    self.out,
5137                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5138                    back::INDENT
5139                )?;
5140                writeln!(self.out, "}}")?;
5141                Ok((name, 2, 1))
5142            }
5143            Snorm16x2 => {
5144                let name = self.namer.call("unpackSnorm16x2");
5145                writeln!(
5146                    self.out,
5147                    "metal::float2 {name}(metal::ushort b0, \
5148                                          metal::ushort b1, \
5149                                          metal::ushort b2, \
5150                                          metal::ushort b3) {{"
5151                )?;
5152                writeln!(
5153                    self.out,
5154                    "{}return metal::unpack_snorm2x16_to_float(b1 << 24 | b0 << 16 | b3 << 8 | b2);",
5155                    back::INDENT
5156                )?;
5157                writeln!(self.out, "}}")?;
5158                Ok((name, 4, 2))
5159            }
5160            Snorm16x4 => {
5161                let name = self.namer.call("unpackSnorm16x4");
5162                writeln!(
5163                    self.out,
5164                    "metal::float4 {name}(metal::ushort b0, \
5165                                          metal::ushort b1, \
5166                                          metal::ushort b2, \
5167                                          metal::ushort b3, \
5168                                          metal::ushort b4, \
5169                                          metal::ushort b5, \
5170                                          metal::ushort b6, \
5171                                          metal::ushort b7) {{"
5172                )?;
5173                writeln!(
5174                    self.out,
5175                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b1 << 24 | b0 << 16 | b3 << 8 | b2), \
5176                                            metal::unpack_snorm2x16_to_float(b5 << 24 | b4 << 16 | b7 << 8 | b6));",
5177                    back::INDENT
5178                )?;
5179                writeln!(self.out, "}}")?;
5180                Ok((name, 8, 4))
5181            }
5182            Float16 => {
5183                let name = self.namer.call("unpackFloat16");
5184                writeln!(
5185                    self.out,
5186                    "float {name}(metal::ushort b0, \
5187                                  metal::ushort b1) {{"
5188                )?;
5189                writeln!(
5190                    self.out,
5191                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5192                    back::INDENT
5193                )?;
5194                writeln!(self.out, "}}")?;
5195                Ok((name, 2, 1))
5196            }
5197            Float16x2 => {
5198                let name = self.namer.call("unpackFloat16x2");
5199                writeln!(
5200                    self.out,
5201                    "metal::float2 {name}(metal::ushort b0, \
5202                                          metal::ushort b1, \
5203                                          metal::ushort b2, \
5204                                          metal::ushort b3) {{"
5205                )?;
5206                writeln!(
5207                    self.out,
5208                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5209                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5210                    back::INDENT
5211                )?;
5212                writeln!(self.out, "}}")?;
5213                Ok((name, 4, 2))
5214            }
5215            Float16x4 => {
5216                let name = self.namer.call("unpackFloat16x4");
5217                writeln!(
5218                    self.out,
5219                    "metal::float4 {name}(metal::ushort b0, \
5220                                        metal::ushort b1, \
5221                                        metal::ushort b2, \
5222                                        metal::ushort b3, \
5223                                        metal::ushort b4, \
5224                                        metal::ushort b5, \
5225                                        metal::ushort b6, \
5226                                        metal::ushort b7) {{"
5227                )?;
5228                writeln!(
5229                    self.out,
5230                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5231                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5232                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5233                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5234                    back::INDENT
5235                )?;
5236                writeln!(self.out, "}}")?;
5237                Ok((name, 8, 4))
5238            }
5239            Float32 => {
5240                let name = self.namer.call("unpackFloat32");
5241                writeln!(
5242                    self.out,
5243                    "float {name}(uint b0, \
5244                                  uint b1, \
5245                                  uint b2, \
5246                                  uint b3) {{"
5247                )?;
5248                writeln!(
5249                    self.out,
5250                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5251                    back::INDENT
5252                )?;
5253                writeln!(self.out, "}}")?;
5254                Ok((name, 4, 1))
5255            }
5256            Float32x2 => {
5257                let name = self.namer.call("unpackFloat32x2");
5258                writeln!(
5259                    self.out,
5260                    "metal::float2 {name}(uint b0, \
5261                                          uint b1, \
5262                                          uint b2, \
5263                                          uint b3, \
5264                                          uint b4, \
5265                                          uint b5, \
5266                                          uint b6, \
5267                                          uint b7) {{"
5268                )?;
5269                writeln!(
5270                    self.out,
5271                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5272                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5273                    back::INDENT
5274                )?;
5275                writeln!(self.out, "}}")?;
5276                Ok((name, 8, 2))
5277            }
5278            Float32x3 => {
5279                let name = self.namer.call("unpackFloat32x3");
5280                writeln!(
5281                    self.out,
5282                    "metal::float3 {name}(uint b0, \
5283                                          uint b1, \
5284                                          uint b2, \
5285                                          uint b3, \
5286                                          uint b4, \
5287                                          uint b5, \
5288                                          uint b6, \
5289                                          uint b7, \
5290                                          uint b8, \
5291                                          uint b9, \
5292                                          uint b10, \
5293                                          uint b11) {{"
5294                )?;
5295                writeln!(
5296                    self.out,
5297                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5298                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5299                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5300                    back::INDENT
5301                )?;
5302                writeln!(self.out, "}}")?;
5303                Ok((name, 12, 3))
5304            }
5305            Float32x4 => {
5306                let name = self.namer.call("unpackFloat32x4");
5307                writeln!(
5308                    self.out,
5309                    "metal::float4 {name}(uint b0, \
5310                                          uint b1, \
5311                                          uint b2, \
5312                                          uint b3, \
5313                                          uint b4, \
5314                                          uint b5, \
5315                                          uint b6, \
5316                                          uint b7, \
5317                                          uint b8, \
5318                                          uint b9, \
5319                                          uint b10, \
5320                                          uint b11, \
5321                                          uint b12, \
5322                                          uint b13, \
5323                                          uint b14, \
5324                                          uint b15) {{"
5325                )?;
5326                writeln!(
5327                    self.out,
5328                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5329                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5330                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5331                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5332                    back::INDENT
5333                )?;
5334                writeln!(self.out, "}}")?;
5335                Ok((name, 16, 4))
5336            }
5337            Uint32 => {
5338                let name = self.namer.call("unpackUint32");
5339                writeln!(
5340                    self.out,
5341                    "uint {name}(uint b0, \
5342                                 uint b1, \
5343                                 uint b2, \
5344                                 uint b3) {{"
5345                )?;
5346                writeln!(
5347                    self.out,
5348                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5349                    back::INDENT
5350                )?;
5351                writeln!(self.out, "}}")?;
5352                Ok((name, 4, 1))
5353            }
5354            Uint32x2 => {
5355                let name = self.namer.call("unpackUint32x2");
5356                writeln!(
5357                    self.out,
5358                    "uint2 {name}(uint b0, \
5359                                  uint b1, \
5360                                  uint b2, \
5361                                  uint b3, \
5362                                  uint b4, \
5363                                  uint b5, \
5364                                  uint b6, \
5365                                  uint b7) {{"
5366                )?;
5367                writeln!(
5368                    self.out,
5369                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5370                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5371                    back::INDENT
5372                )?;
5373                writeln!(self.out, "}}")?;
5374                Ok((name, 8, 2))
5375            }
5376            Uint32x3 => {
5377                let name = self.namer.call("unpackUint32x3");
5378                writeln!(
5379                    self.out,
5380                    "uint3 {name}(uint b0, \
5381                                  uint b1, \
5382                                  uint b2, \
5383                                  uint b3, \
5384                                  uint b4, \
5385                                  uint b5, \
5386                                  uint b6, \
5387                                  uint b7, \
5388                                  uint b8, \
5389                                  uint b9, \
5390                                  uint b10, \
5391                                  uint b11) {{"
5392                )?;
5393                writeln!(
5394                    self.out,
5395                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5396                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5397                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5398                    back::INDENT
5399                )?;
5400                writeln!(self.out, "}}")?;
5401                Ok((name, 12, 3))
5402            }
5403            Uint32x4 => {
5404                let name = self.namer.call("unpackUint32x4");
5405                writeln!(
5406                    self.out,
5407                    "{NAMESPACE}::uint4 {name}(uint b0, \
5408                                  uint b1, \
5409                                  uint b2, \
5410                                  uint b3, \
5411                                  uint b4, \
5412                                  uint b5, \
5413                                  uint b6, \
5414                                  uint b7, \
5415                                  uint b8, \
5416                                  uint b9, \
5417                                  uint b10, \
5418                                  uint b11, \
5419                                  uint b12, \
5420                                  uint b13, \
5421                                  uint b14, \
5422                                  uint b15) {{"
5423                )?;
5424                writeln!(
5425                    self.out,
5426                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5427                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5428                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5429                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5430                    back::INDENT
5431                )?;
5432                writeln!(self.out, "}}")?;
5433                Ok((name, 16, 4))
5434            }
5435            Sint32 => {
5436                let name = self.namer.call("unpackSint32");
5437                writeln!(
5438                    self.out,
5439                    "int {name}(uint b0, \
5440                                uint b1, \
5441                                uint b2, \
5442                                uint b3) {{"
5443                )?;
5444                writeln!(
5445                    self.out,
5446                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5447                    back::INDENT
5448                )?;
5449                writeln!(self.out, "}}")?;
5450                Ok((name, 4, 1))
5451            }
5452            Sint32x2 => {
5453                let name = self.namer.call("unpackSint32x2");
5454                writeln!(
5455                    self.out,
5456                    "metal::int2 {name}(uint b0, \
5457                                        uint b1, \
5458                                        uint b2, \
5459                                        uint b3, \
5460                                        uint b4, \
5461                                        uint b5, \
5462                                        uint b6, \
5463                                        uint b7) {{"
5464                )?;
5465                writeln!(
5466                    self.out,
5467                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5468                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5469                    back::INDENT
5470                )?;
5471                writeln!(self.out, "}}")?;
5472                Ok((name, 8, 2))
5473            }
5474            Sint32x3 => {
5475                let name = self.namer.call("unpackSint32x3");
5476                writeln!(
5477                    self.out,
5478                    "metal::int3 {name}(uint b0, \
5479                                        uint b1, \
5480                                        uint b2, \
5481                                        uint b3, \
5482                                        uint b4, \
5483                                        uint b5, \
5484                                        uint b6, \
5485                                        uint b7, \
5486                                        uint b8, \
5487                                        uint b9, \
5488                                        uint b10, \
5489                                        uint b11) {{"
5490                )?;
5491                writeln!(
5492                    self.out,
5493                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5494                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5495                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5496                    back::INDENT
5497                )?;
5498                writeln!(self.out, "}}")?;
5499                Ok((name, 12, 3))
5500            }
5501            Sint32x4 => {
5502                let name = self.namer.call("unpackSint32x4");
5503                writeln!(
5504                    self.out,
5505                    "metal::int4 {name}(uint b0, \
5506                                        uint b1, \
5507                                        uint b2, \
5508                                        uint b3, \
5509                                        uint b4, \
5510                                        uint b5, \
5511                                        uint b6, \
5512                                        uint b7, \
5513                                        uint b8, \
5514                                        uint b9, \
5515                                        uint b10, \
5516                                        uint b11, \
5517                                        uint b12, \
5518                                        uint b13, \
5519                                        uint b14, \
5520                                        uint b15) {{"
5521                )?;
5522                writeln!(
5523                    self.out,
5524                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5525                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5526                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5527                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5528                    back::INDENT
5529                )?;
5530                writeln!(self.out, "}}")?;
5531                Ok((name, 16, 4))
5532            }
5533            Unorm10_10_10_2 => {
5534                let name = self.namer.call("unpackUnorm10_10_10_2");
5535                writeln!(
5536                    self.out,
5537                    "metal::float4 {name}(uint b0, \
5538                                          uint b1, \
5539                                          uint b2, \
5540                                          uint b3) {{"
5541                )?;
5542                writeln!(
5543                    self.out,
5544                    // The following is correct for RGBA packing, but our format seems to
5545                    // match ABGR, which can be fed into the Metal builtin function
5546                    // unpack_unorm10a2_to_float.
5547                    /*
5548                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5549                       uint r = (v & 0xFFC00000) >> 22; \
5550                       uint g = (v & 0x003FF000) >> 12; \
5551                       uint b = (v & 0x00000FFC) >> 2; \
5552                       uint a = (v & 0x00000003); \
5553                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5554                    */
5555                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5556                    back::INDENT
5557                )?;
5558                writeln!(self.out, "}}")?;
5559                Ok((name, 4, 4))
5560            }
5561            Unorm8x4Bgra => {
5562                let name = self.namer.call("unpackUnorm8x4Bgra");
5563                writeln!(
5564                    self.out,
5565                    "metal::float4 {name}(metal::uchar b0, \
5566                                          metal::uchar b1, \
5567                                          metal::uchar b2, \
5568                                          metal::uchar b3) {{"
5569                )?;
5570                writeln!(
5571                    self.out,
5572                    "{}return metal::float4(float(b2) / 255.0f, \
5573                                            float(b1) / 255.0f, \
5574                                            float(b0) / 255.0f, \
5575                                            float(b3) / 255.0f);",
5576                    back::INDENT
5577                )?;
5578                writeln!(self.out, "}}")?;
5579                Ok((name, 4, 4))
5580            }
5581        }
5582    }
5583
5584    fn write_wrapped_unary_op(
5585        &mut self,
5586        module: &crate::Module,
5587        func_ctx: &back::FunctionCtx,
5588        op: crate::UnaryOperator,
5589        operand: Handle<crate::Expression>,
5590    ) -> BackendResult {
5591        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5592        match op {
5593            // Negating the TYPE_MIN of a two's complement signed integer
5594            // type causes overflow, which is undefined behaviour in MSL. To
5595            // avoid this we bitcast the value to unsigned and negate it,
5596            // then bitcast back to signed.
5597            // This adheres to the WGSL spec in that the negative of the
5598            // type's minimum value should equal to the minimum value.
5599            crate::UnaryOperator::Negate
5600                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5601            {
5602                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5603                    return Ok(());
5604                };
5605                let wrapped = WrappedFunction::UnaryOp {
5606                    op,
5607                    ty: (vector_size, scalar),
5608                };
5609                if !self.wrapped_functions.insert(wrapped) {
5610                    return Ok(());
5611                }
5612
5613                let unsigned_scalar = crate::Scalar {
5614                    kind: crate::ScalarKind::Uint,
5615                    ..scalar
5616                };
5617                let mut type_name = String::new();
5618                let mut unsigned_type_name = String::new();
5619                match vector_size {
5620                    None => {
5621                        put_numeric_type(&mut type_name, scalar, &[])?;
5622                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5623                    }
5624                    Some(size) => {
5625                        put_numeric_type(&mut type_name, scalar, &[size])?;
5626                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5627                    }
5628                };
5629
5630                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5631                let level = back::Level(1);
5632                writeln!(
5633                    self.out,
5634                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5635                )?;
5636                writeln!(self.out, "}}")?;
5637                writeln!(self.out)?;
5638            }
5639            _ => {}
5640        }
5641        Ok(())
5642    }
5643
5644    fn write_wrapped_binary_op(
5645        &mut self,
5646        module: &crate::Module,
5647        func_ctx: &back::FunctionCtx,
5648        expr: Handle<crate::Expression>,
5649        op: crate::BinaryOperator,
5650        left: Handle<crate::Expression>,
5651        right: Handle<crate::Expression>,
5652    ) -> BackendResult {
5653        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5654        let left_ty = func_ctx.resolve_type(left, &module.types);
5655        let right_ty = func_ctx.resolve_type(right, &module.types);
5656        match (op, expr_ty.scalar_kind()) {
5657            // Signed integer division of TYPE_MIN / -1, or signed or
5658            // unsigned division by zero, gives an unspecified value in MSL.
5659            // We override the divisor to 1 in these cases.
5660            // This adheres to the WGSL spec in that:
5661            // * TYPE_MIN / -1 == TYPE_MIN
5662            // * x / 0 == x
5663            (
5664                crate::BinaryOperator::Divide,
5665                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5666            ) => {
5667                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5668                    return Ok(());
5669                };
5670                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5671                    return Ok(());
5672                };
5673                let wrapped = WrappedFunction::BinaryOp {
5674                    op,
5675                    left_ty: left_wrapped_ty,
5676                    right_ty: right_wrapped_ty,
5677                };
5678                if !self.wrapped_functions.insert(wrapped) {
5679                    return Ok(());
5680                }
5681
5682                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5683                    return Ok(());
5684                };
5685                let mut type_name = String::new();
5686                match vector_size {
5687                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5688                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5689                };
5690                writeln!(
5691                    self.out,
5692                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5693                )?;
5694                let level = back::Level(1);
5695                match scalar.kind {
5696                    crate::ScalarKind::Sint => {
5697                        let min_val = match scalar.width {
5698                            4 => crate::Literal::I32(i32::MIN),
5699                            8 => crate::Literal::I64(i64::MIN),
5700                            _ => {
5701                                return Err(Error::GenericValidation(format!(
5702                                    "Unexpected width for scalar {scalar:?}"
5703                                )));
5704                            }
5705                        };
5706                        write!(
5707                            self.out,
5708                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5709                        )?;
5710                        self.put_literal(min_val)?;
5711                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5712                    }
5713                    crate::ScalarKind::Uint => writeln!(
5714                        self.out,
5715                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5716                    )?,
5717                    _ => unreachable!(),
5718                }
5719                writeln!(self.out, "}}")?;
5720                writeln!(self.out)?;
5721            }
5722            // Integer modulo where one or both operands are negative, or the
5723            // divisor is zero, is undefined behaviour in MSL. To avoid this
5724            // we use the following equation:
5725            //
5726            // dividend - (dividend / divisor) * divisor
5727            //
5728            // overriding the divisor to 1 if either it is 0, or it is -1
5729            // and the dividend is TYPE_MIN.
5730            //
5731            // This adheres to the WGSL spec in that:
5732            // * TYPE_MIN % -1 == 0
5733            // * x % 0 == 0
5734            (
5735                crate::BinaryOperator::Modulo,
5736                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5737            ) => {
5738                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5739                    return Ok(());
5740                };
5741                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5742                else {
5743                    return Ok(());
5744                };
5745                let wrapped = WrappedFunction::BinaryOp {
5746                    op,
5747                    left_ty: left_wrapped_ty,
5748                    right_ty: (right_vector_size, right_scalar),
5749                };
5750                if !self.wrapped_functions.insert(wrapped) {
5751                    return Ok(());
5752                }
5753
5754                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5755                    return Ok(());
5756                };
5757                let mut type_name = String::new();
5758                match vector_size {
5759                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5760                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5761                };
5762                let mut rhs_type_name = String::new();
5763                match right_vector_size {
5764                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5765                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5766                };
5767
5768                writeln!(
5769                    self.out,
5770                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5771                )?;
5772                let level = back::Level(1);
5773                match scalar.kind {
5774                    crate::ScalarKind::Sint => {
5775                        let min_val = match scalar.width {
5776                            4 => crate::Literal::I32(i32::MIN),
5777                            8 => crate::Literal::I64(i64::MIN),
5778                            _ => {
5779                                return Err(Error::GenericValidation(format!(
5780                                    "Unexpected width for scalar {scalar:?}"
5781                                )));
5782                            }
5783                        };
5784                        write!(
5785                            self.out,
5786                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5787                        )?;
5788                        self.put_literal(min_val)?;
5789                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5790                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5791                    }
5792                    crate::ScalarKind::Uint => writeln!(
5793                        self.out,
5794                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5795                    )?,
5796                    _ => unreachable!(),
5797                }
5798                writeln!(self.out, "}}")?;
5799                writeln!(self.out)?;
5800            }
5801            _ => {}
5802        }
5803        Ok(())
5804    }
5805
5806    #[allow(clippy::too_many_arguments)]
5807    fn write_wrapped_math_function(
5808        &mut self,
5809        module: &crate::Module,
5810        func_ctx: &back::FunctionCtx,
5811        fun: crate::MathFunction,
5812        arg: Handle<crate::Expression>,
5813        _arg1: Option<Handle<crate::Expression>>,
5814        _arg2: Option<Handle<crate::Expression>>,
5815        _arg3: Option<Handle<crate::Expression>>,
5816    ) -> BackendResult {
5817        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5818        match fun {
5819            // Taking the absolute value of the TYPE_MIN of a two's
5820            // complement signed integer type causes overflow, which is
5821            // undefined behaviour in MSL. To avoid this, when the value is
5822            // negative we bitcast the value to unsigned and negate it, then
5823            // bitcast back to signed.
5824            // This adheres to the WGSL spec in that the absolute of the
5825            // type's minimum value should equal to the minimum value.
5826            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5827                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5828                    return Ok(());
5829                };
5830                let wrapped = WrappedFunction::Math {
5831                    fun,
5832                    arg_ty: (vector_size, scalar),
5833                };
5834                if !self.wrapped_functions.insert(wrapped) {
5835                    return Ok(());
5836                }
5837
5838                let unsigned_scalar = crate::Scalar {
5839                    kind: crate::ScalarKind::Uint,
5840                    ..scalar
5841                };
5842                let mut type_name = String::new();
5843                let mut unsigned_type_name = String::new();
5844                match vector_size {
5845                    None => {
5846                        put_numeric_type(&mut type_name, scalar, &[])?;
5847                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5848                    }
5849                    Some(size) => {
5850                        put_numeric_type(&mut type_name, scalar, &[size])?;
5851                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5852                    }
5853                };
5854
5855                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5856                let level = back::Level(1);
5857                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5858                writeln!(self.out, "}}")?;
5859                writeln!(self.out)?;
5860            }
5861            _ => {}
5862        }
5863        Ok(())
5864    }
5865
5866    fn write_wrapped_cast(
5867        &mut self,
5868        module: &crate::Module,
5869        func_ctx: &back::FunctionCtx,
5870        expr: Handle<crate::Expression>,
5871        kind: crate::ScalarKind,
5872        convert: Option<crate::Bytes>,
5873    ) -> BackendResult {
5874        // Avoid undefined behaviour when casting from a float to integer
5875        // when the value is out of range for the target type. Additionally
5876        // ensure we clamp to the correct value as per the WGSL spec.
5877        //
5878        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
5879        // * If X is exactly representable in the target type T, then the
5880        //   result is that value.
5881        // * Otherwise, the result is the value in T closest to
5882        //   truncate(X) and also exactly representable in the original
5883        //   floating point type.
5884        let src_ty = func_ctx.resolve_type(expr, &module.types);
5885        let Some(width) = convert else {
5886            return Ok(());
5887        };
5888        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5889            return Ok(());
5890        };
5891        let dst_scalar = crate::Scalar { kind, width };
5892        if src_scalar.kind != crate::ScalarKind::Float
5893            || (dst_scalar.kind != crate::ScalarKind::Sint
5894                && dst_scalar.kind != crate::ScalarKind::Uint)
5895        {
5896            return Ok(());
5897        }
5898        let wrapped = WrappedFunction::Cast {
5899            src_scalar,
5900            vector_size,
5901            dst_scalar,
5902        };
5903        if !self.wrapped_functions.insert(wrapped) {
5904            return Ok(());
5905        }
5906        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5907
5908        let mut src_type_name = String::new();
5909        match vector_size {
5910            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5911            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5912        };
5913        let mut dst_type_name = String::new();
5914        match vector_size {
5915            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5916            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5917        };
5918        let fun_name = match dst_scalar {
5919            crate::Scalar::I32 => F2I32_FUNCTION,
5920            crate::Scalar::U32 => F2U32_FUNCTION,
5921            crate::Scalar::I64 => F2I64_FUNCTION,
5922            crate::Scalar::U64 => F2U64_FUNCTION,
5923            _ => unreachable!(),
5924        };
5925
5926        writeln!(
5927            self.out,
5928            "{dst_type_name} {fun_name}({src_type_name} value) {{"
5929        )?;
5930        let level = back::Level(1);
5931        write!(
5932            self.out,
5933            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5934        )?;
5935        self.put_literal(min)?;
5936        write!(self.out, ", ")?;
5937        self.put_literal(max)?;
5938        writeln!(self.out, "));")?;
5939        writeln!(self.out, "}}")?;
5940        writeln!(self.out)?;
5941        Ok(())
5942    }
5943
5944    /// Helper function used by [`Self::write_wrapped_image_load`] and
5945    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
5946    /// conversion code for external textures. Expects the preceding code to
5947    /// declare the Y component as a `float` variable of name `y`, the UV
5948    /// components as a `float2` variable of name `uv`, and the external
5949    /// texture params as a variable of name `params`. The emitted code will
5950    /// return the result.
5951    fn write_convert_yuv_to_rgb_and_return(
5952        &mut self,
5953        level: back::Level,
5954        y: &str,
5955        uv: &str,
5956        params: &str,
5957    ) -> BackendResult {
5958        let l1 = level;
5959        let l2 = l1.next();
5960
5961        // Convert from YUV to non-linear RGB in the source color space.
5962        writeln!(
5963            self.out,
5964            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
5965        )?;
5966
5967        // Apply the inverse of the source transfer function to convert to
5968        // linear RGB in the source color space.
5969        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
5970        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
5971        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
5972        writeln!(
5973            self.out,
5974            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
5975        )?;
5976
5977        // Multiply by the gamut conversion matrix to convert to linear RGB in
5978        // the destination color space.
5979        writeln!(
5980            self.out,
5981            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
5982        )?;
5983
5984        // Finally, apply the dest transfer function to convert to non-linear
5985        // RGB in the destination color space, and return the result.
5986        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
5987        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
5988        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
5989        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
5990
5991        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
5992        Ok(())
5993    }
5994
5995    #[allow(clippy::too_many_arguments)]
5996    fn write_wrapped_image_load(
5997        &mut self,
5998        module: &crate::Module,
5999        func_ctx: &back::FunctionCtx,
6000        image: Handle<crate::Expression>,
6001        _coordinate: Handle<crate::Expression>,
6002        _array_index: Option<Handle<crate::Expression>>,
6003        _sample: Option<Handle<crate::Expression>>,
6004        _level: Option<Handle<crate::Expression>>,
6005    ) -> BackendResult {
6006        // We currently only need to wrap image loads for external textures
6007        let class = match *func_ctx.resolve_type(image, &module.types) {
6008            crate::TypeInner::Image { class, .. } => class,
6009            _ => unreachable!(),
6010        };
6011        if class != crate::ImageClass::External {
6012            return Ok(());
6013        }
6014        let wrapped = WrappedFunction::ImageLoad { class };
6015        if !self.wrapped_functions.insert(wrapped) {
6016            return Ok(());
6017        }
6018
6019        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6020        let l1 = back::Level(1);
6021        let l2 = l1.next();
6022        let l3 = l2.next();
6023        writeln!(
6024            self.out,
6025            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6026        )?;
6027        // Clamp coords to provided size of external texture to prevent OOB
6028        // read. If params.size is zero then clamp to the actual size of the
6029        // texture.
6030        writeln!(
6031            self.out,
6032            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6033        )?;
6034        writeln!(
6035            self.out,
6036            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6037        )?;
6038
6039        // Apply load transformation
6040        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6041        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6042        // For single plane, simply read from plane0
6043        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6044        writeln!(self.out, "{l1}}} else {{")?;
6045
6046        // Chroma planes may be subsampled so we must scale the coords accordingly.
6047        writeln!(
6048            self.out,
6049            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6050        )?;
6051        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6052
6053        // For multi-plane, read the Y value from plane 0
6054        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6055
6056        writeln!(self.out, "{l2}float2 uv;")?;
6057        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6058        // For 2 planes, read UV from interleaved plane 1
6059        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6060        writeln!(self.out, "{l2}}} else {{")?;
6061        // For 3 planes, read U and V from planes 1 and 2 respectively
6062        writeln!(
6063            self.out,
6064            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6065        )?;
6066        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6067        writeln!(
6068            self.out,
6069            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6070        )?;
6071        writeln!(self.out, "{l2}}}")?;
6072
6073        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6074
6075        writeln!(self.out, "{l1}}}")?;
6076        writeln!(self.out, "}}")?;
6077        writeln!(self.out)?;
6078        Ok(())
6079    }
6080
6081    #[allow(clippy::too_many_arguments)]
6082    fn write_wrapped_image_sample(
6083        &mut self,
6084        module: &crate::Module,
6085        func_ctx: &back::FunctionCtx,
6086        image: Handle<crate::Expression>,
6087        _sampler: Handle<crate::Expression>,
6088        _gather: Option<crate::SwizzleComponent>,
6089        _coordinate: Handle<crate::Expression>,
6090        _array_index: Option<Handle<crate::Expression>>,
6091        _offset: Option<Handle<crate::Expression>>,
6092        _level: crate::SampleLevel,
6093        _depth_ref: Option<Handle<crate::Expression>>,
6094        clamp_to_edge: bool,
6095    ) -> BackendResult {
6096        // We currently only need to wrap textureSampleBaseClampToEdge, for
6097        // both sampled and external textures.
6098        if !clamp_to_edge {
6099            return Ok(());
6100        }
6101        let class = match *func_ctx.resolve_type(image, &module.types) {
6102            crate::TypeInner::Image { class, .. } => class,
6103            _ => unreachable!(),
6104        };
6105        let wrapped = WrappedFunction::ImageSample {
6106            class,
6107            clamp_to_edge: true,
6108        };
6109        if !self.wrapped_functions.insert(wrapped) {
6110            return Ok(());
6111        }
6112        match class {
6113            crate::ImageClass::External => {
6114                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6115                let l1 = back::Level(1);
6116                let l2 = l1.next();
6117                let l3 = l2.next();
6118                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6119                writeln!(
6120                    self.out,
6121                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6122                )?;
6123
6124                // Calculate the sample bounds. The purported size of the texture
6125                // (params.size) is irrelevant here as we are dealing with normalized
6126                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6127                // apply the sample transformation to that, also bearing in mind that it
6128                // may contain a flip on either axis. We calculate and adjust for the
6129                // half-texel separately for each plane as it depends on the actual
6130                // texture size which may vary between planes.
6131                writeln!(
6132                    self.out,
6133                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6134                )?;
6135                writeln!(
6136                    self.out,
6137                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6138                )?;
6139                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6140                writeln!(
6141                    self.out,
6142                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6143                )?;
6144                writeln!(
6145                    self.out,
6146                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6147                )?;
6148                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6149                // For single plane, simply sample from plane0
6150                writeln!(
6151                    self.out,
6152                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6153                )?;
6154                writeln!(self.out, "{l1}}} else {{")?;
6155                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6156                writeln!(
6157                    self.out,
6158                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6159                )?;
6160                writeln!(
6161                    self.out,
6162                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6163                )?;
6164
6165                // For multi-plane, sample the Y value from plane 0
6166                writeln!(
6167                    self.out,
6168                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6169                )?;
6170                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6171                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6172                // For 2 planes, sample UV from interleaved plane 1
6173                writeln!(
6174                    self.out,
6175                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6176                )?;
6177                writeln!(self.out, "{l2}}} else {{")?;
6178                // For 3 planes, sample U and V from planes 1 and 2 respectively
6179                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6180                writeln!(
6181                    self.out,
6182                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6183                )?;
6184                writeln!(
6185                    self.out,
6186                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6187                )?;
6188                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6189                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6190                writeln!(self.out, "{l2}}}")?;
6191
6192                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6193
6194                writeln!(self.out, "{l1}}}")?;
6195                writeln!(self.out, "}}")?;
6196                writeln!(self.out)?;
6197            }
6198            _ => {
6199                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6200                let l1 = back::Level(1);
6201                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6202                writeln!(
6203                    self.out,
6204                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6205                )?;
6206                writeln!(self.out, "}}")?;
6207                writeln!(self.out)?;
6208            }
6209        }
6210        Ok(())
6211    }
6212
6213    fn write_wrapped_image_query(
6214        &mut self,
6215        module: &crate::Module,
6216        func_ctx: &back::FunctionCtx,
6217        image: Handle<crate::Expression>,
6218        query: crate::ImageQuery,
6219    ) -> BackendResult {
6220        // We currently only need to wrap size image queries for external textures
6221        if !matches!(query, crate::ImageQuery::Size { .. }) {
6222            return Ok(());
6223        }
6224        let class = match *func_ctx.resolve_type(image, &module.types) {
6225            crate::TypeInner::Image { class, .. } => class,
6226            _ => unreachable!(),
6227        };
6228        if class != crate::ImageClass::External {
6229            return Ok(());
6230        }
6231        let wrapped = WrappedFunction::ImageQuerySize { class };
6232        if !self.wrapped_functions.insert(wrapped) {
6233            return Ok(());
6234        }
6235        writeln!(
6236            self.out,
6237            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6238        )?;
6239        let l1 = back::Level(1);
6240        let l2 = l1.next();
6241        writeln!(
6242            self.out,
6243            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6244        )?;
6245        writeln!(self.out, "{l2}return tex.params.size;")?;
6246        writeln!(self.out, "{l1}}} else {{")?;
6247        // params.size == (0, 0) indicates to query and return plane 0's actual size
6248        writeln!(
6249            self.out,
6250            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6251        )?;
6252        writeln!(self.out, "{l1}}}")?;
6253        writeln!(self.out, "}}")?;
6254        writeln!(self.out)?;
6255        Ok(())
6256    }
6257
6258    pub(super) fn write_wrapped_functions(
6259        &mut self,
6260        module: &crate::Module,
6261        func_ctx: &back::FunctionCtx,
6262    ) -> BackendResult {
6263        for (expr_handle, expr) in func_ctx.expressions.iter() {
6264            match *expr {
6265                crate::Expression::Unary { op, expr: operand } => {
6266                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6267                }
6268                crate::Expression::Binary { op, left, right } => {
6269                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6270                }
6271                crate::Expression::Math {
6272                    fun,
6273                    arg,
6274                    arg1,
6275                    arg2,
6276                    arg3,
6277                } => {
6278                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6279                }
6280                crate::Expression::As {
6281                    expr,
6282                    kind,
6283                    convert,
6284                } => {
6285                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6286                }
6287                crate::Expression::ImageLoad {
6288                    image,
6289                    coordinate,
6290                    array_index,
6291                    sample,
6292                    level,
6293                } => {
6294                    self.write_wrapped_image_load(
6295                        module,
6296                        func_ctx,
6297                        image,
6298                        coordinate,
6299                        array_index,
6300                        sample,
6301                        level,
6302                    )?;
6303                }
6304                crate::Expression::ImageSample {
6305                    image,
6306                    sampler,
6307                    gather,
6308                    coordinate,
6309                    array_index,
6310                    offset,
6311                    level,
6312                    depth_ref,
6313                    clamp_to_edge,
6314                } => {
6315                    self.write_wrapped_image_sample(
6316                        module,
6317                        func_ctx,
6318                        image,
6319                        sampler,
6320                        gather,
6321                        coordinate,
6322                        array_index,
6323                        offset,
6324                        level,
6325                        depth_ref,
6326                        clamp_to_edge,
6327                    )?;
6328                }
6329                crate::Expression::ImageQuery { image, query } => {
6330                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6331                }
6332                _ => {}
6333            }
6334        }
6335
6336        Ok(())
6337    }
6338
6339    // Returns the array of mapped entry point names.
6340    fn write_functions(
6341        &mut self,
6342        module: &crate::Module,
6343        mod_info: &valid::ModuleInfo,
6344        options: &Options,
6345        pipeline_options: &PipelineOptions,
6346    ) -> Result<TranslationInfo, Error> {
6347        use back::msl::VertexFormat;
6348
6349        // Define structs to hold resolved/generated data for vertex buffers and
6350        // their attributes.
6351        struct AttributeMappingResolved {
6352            ty_name: String,
6353            dimension: u32,
6354            ty_is_int: bool,
6355            name: String,
6356        }
6357        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6358
6359        struct VertexBufferMappingResolved<'a> {
6360            id: u32,
6361            stride: u32,
6362            indexed_by_vertex: bool,
6363            ty_name: String,
6364            param_name: String,
6365            elem_name: String,
6366            attributes: &'a Vec<back::msl::AttributeMapping>,
6367        }
6368        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6369
6370        // Define a struct to hold a named reference to a byte-unpacking function.
6371        struct UnpackingFunction {
6372            name: String,
6373            byte_count: u32,
6374            dimension: u32,
6375        }
6376        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6377
6378        // Check if we are attempting vertex pulling. If we are, generate some
6379        // names we'll need, and iterate the vertex buffer mappings to output
6380        // all the conversion functions we'll need to unpack the attribute data.
6381        // We can re-use these names for all entry points that need them, since
6382        // those entry points also use self.namer.
6383        let mut needs_vertex_id = false;
6384        let v_id = self.namer.call("v_id");
6385
6386        let mut needs_instance_id = false;
6387        let i_id = self.namer.call("i_id");
6388        if pipeline_options.vertex_pulling_transform {
6389            for vbm in &pipeline_options.vertex_buffer_mappings {
6390                let buffer_id = vbm.id;
6391                let buffer_stride = vbm.stride;
6392
6393                assert!(
6394                    buffer_stride > 0,
6395                    "Vertex pulling requires a non-zero buffer stride."
6396                );
6397
6398                if vbm.indexed_by_vertex {
6399                    needs_vertex_id = true;
6400                } else {
6401                    needs_instance_id = true;
6402                }
6403
6404                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6405                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6406                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6407
6408                vbm_resolved.push(VertexBufferMappingResolved {
6409                    id: buffer_id,
6410                    stride: buffer_stride,
6411                    indexed_by_vertex: vbm.indexed_by_vertex,
6412                    ty_name: buffer_ty,
6413                    param_name: buffer_param,
6414                    elem_name: buffer_elem,
6415                    attributes: &vbm.attributes,
6416                });
6417
6418                // Iterate the attributes and generate needed unpacking functions.
6419                for attribute in &vbm.attributes {
6420                    if unpacking_functions.contains_key(&attribute.format) {
6421                        continue;
6422                    }
6423                    let (name, byte_count, dimension) =
6424                        match self.write_unpacking_function(attribute.format) {
6425                            Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6426                            _ => {
6427                                continue;
6428                            }
6429                        };
6430                    unpacking_functions.insert(
6431                        attribute.format,
6432                        UnpackingFunction {
6433                            name,
6434                            byte_count,
6435                            dimension,
6436                        },
6437                    );
6438                }
6439            }
6440        }
6441
6442        let mut pass_through_globals = Vec::new();
6443        for (fun_handle, fun) in module.functions.iter() {
6444            log::trace!(
6445                "function {:?}, handle {:?}",
6446                fun.name.as_deref().unwrap_or("(anonymous)"),
6447                fun_handle
6448            );
6449
6450            let ctx = back::FunctionCtx {
6451                ty: back::FunctionType::Function(fun_handle),
6452                info: &mod_info[fun_handle],
6453                expressions: &fun.expressions,
6454                named_expressions: &fun.named_expressions,
6455            };
6456
6457            writeln!(self.out)?;
6458            self.write_wrapped_functions(module, &ctx)?;
6459
6460            let fun_info = &mod_info[fun_handle];
6461            pass_through_globals.clear();
6462            let mut needs_buffer_sizes = false;
6463            for (handle, var) in module.global_variables.iter() {
6464                if !fun_info[handle].is_empty() {
6465                    if var.space.needs_pass_through() {
6466                        pass_through_globals.push(handle);
6467                    }
6468                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6469                }
6470            }
6471
6472            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6473            match fun.result {
6474                Some(ref result) => {
6475                    let ty_name = TypeContext {
6476                        handle: result.ty,
6477                        gctx: module.to_ctx(),
6478                        names: &self.names,
6479                        access: crate::StorageAccess::empty(),
6480                        first_time: false,
6481                    };
6482                    write!(self.out, "{ty_name}")?;
6483                }
6484                None => {
6485                    write!(self.out, "void")?;
6486                }
6487            }
6488            writeln!(self.out, " {fun_name}(")?;
6489
6490            for (index, arg) in fun.arguments.iter().enumerate() {
6491                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6492                let param_type_name = TypeContext {
6493                    handle: arg.ty,
6494                    gctx: module.to_ctx(),
6495                    names: &self.names,
6496                    access: crate::StorageAccess::empty(),
6497                    first_time: false,
6498                };
6499                let separator = separate(
6500                    !pass_through_globals.is_empty()
6501                        || index + 1 != fun.arguments.len()
6502                        || needs_buffer_sizes,
6503                );
6504                writeln!(
6505                    self.out,
6506                    "{}{} {}{}",
6507                    back::INDENT,
6508                    param_type_name,
6509                    name,
6510                    separator
6511                )?;
6512            }
6513            for (index, &handle) in pass_through_globals.iter().enumerate() {
6514                let tyvar = TypedGlobalVariable {
6515                    module,
6516                    names: &self.names,
6517                    handle,
6518                    usage: fun_info[handle],
6519
6520                    reference: true,
6521                };
6522                let separator =
6523                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6524                write!(self.out, "{}", back::INDENT)?;
6525                tyvar.try_fmt(&mut self.out)?;
6526                writeln!(self.out, "{separator}")?;
6527            }
6528
6529            if needs_buffer_sizes {
6530                writeln!(
6531                    self.out,
6532                    "{}constant _mslBufferSizes& _buffer_sizes",
6533                    back::INDENT
6534                )?;
6535            }
6536
6537            writeln!(self.out, ") {{")?;
6538
6539            let guarded_indices =
6540                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6541
6542            let context = StatementContext {
6543                expression: ExpressionContext {
6544                    function: fun,
6545                    origin: FunctionOrigin::Handle(fun_handle),
6546                    info: fun_info,
6547                    lang_version: options.lang_version,
6548                    policies: options.bounds_check_policies,
6549                    guarded_indices,
6550                    module,
6551                    mod_info,
6552                    pipeline_options,
6553                    force_loop_bounding: options.force_loop_bounding,
6554                },
6555                result_struct: None,
6556            };
6557
6558            self.put_locals(&context.expression)?;
6559            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6560            self.put_block(back::Level(1), &fun.body, &context)?;
6561            writeln!(self.out, "}}")?;
6562            self.named_expressions.clear();
6563        }
6564
6565        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6566            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6567
6568        let mut info = TranslationInfo {
6569            entry_point_names: Vec::with_capacity(ep_range.len()),
6570        };
6571
6572        for ep_index in ep_range {
6573            let ep = &module.entry_points[ep_index];
6574            let fun = &ep.function;
6575            let fun_info = mod_info.get_entry_point(ep_index);
6576            let mut ep_error = None;
6577
6578            // For vertex_id and instance_id arguments, presume that we'll
6579            // use our generated names, but switch to the name of an
6580            // existing @builtin param, if we find one.
6581            let mut v_existing_id = None;
6582            let mut i_existing_id = None;
6583
6584            log::trace!(
6585                "entry point {:?}, index {:?}",
6586                fun.name.as_deref().unwrap_or("(anonymous)"),
6587                ep_index
6588            );
6589
6590            let ctx = back::FunctionCtx {
6591                ty: back::FunctionType::EntryPoint(ep_index as u16),
6592                info: fun_info,
6593                expressions: &fun.expressions,
6594                named_expressions: &fun.named_expressions,
6595            };
6596
6597            self.write_wrapped_functions(module, &ctx)?;
6598
6599            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6600                crate::ShaderStage::Vertex => (
6601                    "vertex",
6602                    LocationMode::VertexInput,
6603                    LocationMode::VertexOutput,
6604                    true,
6605                ),
6606                crate::ShaderStage::Fragment => (
6607                    "fragment",
6608                    LocationMode::FragmentInput,
6609                    LocationMode::FragmentOutput,
6610                    false,
6611                ),
6612                crate::ShaderStage::Compute => (
6613                    "kernel",
6614                    LocationMode::Uniform,
6615                    LocationMode::Uniform,
6616                    false,
6617                ),
6618                crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(),
6619            };
6620
6621            // Should this entry point be modified to do vertex pulling?
6622            let do_vertex_pulling = can_vertex_pull
6623                && pipeline_options.vertex_pulling_transform
6624                && !pipeline_options.vertex_buffer_mappings.is_empty();
6625
6626            // Is any global variable used by this entry point dynamically sized?
6627            let needs_buffer_sizes = do_vertex_pulling
6628                || module
6629                    .global_variables
6630                    .iter()
6631                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6632                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6633
6634            // skip this entry point if any global bindings are missing,
6635            // or their types are incompatible.
6636            if !options.fake_missing_bindings {
6637                for (var_handle, var) in module.global_variables.iter() {
6638                    if fun_info[var_handle].is_empty() {
6639                        continue;
6640                    }
6641                    match var.space {
6642                        crate::AddressSpace::Uniform
6643                        | crate::AddressSpace::Storage { .. }
6644                        | crate::AddressSpace::Handle => {
6645                            let br = match var.binding {
6646                                Some(ref br) => br,
6647                                None => {
6648                                    let var_name = var.name.clone().unwrap_or_default();
6649                                    ep_error =
6650                                        Some(super::EntryPointError::MissingBinding(var_name));
6651                                    break;
6652                                }
6653                            };
6654                            let target = options.get_resource_binding_target(ep, br);
6655                            let good = match target {
6656                                Some(target) => {
6657                                    // We intentionally don't dereference binding_arrays here,
6658                                    // so that binding arrays fall to the buffer location.
6659
6660                                    match module.types[var.ty].inner {
6661                                        crate::TypeInner::Image {
6662                                            class: crate::ImageClass::External,
6663                                            ..
6664                                        } => target.external_texture.is_some(),
6665                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6666                                        crate::TypeInner::Sampler { .. } => {
6667                                            target.sampler.is_some()
6668                                        }
6669                                        _ => target.buffer.is_some(),
6670                                    }
6671                                }
6672                                None => false,
6673                            };
6674                            if !good {
6675                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6676                                break;
6677                            }
6678                        }
6679                        crate::AddressSpace::PushConstant => {
6680                            if let Err(e) = options.resolve_push_constants(ep) {
6681                                ep_error = Some(e);
6682                                break;
6683                            }
6684                        }
6685                        crate::AddressSpace::Function
6686                        | crate::AddressSpace::Private
6687                        | crate::AddressSpace::WorkGroup => {}
6688                    }
6689                }
6690                if needs_buffer_sizes {
6691                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6692                        ep_error = Some(err);
6693                    }
6694                }
6695            }
6696
6697            if let Some(err) = ep_error {
6698                info.entry_point_names.push(Err(err));
6699                continue;
6700            }
6701            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6702            info.entry_point_names.push(Ok(fun_name.clone()));
6703
6704            writeln!(self.out)?;
6705
6706            // Since `Namer.reset` wasn't expecting struct members to be
6707            // suddenly injected into another namespace like this,
6708            // `self.names` doesn't keep them distinct from other variables.
6709            // Generate fresh names for these arguments, and remember the
6710            // mapping.
6711            let mut flattened_member_names = FastHashMap::default();
6712            // Varyings' members get their own namespace
6713            let mut varyings_namer = proc::Namer::default();
6714
6715            // List all the Naga `EntryPoint`'s `Function`'s arguments,
6716            // flattening structs into their members. In Metal, we will pass
6717            // each of these values to the entry point as a separate argument—
6718            // except for the varyings, handled next.
6719            let mut flattened_arguments = Vec::new();
6720            for (arg_index, arg) in fun.arguments.iter().enumerate() {
6721                match module.types[arg.ty].inner {
6722                    crate::TypeInner::Struct { ref members, .. } => {
6723                        for (member_index, member) in members.iter().enumerate() {
6724                            let member_index = member_index as u32;
6725                            flattened_arguments.push((
6726                                NameKey::StructMember(arg.ty, member_index),
6727                                member.ty,
6728                                member.binding.as_ref(),
6729                            ));
6730                            let name_key = NameKey::StructMember(arg.ty, member_index);
6731                            let name = match member.binding {
6732                                Some(crate::Binding::Location { .. }) => {
6733                                    if do_vertex_pulling {
6734                                        self.namer.call(&self.names[&name_key])
6735                                    } else {
6736                                        varyings_namer.call(&self.names[&name_key])
6737                                    }
6738                                }
6739                                _ => self.namer.call(&self.names[&name_key]),
6740                            };
6741                            flattened_member_names.insert(name_key, name);
6742                        }
6743                    }
6744                    _ => flattened_arguments.push((
6745                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6746                        arg.ty,
6747                        arg.binding.as_ref(),
6748                    )),
6749                }
6750            }
6751
6752            // Identify the varyings among the argument values, and maybe emit
6753            // a struct type named `<fun>Input` to hold them. If we are doing
6754            // vertex pulling, we instead update our attribute mapping to
6755            // note the types, names, and zero values of the attributes.
6756            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6757            let varyings_member_name = self.namer.call("varyings");
6758            let mut has_varyings = false;
6759            if !flattened_arguments.is_empty() {
6760                if !do_vertex_pulling {
6761                    writeln!(self.out, "struct {stage_in_name} {{")?;
6762                }
6763                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6764                    let (binding, location) = match binding {
6765                        Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6766                            (binding, location)
6767                        }
6768                        _ => continue,
6769                    };
6770                    let name = match *name_key {
6771                        NameKey::StructMember(..) => &flattened_member_names[name_key],
6772                        _ => &self.names[name_key],
6773                    };
6774                    let ty_name = TypeContext {
6775                        handle: ty,
6776                        gctx: module.to_ctx(),
6777                        names: &self.names,
6778                        access: crate::StorageAccess::empty(),
6779                        first_time: false,
6780                    };
6781                    let resolved = options.resolve_local_binding(binding, in_mode)?;
6782                    if do_vertex_pulling {
6783                        // Update our attribute mapping.
6784                        am_resolved.insert(
6785                            location,
6786                            AttributeMappingResolved {
6787                                ty_name: ty_name.to_string(),
6788                                dimension: ty_name.vertex_input_dimension(),
6789                                ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6790                                name: name.to_string(),
6791                            },
6792                        );
6793                    } else {
6794                        has_varyings = true;
6795                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6796                        resolved.try_fmt(&mut self.out)?;
6797                        writeln!(self.out, ";")?;
6798                    }
6799                }
6800                if !do_vertex_pulling {
6801                    writeln!(self.out, "}};")?;
6802                }
6803            }
6804
6805            // Define a struct type named for the return value, if any, named
6806            // `<fun>Output`.
6807            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6808            let result_member_name = self.namer.call("member");
6809            let result_type_name = match fun.result {
6810                Some(ref result) => {
6811                    let mut result_members = Vec::new();
6812                    if let crate::TypeInner::Struct { ref members, .. } =
6813                        module.types[result.ty].inner
6814                    {
6815                        for (member_index, member) in members.iter().enumerate() {
6816                            result_members.push((
6817                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6818                                member.ty,
6819                                member.binding.as_ref(),
6820                            ));
6821                        }
6822                    } else {
6823                        result_members.push((
6824                            &result_member_name,
6825                            result.ty,
6826                            result.binding.as_ref(),
6827                        ));
6828                    }
6829
6830                    writeln!(self.out, "struct {stage_out_name} {{")?;
6831                    let mut has_point_size = false;
6832                    for (name, ty, binding) in result_members {
6833                        let ty_name = TypeContext {
6834                            handle: ty,
6835                            gctx: module.to_ctx(),
6836                            names: &self.names,
6837                            access: crate::StorageAccess::empty(),
6838                            first_time: true,
6839                        };
6840                        let binding = binding.ok_or_else(|| {
6841                            Error::GenericValidation("Expected binding, got None".into())
6842                        })?;
6843
6844                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6845                            has_point_size = true;
6846                            if !pipeline_options.allow_and_force_point_size {
6847                                continue;
6848                            }
6849                        }
6850
6851                        let array_len = match module.types[ty].inner {
6852                            crate::TypeInner::Array {
6853                                size: crate::ArraySize::Constant(size),
6854                                ..
6855                            } => Some(size),
6856                            _ => None,
6857                        };
6858                        let resolved = options.resolve_local_binding(binding, out_mode)?;
6859                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6860                        if let Some(array_len) = array_len {
6861                            write!(self.out, " [{array_len}]")?;
6862                        }
6863                        resolved.try_fmt(&mut self.out)?;
6864                        writeln!(self.out, ";")?;
6865                    }
6866
6867                    if pipeline_options.allow_and_force_point_size
6868                        && ep.stage == crate::ShaderStage::Vertex
6869                        && !has_point_size
6870                    {
6871                        // inject the point size output last
6872                        writeln!(
6873                            self.out,
6874                            "{}float _point_size [[point_size]];",
6875                            back::INDENT
6876                        )?;
6877                    }
6878                    writeln!(self.out, "}};")?;
6879                    &stage_out_name
6880                }
6881                None => "void",
6882            };
6883
6884            // If we're doing a vertex pulling transform, define the buffer
6885            // structure types.
6886            if do_vertex_pulling {
6887                for vbm in &vbm_resolved {
6888                    let buffer_stride = vbm.stride;
6889                    let buffer_ty = &vbm.ty_name;
6890
6891                    // Define a structure of bytes of the appropriate size.
6892                    // When we access the attributes, we'll be unpacking these
6893                    // bytes at some offset.
6894                    writeln!(
6895                        self.out,
6896                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6897                    )?;
6898                }
6899            }
6900
6901            // Write the entry point function's name, and begin its argument list.
6902            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6903            let mut is_first_argument = true;
6904
6905            // If we have produced a struct holding the `EntryPoint`'s
6906            // `Function`'s arguments' varyings, pass that struct first.
6907            if has_varyings {
6908                writeln!(
6909                    self.out,
6910                    "  {stage_in_name} {varyings_member_name} [[stage_in]]"
6911                )?;
6912                is_first_argument = false;
6913            }
6914
6915            let mut local_invocation_id = None;
6916
6917            // Then pass the remaining arguments not included in the varyings
6918            // struct.
6919            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6920                let binding = match binding {
6921                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6922                    _ => continue,
6923                };
6924                let name = match *name_key {
6925                    NameKey::StructMember(..) => &flattened_member_names[name_key],
6926                    _ => &self.names[name_key],
6927                };
6928
6929                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6930                    local_invocation_id = Some(name_key);
6931                }
6932
6933                let ty_name = TypeContext {
6934                    handle: ty,
6935                    gctx: module.to_ctx(),
6936                    names: &self.names,
6937                    access: crate::StorageAccess::empty(),
6938                    first_time: false,
6939                };
6940
6941                match *binding {
6942                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
6943                        v_existing_id = Some(name.clone());
6944                    }
6945                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
6946                        i_existing_id = Some(name.clone());
6947                    }
6948                    _ => {}
6949                };
6950
6951                let resolved = options.resolve_local_binding(binding, in_mode)?;
6952                let separator = if is_first_argument {
6953                    is_first_argument = false;
6954                    ' '
6955                } else {
6956                    ','
6957                };
6958                write!(self.out, "{separator} {ty_name} {name}")?;
6959                resolved.try_fmt(&mut self.out)?;
6960                writeln!(self.out)?;
6961            }
6962
6963            let need_workgroup_variables_initialization =
6964                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
6965
6966            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
6967                let separator = if is_first_argument {
6968                    is_first_argument = false;
6969                    ' '
6970                } else {
6971                    ','
6972                };
6973                writeln!(
6974                    self.out,
6975                    "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
6976                )?;
6977            }
6978
6979            // Those global variables used by this entry point and its callees
6980            // get passed as arguments. `Private` globals are an exception, they
6981            // don't outlive this invocation, so we declare them below as locals
6982            // within the entry point.
6983            for (handle, var) in module.global_variables.iter() {
6984                let usage = fun_info[handle];
6985                if usage.is_empty() || var.space == crate::AddressSpace::Private {
6986                    continue;
6987                }
6988
6989                if options.lang_version < (1, 2) {
6990                    match var.space {
6991                        // This restriction is not documented in the MSL spec
6992                        // but validation will fail if it is not upheld.
6993                        //
6994                        // We infer the required version from the "Function
6995                        // Buffer Read-Writes" section of [what's new], where
6996                        // the feature sets listed correspond with the ones
6997                        // supporting MSL 1.2.
6998                        //
6999                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7000                        crate::AddressSpace::Storage { access }
7001                            if access.contains(crate::StorageAccess::STORE)
7002                                && ep.stage == crate::ShaderStage::Fragment =>
7003                        {
7004                            return Err(Error::UnsupportedWriteableStorageBuffer)
7005                        }
7006                        crate::AddressSpace::Handle => {
7007                            match module.types[var.ty].inner {
7008                                crate::TypeInner::Image {
7009                                    class: crate::ImageClass::Storage { access, .. },
7010                                    ..
7011                                } => {
7012                                    // This restriction is not documented in the MSL spec
7013                                    // but validation will fail if it is not upheld.
7014                                    //
7015                                    // We infer the required version from the "Function
7016                                    // Texture Read-Writes" section of [what's new], where
7017                                    // the feature sets listed correspond with the ones
7018                                    // supporting MSL 1.2.
7019                                    //
7020                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7021                                    if access.contains(crate::StorageAccess::STORE)
7022                                        && (ep.stage == crate::ShaderStage::Vertex
7023                                            || ep.stage == crate::ShaderStage::Fragment)
7024                                    {
7025                                        return Err(Error::UnsupportedWriteableStorageTexture(
7026                                            ep.stage,
7027                                        ));
7028                                    }
7029
7030                                    if access.contains(
7031                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7032                                    ) {
7033                                        return Err(Error::UnsupportedRWStorageTexture);
7034                                    }
7035                                }
7036                                _ => {}
7037                            }
7038                        }
7039                        _ => {}
7040                    }
7041                }
7042
7043                // Check min MSL version for binding arrays
7044                match var.space {
7045                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7046                        crate::TypeInner::BindingArray { base, .. } => {
7047                            match module.types[base].inner {
7048                                crate::TypeInner::Sampler { .. } => {
7049                                    if options.lang_version < (2, 0) {
7050                                        return Err(Error::UnsupportedArrayOf(
7051                                            "samplers".to_string(),
7052                                        ));
7053                                    }
7054                                }
7055                                crate::TypeInner::Image { class, .. } => match class {
7056                                    crate::ImageClass::Sampled { .. }
7057                                    | crate::ImageClass::Depth { .. }
7058                                    | crate::ImageClass::Storage {
7059                                        access: crate::StorageAccess::LOAD,
7060                                        ..
7061                                    } => {
7062                                        // Array of textures since:
7063                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7064                                        // - macOS: Metal 2
7065
7066                                        if options.lang_version < (2, 0) {
7067                                            return Err(Error::UnsupportedArrayOf(
7068                                                "textures".to_string(),
7069                                            ));
7070                                        }
7071                                    }
7072                                    crate::ImageClass::Storage {
7073                                        access: crate::StorageAccess::STORE,
7074                                        ..
7075                                    } => {
7076                                        // Array of write-only textures since:
7077                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7078                                        // - macOS: Metal 2
7079
7080                                        if options.lang_version < (2, 0) {
7081                                            return Err(Error::UnsupportedArrayOf(
7082                                                "write-only textures".to_string(),
7083                                            ));
7084                                        }
7085                                    }
7086                                    crate::ImageClass::Storage { .. } => {
7087                                        return Err(Error::UnsupportedArrayOf(
7088                                            "read-write textures".to_string(),
7089                                        ));
7090                                    }
7091                                    crate::ImageClass::External => {
7092                                        return Err(Error::UnsupportedArrayOf(
7093                                            "external textures".to_string(),
7094                                        ));
7095                                    }
7096                                },
7097                                _ => {
7098                                    return Err(Error::UnsupportedArrayOfType(base));
7099                                }
7100                            }
7101                        }
7102                        _ => {}
7103                    },
7104                    _ => {}
7105                }
7106
7107                // the resolves have already been checked for `!fake_missing_bindings` case
7108                let resolved = match var.space {
7109                    crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7110                    crate::AddressSpace::WorkGroup => None,
7111                    _ => options
7112                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7113                        .ok(),
7114                };
7115                if let Some(ref resolved) = resolved {
7116                    // Inline samplers are be defined in the EP body
7117                    if resolved.as_inline_sampler(options).is_some() {
7118                        continue;
7119                    }
7120                }
7121
7122                let mut separator = || {
7123                    if is_first_argument {
7124                        is_first_argument = false;
7125                        ' '
7126                    } else {
7127                        ','
7128                    }
7129                };
7130
7131                match module.types[var.ty].inner {
7132                    crate::TypeInner::Image {
7133                        class: crate::ImageClass::External,
7134                        ..
7135                    } => {
7136                        // External texture global variables get lowered to 3 textures
7137                        // and a constant buffer. We must emit a separate argument for
7138                        // each of these.
7139                        let target = match resolved {
7140                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7141                                target.external_texture
7142                            }
7143                            _ => None,
7144                        };
7145
7146                        for i in 0..3 {
7147                            write!(self.out, "{} ", separator())?;
7148
7149                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7150                                handle,
7151                                ExternalTextureNameKey::Plane(i),
7152                            )];
7153                            write!(
7154                              self.out,
7155                              "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7156                            )?;
7157                            if let Some(ref target) = target {
7158                                write!(self.out, " [[texture({})]]", target.planes[i])?;
7159                            }
7160                            writeln!(self.out)?;
7161                        }
7162                        let params_ty_name = &self.names
7163                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7164                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7165                            handle,
7166                            ExternalTextureNameKey::Params,
7167                        )];
7168                        write!(self.out, "{} ", separator())?;
7169                        write!(self.out, "constant {params_ty_name}& {params_name}")?;
7170                        if let Some(ref target) = target {
7171                            write!(self.out, " [[buffer({})]]", target.params)?;
7172                        }
7173                    }
7174                    _ => {
7175                        let tyvar = TypedGlobalVariable {
7176                            module,
7177                            names: &self.names,
7178                            handle,
7179                            usage,
7180                            reference: true,
7181                        };
7182                        write!(self.out, "{} ", separator())?;
7183                        tyvar.try_fmt(&mut self.out)?;
7184                        if let Some(resolved) = resolved {
7185                            resolved.try_fmt(&mut self.out)?;
7186                        }
7187                        if let Some(value) = var.init {
7188                            write!(self.out, " = ")?;
7189                            self.put_const_expression(
7190                                value,
7191                                module,
7192                                mod_info,
7193                                &module.global_expressions,
7194                            )?;
7195                        }
7196                    }
7197                }
7198                writeln!(self.out)?;
7199            }
7200
7201            if do_vertex_pulling {
7202                assert!(needs_vertex_id || needs_instance_id);
7203
7204                let mut separator = if is_first_argument {
7205                    is_first_argument = false;
7206                    ' '
7207                } else {
7208                    ','
7209                };
7210
7211                if needs_vertex_id && v_existing_id.is_none() {
7212                    // Write the [[vertex_id]] argument.
7213                    writeln!(self.out, "{separator} uint {v_id} [[vertex_id]]")?;
7214                    separator = ',';
7215                }
7216
7217                if needs_instance_id && i_existing_id.is_none() {
7218                    writeln!(self.out, "{separator} uint {i_id} [[instance_id]]")?;
7219                }
7220
7221                // Iterate vbm_resolved, output one argument for every vertex buffer,
7222                // using the names we generated earlier.
7223                for vbm in &vbm_resolved {
7224                    let id = &vbm.id;
7225                    let ty_name = &vbm.ty_name;
7226                    let param_name = &vbm.param_name;
7227                    writeln!(
7228                        self.out,
7229                        ", const device {ty_name}* {param_name} [[buffer({id})]]"
7230                    )?;
7231                }
7232            }
7233
7234            // If this entry uses any variable-length arrays, their sizes are
7235            // passed as a final struct-typed argument.
7236            if needs_buffer_sizes {
7237                // this is checked earlier
7238                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7239                let separator = if is_first_argument { ' ' } else { ',' };
7240                write!(
7241                    self.out,
7242                    "{separator} constant _mslBufferSizes& _buffer_sizes",
7243                )?;
7244                resolved.try_fmt(&mut self.out)?;
7245                writeln!(self.out)?;
7246            }
7247
7248            // end of the entry point argument list
7249            writeln!(self.out, ") {{")?;
7250
7251            // Starting the function body.
7252            if do_vertex_pulling {
7253                // Provide zero values for all the attributes, which we will overwrite with
7254                // real data from the vertex attribute buffers, if the indices are in-bounds.
7255                for vbm in &vbm_resolved {
7256                    for attribute in vbm.attributes {
7257                        let location = attribute.shader_location;
7258                        let am_option = am_resolved.get(&location);
7259                        if am_option.is_none() {
7260                            // This bound attribute isn't used in this entry point, so
7261                            // don't bother zero-initializing it.
7262                            continue;
7263                        }
7264                        let am = am_option.unwrap();
7265                        let attribute_ty_name = &am.ty_name;
7266                        let attribute_name = &am.name;
7267
7268                        writeln!(
7269                            self.out,
7270                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7271                            back::Level(1)
7272                        )?;
7273                    }
7274
7275                    // Output a bounds check block that will set real values for the
7276                    // attributes, if the bounds are satisfied.
7277                    write!(self.out, "{}if (", back::Level(1))?;
7278
7279                    let idx = &vbm.id;
7280                    let stride = &vbm.stride;
7281                    let index_name = if vbm.indexed_by_vertex {
7282                        if let Some(ref name) = v_existing_id {
7283                            name
7284                        } else {
7285                            &v_id
7286                        }
7287                    } else if let Some(ref name) = i_existing_id {
7288                        name
7289                    } else {
7290                        &i_id
7291                    };
7292                    write!(
7293                        self.out,
7294                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7295                    )?;
7296
7297                    writeln!(self.out, ") {{")?;
7298
7299                    // Pull the bytes out of the vertex buffer.
7300                    let ty_name = &vbm.ty_name;
7301                    let elem_name = &vbm.elem_name;
7302                    let param_name = &vbm.param_name;
7303
7304                    writeln!(
7305                        self.out,
7306                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7307                        back::Level(2),
7308                    )?;
7309
7310                    // Now set real values for each of the attributes, by unpacking the data
7311                    // from the buffer elements.
7312                    for attribute in vbm.attributes {
7313                        let location = attribute.shader_location;
7314                        let Some(am) = am_resolved.get(&location) else {
7315                            // This bound attribute isn't used in this entry point, so
7316                            // don't bother extracting the data. Too bad we emitted the
7317                            // unpacking function earlier -- it might not get used.
7318                            continue;
7319                        };
7320                        let attribute_name = &am.name;
7321                        let attribute_ty_name = &am.ty_name;
7322
7323                        let offset = attribute.offset;
7324                        let func = unpacking_functions
7325                            .get(&attribute.format)
7326                            .expect("Should have generated this unpacking function earlier.");
7327                        let func_name = &func.name;
7328
7329                        // Check dimensionality of the attribute compared to the unpacking
7330                        // function. If attribute dimension > unpack dimension, we have to
7331                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7332                        // scalar type. Otherwise, if attribute dimension is < unpack
7333                        // dimension, then we need to explicitly truncate the result.
7334
7335                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7336
7337                        if needs_padding_or_truncation != Ordering::Equal {
7338                            // Emit a comment flagging that a conversion is happening,
7339                            // since the actual logic can be at the end of a long line.
7340                            writeln!(
7341                                self.out,
7342                                "{}// {attribute_ty_name} <- {:?}",
7343                                back::Level(2),
7344                                attribute.format
7345                            )?;
7346                        }
7347
7348                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7349
7350                        if needs_padding_or_truncation == Ordering::Greater {
7351                            // Needs padding: emit constructor call for wider type
7352                            write!(self.out, "{attribute_ty_name}(")?;
7353                        }
7354
7355                        // Emit call to unpacking function
7356                        write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7357                        for i in (offset + 1)..(offset + func.byte_count) {
7358                            write!(self.out, ", {elem_name}.data[{i}]")?;
7359                        }
7360                        write!(self.out, ")")?;
7361
7362                        match needs_padding_or_truncation {
7363                            Ordering::Greater => {
7364                                // Padding
7365                                let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7366                                let one_value = if am.ty_is_int { "1" } else { "1.0" };
7367                                for i in func.dimension..am.dimension {
7368                                    write!(
7369                                        self.out,
7370                                        ", {}",
7371                                        if i == 3 { one_value } else { zero_value }
7372                                    )?;
7373                                }
7374                                write!(self.out, ")")?;
7375                            }
7376                            Ordering::Less => {
7377                                // Truncate to the first `am.dimension` components
7378                                write!(
7379                                    self.out,
7380                                    ".{}",
7381                                    &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7382                                )?;
7383                            }
7384                            Ordering::Equal => {}
7385                        }
7386
7387                        writeln!(self.out, ";")?;
7388                    }
7389
7390                    // End the bounds check / attribute setting block.
7391                    writeln!(self.out, "{}}}", back::Level(1))?;
7392                }
7393            }
7394
7395            if need_workgroup_variables_initialization {
7396                self.write_workgroup_variables_initialization(
7397                    module,
7398                    mod_info,
7399                    fun_info,
7400                    local_invocation_id,
7401                )?;
7402            }
7403
7404            // Metal doesn't support private mutable variables outside of functions,
7405            // so we put them here, just like the locals.
7406            for (handle, var) in module.global_variables.iter() {
7407                let usage = fun_info[handle];
7408                if usage.is_empty() {
7409                    continue;
7410                }
7411                if var.space == crate::AddressSpace::Private {
7412                    let tyvar = TypedGlobalVariable {
7413                        module,
7414                        names: &self.names,
7415                        handle,
7416                        usage,
7417
7418                        reference: false,
7419                    };
7420                    write!(self.out, "{}", back::INDENT)?;
7421                    tyvar.try_fmt(&mut self.out)?;
7422                    match var.init {
7423                        Some(value) => {
7424                            write!(self.out, " = ")?;
7425                            self.put_const_expression(
7426                                value,
7427                                module,
7428                                mod_info,
7429                                &module.global_expressions,
7430                            )?;
7431                            writeln!(self.out, ";")?;
7432                        }
7433                        None => {
7434                            writeln!(self.out, " = {{}};")?;
7435                        }
7436                    };
7437                } else if let Some(ref binding) = var.binding {
7438                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7439                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7440                        // write an inline sampler
7441                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7442                        writeln!(
7443                            self.out,
7444                            "{}constexpr {}::sampler {}(",
7445                            back::INDENT,
7446                            NAMESPACE,
7447                            name
7448                        )?;
7449                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7450                        writeln!(self.out, "{});", back::INDENT)?;
7451                    } else if let crate::TypeInner::Image {
7452                        class: crate::ImageClass::External,
7453                        ..
7454                    } = module.types[var.ty].inner
7455                    {
7456                        // Wrap the individual arguments for each external texture global
7457                        // in a struct which can be easily passed around.
7458                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7459                        let l1 = back::Level(1);
7460                        let l2 = l1.next();
7461                        writeln!(
7462                            self.out,
7463                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7464                        )?;
7465                        for i in 0..3 {
7466                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7467                                handle,
7468                                ExternalTextureNameKey::Plane(i),
7469                            )];
7470                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7471                        }
7472                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7473                            handle,
7474                            ExternalTextureNameKey::Params,
7475                        )];
7476                        writeln!(self.out, "{l2}.params = {params_name},")?;
7477                        writeln!(self.out, "{l1}}};")?;
7478                    }
7479                }
7480            }
7481
7482            // Now take the arguments that we gathered into structs, and the
7483            // structs that we flattened into arguments, and emit local
7484            // variables with initializers that put everything back the way the
7485            // body code expects.
7486            //
7487            // If we had to generate fresh names for struct members passed as
7488            // arguments, be sure to use those names when rebuilding the struct.
7489            //
7490            // "Each day, I change some zeros to ones, and some ones to zeros.
7491            // The rest, I leave alone."
7492            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7493                let arg_name =
7494                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7495                match module.types[arg.ty].inner {
7496                    crate::TypeInner::Struct { ref members, .. } => {
7497                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7498                        write!(
7499                            self.out,
7500                            "{}const {} {} = {{ ",
7501                            back::INDENT,
7502                            struct_name,
7503                            arg_name
7504                        )?;
7505                        for (member_index, member) in members.iter().enumerate() {
7506                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7507                            let name = &flattened_member_names[&key];
7508                            if member_index != 0 {
7509                                write!(self.out, ", ")?;
7510                            }
7511                            // insert padding initialization, if needed
7512                            if self
7513                                .struct_member_pads
7514                                .contains(&(arg.ty, member_index as u32))
7515                            {
7516                                write!(self.out, "{{}}, ")?;
7517                            }
7518                            if let Some(crate::Binding::Location { .. }) = member.binding {
7519                                if has_varyings {
7520                                    write!(self.out, "{varyings_member_name}.")?;
7521                                }
7522                            }
7523                            write!(self.out, "{name}")?;
7524                        }
7525                        writeln!(self.out, " }};")?;
7526                    }
7527                    _ => {
7528                        if let Some(crate::Binding::Location { .. }) = arg.binding {
7529                            if has_varyings {
7530                                writeln!(
7531                                    self.out,
7532                                    "{}const auto {} = {}.{};",
7533                                    back::INDENT,
7534                                    arg_name,
7535                                    varyings_member_name,
7536                                    arg_name
7537                                )?;
7538                            }
7539                        }
7540                    }
7541                }
7542            }
7543
7544            let guarded_indices =
7545                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7546
7547            let context = StatementContext {
7548                expression: ExpressionContext {
7549                    function: fun,
7550                    origin: FunctionOrigin::EntryPoint(ep_index as _),
7551                    info: fun_info,
7552                    lang_version: options.lang_version,
7553                    policies: options.bounds_check_policies,
7554                    guarded_indices,
7555                    module,
7556                    mod_info,
7557                    pipeline_options,
7558                    force_loop_bounding: options.force_loop_bounding,
7559                },
7560                result_struct: Some(&stage_out_name),
7561            };
7562
7563            // Finally, declare all the local variables that we need
7564            //TODO: we can postpone this till the relevant expressions are emitted
7565            self.put_locals(&context.expression)?;
7566            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7567            self.put_block(back::Level(1), &fun.body, &context)?;
7568            writeln!(self.out, "}}")?;
7569            if ep_index + 1 != module.entry_points.len() {
7570                writeln!(self.out)?;
7571            }
7572            self.named_expressions.clear();
7573        }
7574
7575        Ok(info)
7576    }
7577
7578    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7579        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
7580        // so we try to avoid it here.
7581        if flags.is_empty() {
7582            writeln!(
7583                self.out,
7584                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7585            )?;
7586        }
7587        if flags.contains(crate::Barrier::STORAGE) {
7588            writeln!(
7589                self.out,
7590                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7591            )?;
7592        }
7593        if flags.contains(crate::Barrier::WORK_GROUP) {
7594            writeln!(
7595                self.out,
7596                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7597            )?;
7598        }
7599        if flags.contains(crate::Barrier::SUB_GROUP) {
7600            writeln!(
7601                self.out,
7602                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7603            )?;
7604        }
7605        if flags.contains(crate::Barrier::TEXTURE) {
7606            writeln!(
7607                self.out,
7608                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7609            )?;
7610        }
7611        Ok(())
7612    }
7613}
7614
7615/// Initializing workgroup variables is more tricky for Metal because we have to deal
7616/// with atomics at the type-level (which don't have a copy constructor).
7617mod workgroup_mem_init {
7618    use crate::EntryPoint;
7619
7620    use super::*;
7621
7622    enum Access {
7623        GlobalVariable(Handle<crate::GlobalVariable>),
7624        StructMember(Handle<crate::Type>, u32),
7625        Array(usize),
7626    }
7627
7628    impl Access {
7629        fn write<W: Write>(
7630            &self,
7631            writer: &mut W,
7632            names: &FastHashMap<NameKey, String>,
7633        ) -> Result<(), core::fmt::Error> {
7634            match *self {
7635                Access::GlobalVariable(handle) => {
7636                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7637                }
7638                Access::StructMember(handle, index) => {
7639                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7640                }
7641                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7642            }
7643        }
7644    }
7645
7646    struct AccessStack {
7647        stack: Vec<Access>,
7648        array_depth: usize,
7649    }
7650
7651    impl AccessStack {
7652        const fn new() -> Self {
7653            Self {
7654                stack: Vec::new(),
7655                array_depth: 0,
7656            }
7657        }
7658
7659        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7660            let array_depth = self.array_depth;
7661            self.stack.push(Access::Array(array_depth));
7662            self.array_depth += 1;
7663            let res = cb(self, array_depth);
7664            self.stack.pop();
7665            self.array_depth -= 1;
7666            res
7667        }
7668
7669        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7670            self.stack.push(new);
7671            let res = cb(self);
7672            self.stack.pop();
7673            res
7674        }
7675
7676        fn write<W: Write>(
7677            &self,
7678            writer: &mut W,
7679            names: &FastHashMap<NameKey, String>,
7680        ) -> Result<(), core::fmt::Error> {
7681            for next in self.stack.iter() {
7682                next.write(writer, names)?;
7683            }
7684            Ok(())
7685        }
7686    }
7687
7688    impl<W: Write> Writer<W> {
7689        pub(super) fn need_workgroup_variables_initialization(
7690            &mut self,
7691            options: &Options,
7692            ep: &EntryPoint,
7693            module: &crate::Module,
7694            fun_info: &valid::FunctionInfo,
7695        ) -> bool {
7696            options.zero_initialize_workgroup_memory
7697                && ep.stage == crate::ShaderStage::Compute
7698                && module.global_variables.iter().any(|(handle, var)| {
7699                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7700                })
7701        }
7702
7703        pub(super) fn write_workgroup_variables_initialization(
7704            &mut self,
7705            module: &crate::Module,
7706            module_info: &valid::ModuleInfo,
7707            fun_info: &valid::FunctionInfo,
7708            local_invocation_id: Option<&NameKey>,
7709        ) -> BackendResult {
7710            let level = back::Level(1);
7711
7712            writeln!(
7713                self.out,
7714                "{}if ({}::all({} == {}::uint3(0u))) {{",
7715                level,
7716                NAMESPACE,
7717                local_invocation_id
7718                    .map(|name_key| self.names[name_key].as_str())
7719                    .unwrap_or("__local_invocation_id"),
7720                NAMESPACE,
7721            )?;
7722
7723            let mut access_stack = AccessStack::new();
7724
7725            let vars = module.global_variables.iter().filter(|&(handle, var)| {
7726                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7727            });
7728
7729            for (handle, var) in vars {
7730                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7731                    self.write_workgroup_variable_initialization(
7732                        module,
7733                        module_info,
7734                        var.ty,
7735                        access_stack,
7736                        level.next(),
7737                    )
7738                })?;
7739            }
7740
7741            writeln!(self.out, "{level}}}")?;
7742            self.write_barrier(crate::Barrier::WORK_GROUP, level)
7743        }
7744
7745        fn write_workgroup_variable_initialization(
7746            &mut self,
7747            module: &crate::Module,
7748            module_info: &valid::ModuleInfo,
7749            ty: Handle<crate::Type>,
7750            access_stack: &mut AccessStack,
7751            level: back::Level,
7752        ) -> BackendResult {
7753            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7754                write!(self.out, "{level}")?;
7755                access_stack.write(&mut self.out, &self.names)?;
7756                writeln!(self.out, " = {{}};")?;
7757            } else {
7758                match module.types[ty].inner {
7759                    crate::TypeInner::Atomic { .. } => {
7760                        write!(
7761                            self.out,
7762                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7763                        )?;
7764                        access_stack.write(&mut self.out, &self.names)?;
7765                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7766                    }
7767                    crate::TypeInner::Array { base, size, .. } => {
7768                        let count = match size.resolve(module.to_ctx())? {
7769                            proc::IndexableLength::Known(count) => count,
7770                            proc::IndexableLength::Dynamic => unreachable!(),
7771                        };
7772
7773                        access_stack.enter_array(|access_stack, array_depth| {
7774                            writeln!(
7775                                self.out,
7776                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7777                            )?;
7778                            self.write_workgroup_variable_initialization(
7779                                module,
7780                                module_info,
7781                                base,
7782                                access_stack,
7783                                level.next(),
7784                            )?;
7785                            writeln!(self.out, "{level}}}")?;
7786                            BackendResult::Ok(())
7787                        })?;
7788                    }
7789                    crate::TypeInner::Struct { ref members, .. } => {
7790                        for (index, member) in members.iter().enumerate() {
7791                            access_stack.enter(
7792                                Access::StructMember(ty, index as u32),
7793                                |access_stack| {
7794                                    self.write_workgroup_variable_initialization(
7795                                        module,
7796                                        module_info,
7797                                        member.ty,
7798                                        access_stack,
7799                                        level,
7800                                    )
7801                                },
7802                            )?;
7803                        }
7804                    }
7805                    _ => unreachable!(),
7806                }
7807            }
7808
7809            Ok(())
7810        }
7811    }
7812}
7813
7814impl crate::AtomicFunction {
7815    const fn to_msl(self) -> &'static str {
7816        match self {
7817            Self::Add => "fetch_add",
7818            Self::Subtract => "fetch_sub",
7819            Self::And => "fetch_and",
7820            Self::InclusiveOr => "fetch_or",
7821            Self::ExclusiveOr => "fetch_xor",
7822            Self::Min => "fetch_min",
7823            Self::Max => "fetch_max",
7824            Self::Exchange { compare: None } => "exchange",
7825            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7826        }
7827    }
7828
7829    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7830        Ok(match self {
7831            Self::Min => "min",
7832            Self::Max => "max",
7833            _ => Err(Error::FeatureNotImplemented(
7834                "64-bit atomic operation other than min/max".to_string(),
7835            ))?,
7836        })
7837    }
7838}