naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18    arena::{Handle, HandleSet},
19    back::{self, get_entry_points, Baked},
20    common,
21    proc::{
22        self,
23        index::{self, BoundsCheck},
24        ExternalTextureNameKey, NameKey, TypeResolution,
25    },
26    valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32/// Shorthand result used internally by the backend
33type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36// The name of the array member of the Metal struct types we generate to
37// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
38// for details.
39const WRAPPED_ARRAY_FIELD: &str = "inner";
40// This is a hack: we need to pass a pointer to an atomic,
41// but generally the backend isn't putting "&" in front of every pointer.
42// Some more general handling of pointers is needed to be implemented here.
43const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
50const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67    "nagaTextureSampleBaseClampToEdge";
68/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
69/// However, if you put that texture inside a struct, everything is totally fine. This
70/// baffles me to no end.
71///
72/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
73/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
74/// you have noticed that this should be exactly the same to the compiler, and you're correct.
75pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76/// Name of the struct that is declared to wrap the 3 textures and parameters
77/// buffer that [`crate::ImageClass::External`] variables are lowered to,
78/// allowing them to be conveniently passed to user-defined or wrapper
79/// functions. The struct is declared in [`Writer::write_type_defs`].
80pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81
82/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
83///
84/// The `sizes` slice determines whether this function writes a
85/// scalar, vector, or matrix type:
86///
87/// - An empty slice produces a scalar type.
88/// - A one-element slice produces a vector type.
89/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
90fn put_numeric_type(
91    out: &mut impl Write,
92    scalar: crate::Scalar,
93    sizes: &[crate::VectorSize],
94) -> Result<(), FmtError> {
95    match (scalar, sizes) {
96        (scalar, &[]) => {
97            write!(out, "{}", scalar.to_msl_name())
98        }
99        (scalar, &[rows]) => {
100            write!(
101                out,
102                "{}::{}{}",
103                NAMESPACE,
104                scalar.to_msl_name(),
105                common::vector_size_str(rows)
106            )
107        }
108        (scalar, &[rows, columns]) => {
109            write!(
110                out,
111                "{}::{}{}x{}",
112                NAMESPACE,
113                scalar.to_msl_name(),
114                common::vector_size_str(columns),
115                common::vector_size_str(rows)
116            )
117        }
118        (_, _) => Ok(()), // not meaningful
119    }
120}
121
122const fn scalar_is_int(scalar: crate::Scalar) -> bool {
123    use crate::ScalarKind::*;
124    match scalar.kind {
125        Sint | Uint | AbstractInt | Bool => true,
126        Float | AbstractFloat => false,
127    }
128}
129
130/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
131const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
132
133/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
134const REINTERPRET_PREFIX: &str = "reinterpreted_";
135
136/// Wrapper for identifier names for clamped level-of-detail values
137///
138/// Values of this type implement [`core::fmt::Display`], formatting as
139/// the name of the variable used to hold the cached clamped
140/// level-of-detail value for an `ImageLoad` expression.
141struct ClampedLod(Handle<crate::Expression>);
142
143impl Display for ClampedLod {
144    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
145        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
146    }
147}
148
149/// Wrapper for generating `struct _mslBufferSizes` member names for
150/// runtime-sized array lengths.
151///
152/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
153/// as an argument to the entry point. This argument's type in the MSL is
154/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
155/// each global variable containing a runtime-sized array.
156///
157/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
158/// runtime-sized array, then the value `ArraySize(global)` implements
159/// [`core::fmt::Display`], formatting as the name of the struct member carrying
160/// the number of elements in that runtime-sized array.
161///
162/// [`GlobalVariable`]: crate::GlobalVariable
163struct ArraySizeMember(Handle<crate::GlobalVariable>);
164
165impl Display for ArraySizeMember {
166    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167        self.0.write_prefixed(f, "size")
168    }
169}
170
171/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
172///
173/// Implements [`core::fmt::Display`], formatting as a name derived from
174/// `target_type` and the variable name of `orig`.
175#[derive(Clone, Copy)]
176struct Reinterpreted<'a> {
177    target_type: &'a str,
178    orig: Handle<crate::Expression>,
179}
180
181impl<'a> Reinterpreted<'a> {
182    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
183        Self { target_type, orig }
184    }
185}
186
187impl Display for Reinterpreted<'_> {
188    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
189        f.write_str(REINTERPRET_PREFIX)?;
190        f.write_str(self.target_type)?;
191        self.orig.write_prefixed(f, "_e")
192    }
193}
194
195struct TypeContext<'a> {
196    handle: Handle<crate::Type>,
197    gctx: proc::GlobalCtx<'a>,
198    names: &'a FastHashMap<NameKey, String>,
199    access: crate::StorageAccess,
200    first_time: bool,
201}
202
203impl TypeContext<'_> {
204    fn scalar(&self) -> Option<crate::Scalar> {
205        let ty = &self.gctx.types[self.handle];
206        ty.inner.scalar()
207    }
208
209    fn vertex_input_dimension(&self) -> u32 {
210        let ty = &self.gctx.types[self.handle];
211        match ty.inner {
212            crate::TypeInner::Scalar(_) => 1,
213            crate::TypeInner::Vector { size, .. } => size as u32,
214            _ => unreachable!(),
215        }
216    }
217}
218
219impl Display for TypeContext<'_> {
220    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221        let ty = &self.gctx.types[self.handle];
222        if ty.needs_alias() && !self.first_time {
223            let name = &self.names[&NameKey::Type(self.handle)];
224            return write!(out, "{name}");
225        }
226
227        match ty.inner {
228            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229            crate::TypeInner::Atomic(scalar) => {
230                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231            }
232            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233            crate::TypeInner::Matrix {
234                columns,
235                rows,
236                scalar,
237            } => put_numeric_type(out, scalar, &[rows, columns]),
238            crate::TypeInner::Pointer { base, space } => {
239                let sub = Self {
240                    handle: base,
241                    first_time: false,
242                    ..*self
243                };
244                let space_name = match space.to_msl_name() {
245                    Some(name) => name,
246                    None => return Ok(()),
247                };
248                write!(out, "{space_name} {sub}&")
249            }
250            crate::TypeInner::ValuePointer {
251                size,
252                scalar,
253                space,
254            } => {
255                match space.to_msl_name() {
256                    Some(name) => write!(out, "{name} ")?,
257                    None => return Ok(()),
258                };
259                match size {
260                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
261                    None => put_numeric_type(out, scalar, &[])?,
262                };
263
264                write!(out, "&")
265            }
266            crate::TypeInner::Array { base, .. } => {
267                let sub = Self {
268                    handle: base,
269                    first_time: false,
270                    ..*self
271                };
272                // Array lengths go at the end of the type definition,
273                // so just print the element type here.
274                write!(out, "{sub}")
275            }
276            crate::TypeInner::Struct { .. } => unreachable!(),
277            crate::TypeInner::Image {
278                dim,
279                arrayed,
280                class,
281            } => {
282                let dim_str = match dim {
283                    crate::ImageDimension::D1 => "1d",
284                    crate::ImageDimension::D2 => "2d",
285                    crate::ImageDimension::D3 => "3d",
286                    crate::ImageDimension::Cube => "cube",
287                };
288                let (texture_str, msaa_str, scalar, access) = match class {
289                    crate::ImageClass::Sampled { kind, multi } => {
290                        let (msaa_str, access) = if multi {
291                            ("_ms", "read")
292                        } else {
293                            ("", "sample")
294                        };
295                        let scalar = crate::Scalar { kind, width: 4 };
296                        ("texture", msaa_str, scalar, access)
297                    }
298                    crate::ImageClass::Depth { multi } => {
299                        let (msaa_str, access) = if multi {
300                            ("_ms", "read")
301                        } else {
302                            ("", "sample")
303                        };
304                        let scalar = crate::Scalar {
305                            kind: crate::ScalarKind::Float,
306                            width: 4,
307                        };
308                        ("depth", msaa_str, scalar, access)
309                    }
310                    crate::ImageClass::Storage { format, .. } => {
311                        let access = if self
312                            .access
313                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
314                        {
315                            "read_write"
316                        } else if self.access.contains(crate::StorageAccess::STORE) {
317                            "write"
318                        } else if self.access.contains(crate::StorageAccess::LOAD) {
319                            "read"
320                        } else {
321                            log::warn!(
322                                "Storage access for {:?} (name '{}'): {:?}",
323                                self.handle,
324                                ty.name.as_deref().unwrap_or_default(),
325                                self.access
326                            );
327                            unreachable!("module is not valid");
328                        };
329                        ("texture", "", format.into(), access)
330                    }
331                    crate::ImageClass::External => {
332                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
333                    }
334                };
335                let base_name = scalar.to_msl_name();
336                let array_str = if arrayed { "_array" } else { "" };
337                write!(
338                    out,
339                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
340                )
341            }
342            crate::TypeInner::Sampler { comparison: _ } => {
343                write!(out, "{NAMESPACE}::sampler")
344            }
345            crate::TypeInner::AccelerationStructure { vertex_return } => {
346                if vertex_return {
347                    unimplemented!("metal does not support vertex ray hit return")
348                }
349                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
350            }
351            crate::TypeInner::RayQuery { vertex_return } => {
352                if vertex_return {
353                    unimplemented!("metal does not support vertex ray hit return")
354                }
355                write!(out, "{RAY_QUERY_TYPE}")
356            }
357            crate::TypeInner::BindingArray { base, .. } => {
358                let base_tyname = Self {
359                    handle: base,
360                    first_time: false,
361                    ..*self
362                };
363
364                write!(
365                    out,
366                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
367                )
368            }
369        }
370    }
371}
372
373struct TypedGlobalVariable<'a> {
374    module: &'a crate::Module,
375    names: &'a FastHashMap<NameKey, String>,
376    handle: Handle<crate::GlobalVariable>,
377    usage: valid::GlobalUse,
378    reference: bool,
379}
380
381impl TypedGlobalVariable<'_> {
382    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
383        let var = &self.module.global_variables[self.handle];
384        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
385
386        let storage_access = match var.space {
387            crate::AddressSpace::Storage { access } => access,
388            _ => match self.module.types[var.ty].inner {
389                crate::TypeInner::Image {
390                    class: crate::ImageClass::Storage { access, .. },
391                    ..
392                } => access,
393                crate::TypeInner::BindingArray { base, .. } => {
394                    match self.module.types[base].inner {
395                        crate::TypeInner::Image {
396                            class: crate::ImageClass::Storage { access, .. },
397                            ..
398                        } => access,
399                        _ => crate::StorageAccess::default(),
400                    }
401                }
402                _ => crate::StorageAccess::default(),
403            },
404        };
405        let ty_name = TypeContext {
406            handle: var.ty,
407            gctx: self.module.to_ctx(),
408            names: self.names,
409            access: storage_access,
410            first_time: false,
411        };
412
413        let (space, access, reference) = match var.space.to_msl_name() {
414            Some(space) if self.reference => {
415                let access = if var.space.needs_access_qualifier()
416                    && !self.usage.intersects(valid::GlobalUse::WRITE)
417                {
418                    "const"
419                } else {
420                    ""
421                };
422                (space, access, "&")
423            }
424            _ => ("", "", ""),
425        };
426
427        Ok(write!(
428            out,
429            "{}{}{}{}{}{} {}",
430            space,
431            if space.is_empty() { "" } else { " " },
432            ty_name,
433            if access.is_empty() { "" } else { " " },
434            access,
435            reference,
436            name,
437        )?)
438    }
439}
440
441#[derive(Eq, PartialEq, Hash)]
442enum WrappedFunction {
443    UnaryOp {
444        op: crate::UnaryOperator,
445        ty: (Option<crate::VectorSize>, crate::Scalar),
446    },
447    BinaryOp {
448        op: crate::BinaryOperator,
449        left_ty: (Option<crate::VectorSize>, crate::Scalar),
450        right_ty: (Option<crate::VectorSize>, crate::Scalar),
451    },
452    Math {
453        fun: crate::MathFunction,
454        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
455    },
456    Cast {
457        src_scalar: crate::Scalar,
458        vector_size: Option<crate::VectorSize>,
459        dst_scalar: crate::Scalar,
460    },
461    ImageLoad {
462        class: crate::ImageClass,
463    },
464    ImageSample {
465        class: crate::ImageClass,
466        clamp_to_edge: bool,
467    },
468    ImageQuerySize {
469        class: crate::ImageClass,
470    },
471}
472
473pub struct Writer<W> {
474    out: W,
475    names: FastHashMap<NameKey, String>,
476    named_expressions: crate::NamedExpressions,
477    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
478    need_bake_expressions: back::NeedBakeExpressions,
479    namer: proc::Namer,
480    wrapped_functions: FastHashSet<WrappedFunction>,
481    #[cfg(test)]
482    put_expression_stack_pointers: FastHashSet<*const ()>,
483    #[cfg(test)]
484    put_block_stack_pointers: FastHashSet<*const ()>,
485    /// Set of (struct type, struct field index) denoting which fields require
486    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
487    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
488}
489
490impl crate::Scalar {
491    fn to_msl_name(self) -> &'static str {
492        use crate::ScalarKind as Sk;
493        match self {
494            Self {
495                kind: Sk::Float,
496                width: 4,
497            } => "float",
498            Self {
499                kind: Sk::Float,
500                width: 2,
501            } => "half",
502            Self {
503                kind: Sk::Sint,
504                width: 4,
505            } => "int",
506            Self {
507                kind: Sk::Uint,
508                width: 4,
509            } => "uint",
510            Self {
511                kind: Sk::Sint,
512                width: 8,
513            } => "long",
514            Self {
515                kind: Sk::Uint,
516                width: 8,
517            } => "ulong",
518            Self {
519                kind: Sk::Bool,
520                width: _,
521            } => "bool",
522            Self {
523                kind: Sk::AbstractInt | Sk::AbstractFloat,
524                width: _,
525            } => unreachable!("Found Abstract scalar kind"),
526            _ => unreachable!("Unsupported scalar kind: {:?}", self),
527        }
528    }
529}
530
531const fn separate(need_separator: bool) -> &'static str {
532    if need_separator {
533        ","
534    } else {
535        ""
536    }
537}
538
539fn should_pack_struct_member(
540    members: &[crate::StructMember],
541    span: u32,
542    index: usize,
543    module: &crate::Module,
544) -> Option<crate::Scalar> {
545    let member = &members[index];
546
547    let ty_inner = &module.types[member.ty].inner;
548    let last_offset = member.offset + ty_inner.size(module.to_ctx());
549    let next_offset = match members.get(index + 1) {
550        Some(next) => next.offset,
551        None => span,
552    };
553    let is_tight = next_offset == last_offset;
554
555    match *ty_inner {
556        crate::TypeInner::Vector {
557            size: crate::VectorSize::Tri,
558            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
559        } if is_tight => Some(scalar),
560        _ => None,
561    }
562}
563
564fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
565    match arena[ty].inner {
566        crate::TypeInner::Struct { ref members, .. } => {
567            if let Some(member) = members.last() {
568                if let crate::TypeInner::Array {
569                    size: crate::ArraySize::Dynamic,
570                    ..
571                } = arena[member.ty].inner
572                {
573                    return true;
574                }
575            }
576            false
577        }
578        crate::TypeInner::Array {
579            size: crate::ArraySize::Dynamic,
580            ..
581        } => true,
582        _ => false,
583    }
584}
585
586impl crate::AddressSpace {
587    /// Returns true if global variables in this address space are
588    /// passed in function arguments. These arguments need to be
589    /// passed through any functions called from the entry point.
590    const fn needs_pass_through(&self) -> bool {
591        match *self {
592            Self::Uniform
593            | Self::Storage { .. }
594            | Self::Private
595            | Self::WorkGroup
596            | Self::PushConstant
597            | Self::Handle
598            | Self::TaskPayload => true,
599            Self::Function => false,
600        }
601    }
602
603    /// Returns true if the address space may need a "const" qualifier.
604    const fn needs_access_qualifier(&self) -> bool {
605        match *self {
606            //Note: we are ignoring the storage access here, and instead
607            // rely on the actual use of a global by functions. This means we
608            // may end up with "const" even if the binding is read-write,
609            // and that should be OK.
610            Self::Storage { .. } => true,
611            Self::TaskPayload => unimplemented!(),
612            // These should always be read-write.
613            Self::Private | Self::WorkGroup => false,
614            // These translate to `constant` address space, no need for qualifiers.
615            Self::Uniform | Self::PushConstant => false,
616            // Not applicable.
617            Self::Handle | Self::Function => false,
618        }
619    }
620
621    const fn to_msl_name(self) -> Option<&'static str> {
622        match self {
623            Self::Handle => None,
624            Self::Uniform | Self::PushConstant => Some("constant"),
625            Self::Storage { .. } => Some("device"),
626            Self::Private | Self::Function => Some("thread"),
627            Self::WorkGroup => Some("threadgroup"),
628            Self::TaskPayload => Some("object_data"),
629        }
630    }
631}
632
633impl crate::Type {
634    // Returns `true` if we need to emit an alias for this type.
635    const fn needs_alias(&self) -> bool {
636        use crate::TypeInner as Ti;
637
638        match self.inner {
639            // value types are concise enough, we only alias them if they are named
640            Ti::Scalar(_)
641            | Ti::Vector { .. }
642            | Ti::Matrix { .. }
643            | Ti::Atomic(_)
644            | Ti::Pointer { .. }
645            | Ti::ValuePointer { .. } => self.name.is_some(),
646            // composite types are better to be aliased, regardless of the name
647            Ti::Struct { .. } | Ti::Array { .. } => true,
648            // handle types may be different, depending on the global var access, so we always inline them
649            Ti::Image { .. }
650            | Ti::Sampler { .. }
651            | Ti::AccelerationStructure { .. }
652            | Ti::RayQuery { .. }
653            | Ti::BindingArray { .. } => false,
654        }
655    }
656}
657
658#[derive(Clone, Copy)]
659enum FunctionOrigin {
660    Handle(Handle<crate::Function>),
661    EntryPoint(proc::EntryPointIndex),
662}
663
664trait NameKeyExt {
665    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
666        match origin {
667            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
668            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
669        }
670    }
671
672    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
673    /// policy when it needs to produce a pointer-typed result for an OOB access. These
674    /// are unique per accessed type, so the second argument is a type handle. See docs
675    /// for [`crate::back::msl`].
676    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
677        match origin {
678            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
679            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
680        }
681    }
682}
683
684impl NameKeyExt for NameKey {}
685
686/// A level of detail argument.
687///
688/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
689/// save the clamped level of detail in a temporary variable whose name is based
690/// on the handle of the `ImageLoad` expression. But for other policies, we just
691/// use the expression directly.
692///
693/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
694/// [`ImageLoad`]: crate::Expression::ImageLoad
695#[derive(Clone, Copy)]
696enum LevelOfDetail {
697    Direct(Handle<crate::Expression>),
698    Restricted(Handle<crate::Expression>),
699}
700
701/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
702///
703/// When this is used in code paths unconcerned with the `Restrict` bounds check
704/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
705/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
706/// be too awkward. If that changes, we can revisit.
707///
708/// [`ImageLoad`]: crate::Expression::ImageLoad
709/// [`ImageStore`]: crate::Statement::ImageStore
710struct TexelAddress {
711    coordinate: Handle<crate::Expression>,
712    array_index: Option<Handle<crate::Expression>>,
713    sample: Option<Handle<crate::Expression>>,
714    level: Option<LevelOfDetail>,
715}
716
717struct ExpressionContext<'a> {
718    function: &'a crate::Function,
719    origin: FunctionOrigin,
720    info: &'a valid::FunctionInfo,
721    module: &'a crate::Module,
722    mod_info: &'a valid::ModuleInfo,
723    pipeline_options: &'a PipelineOptions,
724    lang_version: (u8, u8),
725    policies: index::BoundsCheckPolicies,
726
727    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
728    /// accesses. These may need to be cached in temporary variables. See
729    /// `index::find_checked_indexes` for details.
730    guarded_indices: HandleSet<crate::Expression>,
731    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
732    force_loop_bounding: bool,
733}
734
735impl<'a> ExpressionContext<'a> {
736    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
737        self.info[handle].ty.inner_with(&self.module.types)
738    }
739
740    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
741    ///
742    /// Only mipmapped images need to specify a level of detail. Since 1D
743    /// textures cannot have mipmaps, MSL requires that the level argument to
744    /// texture1d queries and accesses must be a constexpr 0. It's easiest
745    /// just to omit the level entirely for 1D textures.
746    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
747        let image_ty = self.resolve_type(image);
748        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
749            class.is_mipmapped() && dim != crate::ImageDimension::D1
750        } else {
751            false
752        }
753    }
754
755    fn choose_bounds_check_policy(
756        &self,
757        pointer: Handle<crate::Expression>,
758    ) -> index::BoundsCheckPolicy {
759        self.policies
760            .choose_policy(pointer, &self.module.types, self.info)
761    }
762
763    /// See docs for [`proc::index::access_needs_check`].
764    fn access_needs_check(
765        &self,
766        base: Handle<crate::Expression>,
767        index: index::GuardedIndex,
768    ) -> Option<index::IndexableLength> {
769        index::access_needs_check(
770            base,
771            index,
772            self.module,
773            &self.function.expressions,
774            self.info,
775        )
776    }
777
778    /// See docs for [`proc::index::bounds_check_iter`].
779    fn bounds_check_iter(
780        &self,
781        chain: Handle<crate::Expression>,
782    ) -> impl Iterator<Item = BoundsCheck> + '_ {
783        index::bounds_check_iter(chain, self.module, self.function, self.info)
784    }
785
786    /// See docs for [`proc::index::oob_local_types`].
787    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
788        index::oob_local_types(self.module, self.function, self.info, self.policies)
789    }
790
791    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
792        match self.function.expressions[expr_handle] {
793            crate::Expression::AccessIndex { base, index } => {
794                let ty = match *self.resolve_type(base) {
795                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
796                    ref ty => ty,
797                };
798                match *ty {
799                    crate::TypeInner::Struct {
800                        ref members, span, ..
801                    } => should_pack_struct_member(members, span, index as usize, self.module),
802                    _ => None,
803                }
804            }
805            _ => None,
806        }
807    }
808}
809
810struct StatementContext<'a> {
811    expression: ExpressionContext<'a>,
812    result_struct: Option<&'a str>,
813}
814
815impl<W: Write> Writer<W> {
816    /// Creates a new `Writer` instance.
817    pub fn new(out: W) -> Self {
818        Writer {
819            out,
820            names: FastHashMap::default(),
821            named_expressions: Default::default(),
822            need_bake_expressions: Default::default(),
823            namer: proc::Namer::default(),
824            wrapped_functions: FastHashSet::default(),
825            #[cfg(test)]
826            put_expression_stack_pointers: Default::default(),
827            #[cfg(test)]
828            put_block_stack_pointers: Default::default(),
829            struct_member_pads: FastHashSet::default(),
830        }
831    }
832
833    /// Finishes writing and returns the output.
834    // See https://github.com/rust-lang/rust-clippy/issues/4979.
835    #[allow(clippy::missing_const_for_fn)]
836    pub fn finish(self) -> W {
837        self.out
838    }
839
840    /// Generates statements to be inserted immediately before and at the very
841    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
842    /// The 0th item of the returned tuple should be inserted immediately prior
843    /// to the loop and the 1st item should be inserted at the very start of
844    /// the loop body.
845    ///
846    /// # What is this trying to solve?
847    ///
848    /// In Metal Shading Language, an infinite loop has undefined behavior.
849    /// (This rule is inherited from C++14.) This means that, if the MSL
850    /// compiler determines that a given loop will never exit, it may assume
851    /// that it is never reached. It may thus assume that any conditions
852    /// sufficient to cause the loop to be reached must be false. Like many
853    /// optimizing compilers, MSL uses this kind of analysis to establish limits
854    /// on the range of values variables involved in those conditions might
855    /// hold.
856    ///
857    /// For example, suppose the MSL compiler sees the code:
858    ///
859    /// ```ignore
860    /// if (i >= 10) {
861    ///     while (true) { }
862    /// }
863    /// ```
864    ///
865    /// It will recognize that the `while` loop will never terminate, conclude
866    /// that it must be unreachable, and thus infer that, if this code is
867    /// reached, then `i < 10` at that point.
868    ///
869    /// Now suppose that, at some point where `i` has the same value as above,
870    /// the compiler sees the code:
871    ///
872    /// ```ignore
873    /// if (i < 10) {
874    ///     a[i] = 1;
875    /// }
876    /// ```
877    ///
878    /// Because the compiler is confident that `i < 10`, it will make the
879    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
880    ///
881    /// ```ignore
882    /// a[i] = 1;
883    /// ```
884    ///
885    /// If that `if` condition was injected by Naga to implement a bounds check,
886    /// the MSL compiler's optimizations could allow out-of-bounds array
887    /// accesses to occur.
888    ///
889    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
890    /// that a loop is infinite, so an attacker could craft a Naga module
891    /// containing an infinite loop protected by conditions that cause the Metal
892    /// compiler to remove bounds checks that Naga injected elsewhere in the
893    /// function.
894    ///
895    /// This rewrite could occur even if the conditional assignment appears
896    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
897    /// reached. This would allow the attacker to save the results of
898    /// unauthorized reads somewhere accessible before entering the infinite
899    /// loop. But even worse, the MSL compiler has been observed to simply
900    /// delete the infinite loop entirely, so that even code dominated by the
901    /// loop becomes reachable. This would make the attack even more flexible,
902    /// since shaders that would appear to never terminate would actually exit
903    /// nicely, after having stolen data from elsewhere in the GPU address
904    /// space.
905    ///
906    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
907    /// generates is infinite. One approach would be to add inline assembly to
908    /// each loop that is annotated as potentially branching out of the loop,
909    /// but which in fact generates no instructions. Unfortunately, inline
910    /// assembly is not handled correctly by some Metal device drivers.
911    ///
912    /// A previously used approach was to add the following code to the bottom
913    /// of every loop:
914    ///
915    /// ```ignore
916    /// if (volatile bool unpredictable = false; unpredictable)
917    ///     break;
918    /// ```
919    ///
920    /// Although the `if` condition will always be false in any real execution,
921    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
922    /// it must assume that the `break` might be reached, and hence that the
923    /// loop is not unbounded. This prevents the range analysis impact described
924    /// above. Unfortunately this prevented the compiler from making important,
925    /// and safe, optimizations such as loop unrolling and was observed to
926    /// significantly hurt performance.
927    ///
928    /// Our current approach declares a counter before every loop and
929    /// increments it every iteration, breaking after 2^64 iterations:
930    ///
931    /// ```ignore
932    /// uint2 loop_bound = uint2(0);
933    /// while (true) {
934    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
935    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
936    /// }
937    /// ```
938    ///
939    /// This convinces the compiler that the loop is finite and therefore may
940    /// execute, whilst at the same time allowing optimizations such as loop
941    /// unrolling. Furthermore the 64-bit counter is large enough it seems
942    /// implausible that it would affect the execution of any shader.
943    ///
944    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
945    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
946    fn gen_force_bounded_loop_statements(
947        &mut self,
948        level: back::Level,
949        context: &StatementContext,
950    ) -> Option<(String, String)> {
951        if !context.expression.force_loop_bounding {
952            return None;
953        }
954
955        let loop_bound_name = self.namer.call("loop_bound");
956        // Count down from u32::MAX rather than up from 0 to avoid hang on
957        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
958        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
959        let level = level.next();
960        let break_and_inc = format!(
961            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
962{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
963        );
964
965        Some((decl, break_and_inc))
966    }
967
968    fn put_call_parameters(
969        &mut self,
970        parameters: impl Iterator<Item = Handle<crate::Expression>>,
971        context: &ExpressionContext,
972    ) -> BackendResult {
973        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
974            writer.put_expression(expr, context, true)
975        })
976    }
977
978    fn put_call_parameters_impl<C, E>(
979        &mut self,
980        parameters: impl Iterator<Item = Handle<crate::Expression>>,
981        ctx: &C,
982        put_expression: E,
983    ) -> BackendResult
984    where
985        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
986    {
987        write!(self.out, "(")?;
988        for (i, handle) in parameters.enumerate() {
989            if i != 0 {
990                write!(self.out, ", ")?;
991            }
992            put_expression(self, ctx, handle)?;
993        }
994        write!(self.out, ")")?;
995        Ok(())
996    }
997
998    /// Writes the local variables of the given function, as well as any extra
999    /// out-of-bounds locals that are needed.
1000    ///
1001    /// The names of the OOB locals are also added to `self.names` at the same
1002    /// time.
1003    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1004        let oob_local_types = context.oob_local_types();
1005        for &ty in oob_local_types.iter() {
1006            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1007            self.names.insert(name_key, self.namer.call("oob"));
1008        }
1009
1010        for (name_key, ty, init) in context
1011            .function
1012            .local_variables
1013            .iter()
1014            .map(|(local_handle, local)| {
1015                let name_key = NameKey::local(context.origin, local_handle);
1016                (name_key, local.ty, local.init)
1017            })
1018            .chain(oob_local_types.iter().map(|&ty| {
1019                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1020                (name_key, ty, None)
1021            }))
1022        {
1023            let ty_name = TypeContext {
1024                handle: ty,
1025                gctx: context.module.to_ctx(),
1026                names: &self.names,
1027                access: crate::StorageAccess::empty(),
1028                first_time: false,
1029            };
1030            write!(
1031                self.out,
1032                "{}{} {}",
1033                back::INDENT,
1034                ty_name,
1035                self.names[&name_key]
1036            )?;
1037            match init {
1038                Some(value) => {
1039                    write!(self.out, " = ")?;
1040                    self.put_expression(value, context, true)?;
1041                }
1042                None => {
1043                    write!(self.out, " = {{}}")?;
1044                }
1045            };
1046            writeln!(self.out, ";")?;
1047        }
1048        Ok(())
1049    }
1050
1051    fn put_level_of_detail(
1052        &mut self,
1053        level: LevelOfDetail,
1054        context: &ExpressionContext,
1055    ) -> BackendResult {
1056        match level {
1057            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1058            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1059        }
1060        Ok(())
1061    }
1062
1063    fn put_image_query(
1064        &mut self,
1065        image: Handle<crate::Expression>,
1066        query: &str,
1067        level: Option<LevelOfDetail>,
1068        context: &ExpressionContext,
1069    ) -> BackendResult {
1070        self.put_expression(image, context, false)?;
1071        write!(self.out, ".get_{query}(")?;
1072        if let Some(level) = level {
1073            self.put_level_of_detail(level, context)?;
1074        }
1075        write!(self.out, ")")?;
1076        Ok(())
1077    }
1078
1079    fn put_image_size_query(
1080        &mut self,
1081        image: Handle<crate::Expression>,
1082        level: Option<LevelOfDetail>,
1083        kind: crate::ScalarKind,
1084        context: &ExpressionContext,
1085    ) -> BackendResult {
1086        if let crate::TypeInner::Image {
1087            class: crate::ImageClass::External,
1088            ..
1089        } = *context.resolve_type(image)
1090        {
1091            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1092            self.put_expression(image, context, true)?;
1093            write!(self.out, ")")?;
1094            return Ok(());
1095        }
1096
1097        //Note: MSL only has separate width/height/depth queries,
1098        // so compose the result of them.
1099        let dim = match *context.resolve_type(image) {
1100            crate::TypeInner::Image { dim, .. } => dim,
1101            ref other => unreachable!("Unexpected type {:?}", other),
1102        };
1103        let scalar = crate::Scalar { kind, width: 4 };
1104        let coordinate_type = scalar.to_msl_name();
1105        match dim {
1106            crate::ImageDimension::D1 => {
1107                // Since 1D textures never have mipmaps, MSL requires that the
1108                // `level` argument be a constexpr 0. It's simplest for us just
1109                // to pass `None` and omit the level entirely.
1110                if kind == crate::ScalarKind::Uint {
1111                    // No need to construct a vector. No cast needed.
1112                    self.put_image_query(image, "width", None, context)?;
1113                } else {
1114                    // There's no definition for `int` in the `metal` namespace.
1115                    write!(self.out, "int(")?;
1116                    self.put_image_query(image, "width", None, context)?;
1117                    write!(self.out, ")")?;
1118                }
1119            }
1120            crate::ImageDimension::D2 => {
1121                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1122                self.put_image_query(image, "width", level, context)?;
1123                write!(self.out, ", ")?;
1124                self.put_image_query(image, "height", level, context)?;
1125                write!(self.out, ")")?;
1126            }
1127            crate::ImageDimension::D3 => {
1128                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1129                self.put_image_query(image, "width", level, context)?;
1130                write!(self.out, ", ")?;
1131                self.put_image_query(image, "height", level, context)?;
1132                write!(self.out, ", ")?;
1133                self.put_image_query(image, "depth", level, context)?;
1134                write!(self.out, ")")?;
1135            }
1136            crate::ImageDimension::Cube => {
1137                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1138                self.put_image_query(image, "width", level, context)?;
1139                write!(self.out, ")")?;
1140            }
1141        }
1142        Ok(())
1143    }
1144
1145    fn put_cast_to_uint_scalar_or_vector(
1146        &mut self,
1147        expr: Handle<crate::Expression>,
1148        context: &ExpressionContext,
1149    ) -> BackendResult {
1150        // coordinates in IR are int, but Metal expects uint
1151        match *context.resolve_type(expr) {
1152            crate::TypeInner::Scalar(_) => {
1153                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1154            }
1155            crate::TypeInner::Vector { size, .. } => {
1156                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1157            }
1158            _ => {
1159                return Err(Error::GenericValidation(
1160                    "Invalid type for image coordinate".into(),
1161                ))
1162            }
1163        };
1164
1165        write!(self.out, "(")?;
1166        self.put_expression(expr, context, true)?;
1167        write!(self.out, ")")?;
1168        Ok(())
1169    }
1170
1171    fn put_image_sample_level(
1172        &mut self,
1173        image: Handle<crate::Expression>,
1174        level: crate::SampleLevel,
1175        context: &ExpressionContext,
1176    ) -> BackendResult {
1177        let has_levels = context.image_needs_lod(image);
1178        match level {
1179            crate::SampleLevel::Auto => {}
1180            crate::SampleLevel::Zero => {
1181                //TODO: do we support Zero on `Sampled` image classes?
1182            }
1183            _ if !has_levels => {
1184                log::warn!("1D image can't be sampled with level {level:?}");
1185            }
1186            crate::SampleLevel::Exact(h) => {
1187                write!(self.out, ", {NAMESPACE}::level(")?;
1188                self.put_expression(h, context, true)?;
1189                write!(self.out, ")")?;
1190            }
1191            crate::SampleLevel::Bias(h) => {
1192                write!(self.out, ", {NAMESPACE}::bias(")?;
1193                self.put_expression(h, context, true)?;
1194                write!(self.out, ")")?;
1195            }
1196            crate::SampleLevel::Gradient { x, y } => {
1197                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1198                self.put_expression(x, context, true)?;
1199                write!(self.out, ", ")?;
1200                self.put_expression(y, context, true)?;
1201                write!(self.out, ")")?;
1202            }
1203        }
1204        Ok(())
1205    }
1206
1207    fn put_image_coordinate_limits(
1208        &mut self,
1209        image: Handle<crate::Expression>,
1210        level: Option<LevelOfDetail>,
1211        context: &ExpressionContext,
1212    ) -> BackendResult {
1213        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1214        write!(self.out, " - 1")?;
1215        Ok(())
1216    }
1217
1218    /// General function for writing restricted image indexes.
1219    ///
1220    /// This is used to produce restricted mip levels, array indices, and sample
1221    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1222    /// [`Restrict`] bounds check policy.
1223    ///
1224    /// This function writes an expression of the form:
1225    ///
1226    /// ```ignore
1227    ///
1228    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1229    ///
1230    /// ```
1231    ///
1232    /// [`ImageLoad`]: crate::Expression::ImageLoad
1233    /// [`ImageStore`]: crate::Statement::ImageStore
1234    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1235    fn put_restricted_scalar_image_index(
1236        &mut self,
1237        image: Handle<crate::Expression>,
1238        index: Handle<crate::Expression>,
1239        limit_method: &str,
1240        context: &ExpressionContext,
1241    ) -> BackendResult {
1242        write!(self.out, "{NAMESPACE}::min(uint(")?;
1243        self.put_expression(index, context, true)?;
1244        write!(self.out, "), ")?;
1245        self.put_expression(image, context, false)?;
1246        write!(self.out, ".{limit_method}() - 1)")?;
1247        Ok(())
1248    }
1249
1250    fn put_restricted_texel_address(
1251        &mut self,
1252        image: Handle<crate::Expression>,
1253        address: &TexelAddress,
1254        context: &ExpressionContext,
1255    ) -> BackendResult {
1256        // Write the coordinate.
1257        write!(self.out, "{NAMESPACE}::min(")?;
1258        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1259        write!(self.out, ", ")?;
1260        self.put_image_coordinate_limits(image, address.level, context)?;
1261        write!(self.out, ")")?;
1262
1263        // Write the array index, if present.
1264        if let Some(array_index) = address.array_index {
1265            write!(self.out, ", ")?;
1266            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1267        }
1268
1269        // Write the sample index, if present.
1270        if let Some(sample) = address.sample {
1271            write!(self.out, ", ")?;
1272            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1273        }
1274
1275        // The level of detail should be clamped and cached by
1276        // `put_cache_restricted_level`, so we don't need to clamp it here.
1277        if let Some(level) = address.level {
1278            write!(self.out, ", ")?;
1279            self.put_level_of_detail(level, context)?;
1280        }
1281
1282        Ok(())
1283    }
1284
1285    /// Write an expression that is true if the given image access is in bounds.
1286    fn put_image_access_bounds_check(
1287        &mut self,
1288        image: Handle<crate::Expression>,
1289        address: &TexelAddress,
1290        context: &ExpressionContext,
1291    ) -> BackendResult {
1292        let mut conjunction = "";
1293
1294        // First, check the level of detail. Only if that is in bounds can we
1295        // use it to find the appropriate bounds for the coordinates.
1296        let level = if let Some(level) = address.level {
1297            write!(self.out, "uint(")?;
1298            self.put_level_of_detail(level, context)?;
1299            write!(self.out, ") < ")?;
1300            self.put_expression(image, context, true)?;
1301            write!(self.out, ".get_num_mip_levels()")?;
1302            conjunction = " && ";
1303            Some(level)
1304        } else {
1305            None
1306        };
1307
1308        // Check sample index, if present.
1309        if let Some(sample) = address.sample {
1310            write!(self.out, "uint(")?;
1311            self.put_expression(sample, context, true)?;
1312            write!(self.out, ") < ")?;
1313            self.put_expression(image, context, true)?;
1314            write!(self.out, ".get_num_samples()")?;
1315            conjunction = " && ";
1316        }
1317
1318        // Check array index, if present.
1319        if let Some(array_index) = address.array_index {
1320            write!(self.out, "{conjunction}uint(")?;
1321            self.put_expression(array_index, context, true)?;
1322            write!(self.out, ") < ")?;
1323            self.put_expression(image, context, true)?;
1324            write!(self.out, ".get_array_size()")?;
1325            conjunction = " && ";
1326        }
1327
1328        // Finally, check if the coordinates are within bounds.
1329        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1330            crate::TypeInner::Vector { .. } => true,
1331            _ => false,
1332        };
1333        write!(self.out, "{conjunction}")?;
1334        if coord_is_vector {
1335            write!(self.out, "{NAMESPACE}::all(")?;
1336        }
1337        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1338        write!(self.out, " < ")?;
1339        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1340        if coord_is_vector {
1341            write!(self.out, ")")?;
1342        }
1343
1344        Ok(())
1345    }
1346
1347    fn put_image_load(
1348        &mut self,
1349        load: Handle<crate::Expression>,
1350        image: Handle<crate::Expression>,
1351        mut address: TexelAddress,
1352        context: &ExpressionContext,
1353    ) -> BackendResult {
1354        if let crate::TypeInner::Image {
1355            class: crate::ImageClass::External,
1356            ..
1357        } = *context.resolve_type(image)
1358        {
1359            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1360            self.put_expression(image, context, true)?;
1361            write!(self.out, ", ")?;
1362            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1363            write!(self.out, ")")?;
1364            return Ok(());
1365        }
1366
1367        match context.policies.image_load {
1368            proc::BoundsCheckPolicy::Restrict => {
1369                // Use the cached restricted level of detail, if any. Omit the
1370                // level altogether for 1D textures.
1371                if address.level.is_some() {
1372                    address.level = if context.image_needs_lod(image) {
1373                        Some(LevelOfDetail::Restricted(load))
1374                    } else {
1375                        None
1376                    }
1377                }
1378
1379                self.put_expression(image, context, false)?;
1380                write!(self.out, ".read(")?;
1381                self.put_restricted_texel_address(image, &address, context)?;
1382                write!(self.out, ")")?;
1383            }
1384            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1385                write!(self.out, "(")?;
1386                self.put_image_access_bounds_check(image, &address, context)?;
1387                write!(self.out, " ? ")?;
1388                self.put_unchecked_image_load(image, &address, context)?;
1389                write!(self.out, ": DefaultConstructible())")?;
1390            }
1391            proc::BoundsCheckPolicy::Unchecked => {
1392                self.put_unchecked_image_load(image, &address, context)?;
1393            }
1394        }
1395
1396        Ok(())
1397    }
1398
1399    fn put_unchecked_image_load(
1400        &mut self,
1401        image: Handle<crate::Expression>,
1402        address: &TexelAddress,
1403        context: &ExpressionContext,
1404    ) -> BackendResult {
1405        self.put_expression(image, context, false)?;
1406        write!(self.out, ".read(")?;
1407        // coordinates in IR are int, but Metal expects uint
1408        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1409        if let Some(expr) = address.array_index {
1410            write!(self.out, ", ")?;
1411            self.put_expression(expr, context, true)?;
1412        }
1413        if let Some(sample) = address.sample {
1414            write!(self.out, ", ")?;
1415            self.put_expression(sample, context, true)?;
1416        }
1417        if let Some(level) = address.level {
1418            if context.image_needs_lod(image) {
1419                write!(self.out, ", ")?;
1420                self.put_level_of_detail(level, context)?;
1421            }
1422        }
1423        write!(self.out, ")")?;
1424
1425        Ok(())
1426    }
1427
1428    fn put_image_atomic(
1429        &mut self,
1430        level: back::Level,
1431        image: Handle<crate::Expression>,
1432        address: &TexelAddress,
1433        fun: crate::AtomicFunction,
1434        value: Handle<crate::Expression>,
1435        context: &StatementContext,
1436    ) -> BackendResult {
1437        write!(self.out, "{level}")?;
1438        self.put_expression(image, &context.expression, false)?;
1439        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1440            fun.to_msl_64_bit()?
1441        } else {
1442            fun.to_msl()
1443        };
1444        write!(self.out, ".atomic_{op}(")?;
1445        // coordinates in IR are int, but Metal expects uint
1446        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1447        write!(self.out, ", ")?;
1448        self.put_expression(value, &context.expression, true)?;
1449        writeln!(self.out, ");")?;
1450
1451        Ok(())
1452    }
1453
1454    fn put_image_store(
1455        &mut self,
1456        level: back::Level,
1457        image: Handle<crate::Expression>,
1458        address: &TexelAddress,
1459        value: Handle<crate::Expression>,
1460        context: &StatementContext,
1461    ) -> BackendResult {
1462        write!(self.out, "{level}")?;
1463        self.put_expression(image, &context.expression, false)?;
1464        write!(self.out, ".write(")?;
1465        self.put_expression(value, &context.expression, true)?;
1466        write!(self.out, ", ")?;
1467        // coordinates in IR are int, but Metal expects uint
1468        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1469        if let Some(expr) = address.array_index {
1470            write!(self.out, ", ")?;
1471            self.put_expression(expr, &context.expression, true)?;
1472        }
1473        writeln!(self.out, ");")?;
1474
1475        Ok(())
1476    }
1477
1478    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1479    ///
1480    /// The 'maximum valid index' is simply one less than the array's length.
1481    ///
1482    /// This emits an expression of the form `a / b`, so the caller must
1483    /// parenthesize its output if it will be applying operators of higher
1484    /// precedence.
1485    ///
1486    /// `handle` must be the handle of a global variable whose final member is a
1487    /// dynamically sized array.
1488    fn put_dynamic_array_max_index(
1489        &mut self,
1490        handle: Handle<crate::GlobalVariable>,
1491        context: &ExpressionContext,
1492    ) -> BackendResult {
1493        let global = &context.module.global_variables[handle];
1494        let (offset, array_ty) = match context.module.types[global.ty].inner {
1495            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1496                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1497                None => return Err(Error::GenericValidation("Struct has no members".into())),
1498            },
1499            crate::TypeInner::Array {
1500                size: crate::ArraySize::Dynamic,
1501                ..
1502            } => (0, global.ty),
1503            ref ty => {
1504                return Err(Error::GenericValidation(format!(
1505                    "Expected type with dynamic array, got {ty:?}"
1506                )))
1507            }
1508        };
1509
1510        let (size, stride) = match context.module.types[array_ty].inner {
1511            crate::TypeInner::Array { base, stride, .. } => (
1512                context.module.types[base]
1513                    .inner
1514                    .size(context.module.to_ctx()),
1515                stride,
1516            ),
1517            ref ty => {
1518                return Err(Error::GenericValidation(format!(
1519                    "Expected array type, got {ty:?}"
1520                )))
1521            }
1522        };
1523
1524        // When the stride length is larger than the size, the final element's stride of
1525        // bytes would have padding following the value. But the buffer size in
1526        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1527        // enough to hold the actual values' bytes.
1528        //
1529        // So subtract off the size to get a byte size that falls at the start or within
1530        // the final element. Then divide by the stride size, to get one less than the
1531        // length, and then add one. This works even if the buffer size does include the
1532        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1533        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1534        // rules, together with draw-time validation when `minBindingSize` is zero,
1535        // prevent that.
1536        write!(
1537            self.out,
1538            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1539            member = ArraySizeMember(handle),
1540            offset = offset,
1541            size = size,
1542            stride = stride,
1543        )?;
1544        Ok(())
1545    }
1546
1547    /// Emit code for the arithmetic expression of the dot product.
1548    ///
1549    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1550    /// an index. It writes out the expression for the vector component at that index.
1551    fn put_dot_product<T: Copy>(
1552        &mut self,
1553        arg: T,
1554        arg1: T,
1555        size: usize,
1556        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1557    ) -> BackendResult {
1558        // Write parentheses around the dot product expression to prevent operators
1559        // with different precedences from applying earlier.
1560        write!(self.out, "(")?;
1561
1562        // Cycle through all the components of the vector
1563        for index in 0..size {
1564            // Write the addition to the previous product
1565            // This will print an extra '+' at the beginning but that is fine in msl
1566            write!(self.out, " + ")?;
1567            extractor(self, arg, index)?;
1568            write!(self.out, " * ")?;
1569            extractor(self, arg1, index)?;
1570        }
1571
1572        write!(self.out, ")")?;
1573        Ok(())
1574    }
1575
1576    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1577    fn put_pack4x8(
1578        &mut self,
1579        arg: Handle<crate::Expression>,
1580        context: &ExpressionContext<'_>,
1581        was_signed: bool,
1582        clamp_bounds: Option<(&str, &str)>,
1583    ) -> Result<(), Error> {
1584        let write_arg = |this: &mut Self| -> BackendResult {
1585            if let Some((min, max)) = clamp_bounds {
1586                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1587                write!(this.out, "{NAMESPACE}::clamp(")?;
1588                this.put_expression(arg, context, true)?;
1589                write!(this.out, ", {min}, {max})")?;
1590            } else {
1591                this.put_expression(arg, context, true)?;
1592            }
1593            Ok(())
1594        };
1595
1596        if context.lang_version >= (2, 1) {
1597            let packed_type = if was_signed {
1598                "packed_char4"
1599            } else {
1600                "packed_uchar4"
1601            };
1602            // Metal uses little endian byte order, which matches what WGSL expects here.
1603            write!(self.out, "as_type<uint>({packed_type}(")?;
1604            write_arg(self)?;
1605            write!(self.out, "))")?;
1606        } else {
1607            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1608            if was_signed {
1609                write!(self.out, "uint(")?;
1610            }
1611            write!(self.out, "(")?;
1612            write_arg(self)?;
1613            write!(self.out, "[0] & 0xFF) | ((")?;
1614            write_arg(self)?;
1615            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1616            write_arg(self)?;
1617            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1618            write_arg(self)?;
1619            write!(self.out, "[3] & 0xFF) << 24)")?;
1620            if was_signed {
1621                write!(self.out, ")")?;
1622            }
1623        }
1624
1625        Ok(())
1626    }
1627
1628    /// Emit code for the isign expression.
1629    ///
1630    fn put_isign(
1631        &mut self,
1632        arg: Handle<crate::Expression>,
1633        context: &ExpressionContext,
1634    ) -> BackendResult {
1635        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1636        let scalar = context
1637            .resolve_type(arg)
1638            .scalar()
1639            .expect("put_isign should only be called for args which have an integer scalar type")
1640            .to_msl_name();
1641        match context.resolve_type(arg) {
1642            &crate::TypeInner::Vector { size, .. } => {
1643                let size = common::vector_size_str(size);
1644                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1645            }
1646            _ => {
1647                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1648            }
1649        }
1650        write!(self.out, ", (")?;
1651        self.put_expression(arg, context, true)?;
1652        write!(self.out, " > 0)), {scalar}(0), (")?;
1653        self.put_expression(arg, context, true)?;
1654        write!(self.out, " == 0))")?;
1655        Ok(())
1656    }
1657
1658    fn put_const_expression(
1659        &mut self,
1660        expr_handle: Handle<crate::Expression>,
1661        module: &crate::Module,
1662        mod_info: &valid::ModuleInfo,
1663        arena: &crate::Arena<crate::Expression>,
1664    ) -> BackendResult {
1665        self.put_possibly_const_expression(
1666            expr_handle,
1667            arena,
1668            module,
1669            mod_info,
1670            &(module, mod_info),
1671            |&(_, mod_info), expr| &mod_info[expr],
1672            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1673        )
1674    }
1675
1676    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1677        match literal {
1678            crate::Literal::F64(_) => {
1679                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1680            }
1681            crate::Literal::F16(value) => {
1682                if value.is_infinite() {
1683                    let sign = if value.is_sign_negative() { "-" } else { "" };
1684                    write!(self.out, "{sign}INFINITY")?;
1685                } else if value.is_nan() {
1686                    write!(self.out, "NAN")?;
1687                } else {
1688                    let suffix = if value.fract() == f16::from_f32(0.0) {
1689                        ".0h"
1690                    } else {
1691                        "h"
1692                    };
1693                    write!(self.out, "{value}{suffix}")?;
1694                }
1695            }
1696            crate::Literal::F32(value) => {
1697                if value.is_infinite() {
1698                    let sign = if value.is_sign_negative() { "-" } else { "" };
1699                    write!(self.out, "{sign}INFINITY")?;
1700                } else if value.is_nan() {
1701                    write!(self.out, "NAN")?;
1702                } else {
1703                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1704                    write!(self.out, "{value}{suffix}")?;
1705                }
1706            }
1707            crate::Literal::U32(value) => {
1708                write!(self.out, "{value}u")?;
1709            }
1710            crate::Literal::I32(value) => {
1711                // `-2147483648` is parsed as unary negation of positive 2147483648.
1712                // 2147483648 is too large for int32_t meaning the expression gets
1713                // promoted to a int64_t which is not our intention. Avoid this by instead
1714                // using `-2147483647 - 1`.
1715                if value == i32::MIN {
1716                    write!(self.out, "({} - 1)", value + 1)?;
1717                } else {
1718                    write!(self.out, "{value}")?;
1719                }
1720            }
1721            crate::Literal::U64(value) => {
1722                write!(self.out, "{value}uL")?;
1723            }
1724            crate::Literal::I64(value) => {
1725                // `-9223372036854775808` is parsed as unary negation of positive
1726                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1727                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1728                // value to `-9223372036854775808`. Which would then be negated, possibly
1729                // causing undefined behaviour. Avoid this by instead using
1730                // `-9223372036854775808L - 1L`.
1731                if value == i64::MIN {
1732                    write!(self.out, "({}L - 1L)", value + 1)?;
1733                } else {
1734                    write!(self.out, "{value}L")?;
1735                }
1736            }
1737            crate::Literal::Bool(value) => {
1738                write!(self.out, "{value}")?;
1739            }
1740            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1741                return Err(Error::GenericValidation(
1742                    "Unsupported abstract literal".into(),
1743                ));
1744            }
1745        }
1746        Ok(())
1747    }
1748
1749    #[allow(clippy::too_many_arguments)]
1750    fn put_possibly_const_expression<C, I, E>(
1751        &mut self,
1752        expr_handle: Handle<crate::Expression>,
1753        expressions: &crate::Arena<crate::Expression>,
1754        module: &crate::Module,
1755        mod_info: &valid::ModuleInfo,
1756        ctx: &C,
1757        get_expr_ty: I,
1758        put_expression: E,
1759    ) -> BackendResult
1760    where
1761        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1762        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1763    {
1764        match expressions[expr_handle] {
1765            crate::Expression::Literal(literal) => {
1766                self.put_literal(literal)?;
1767            }
1768            crate::Expression::Constant(handle) => {
1769                let constant = &module.constants[handle];
1770                if constant.name.is_some() {
1771                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1772                } else {
1773                    self.put_const_expression(
1774                        constant.init,
1775                        module,
1776                        mod_info,
1777                        &module.global_expressions,
1778                    )?;
1779                }
1780            }
1781            crate::Expression::ZeroValue(ty) => {
1782                let ty_name = TypeContext {
1783                    handle: ty,
1784                    gctx: module.to_ctx(),
1785                    names: &self.names,
1786                    access: crate::StorageAccess::empty(),
1787                    first_time: false,
1788                };
1789                write!(self.out, "{ty_name} {{}}")?;
1790            }
1791            crate::Expression::Compose { ty, ref components } => {
1792                let ty_name = TypeContext {
1793                    handle: ty,
1794                    gctx: module.to_ctx(),
1795                    names: &self.names,
1796                    access: crate::StorageAccess::empty(),
1797                    first_time: false,
1798                };
1799                write!(self.out, "{ty_name}")?;
1800                match module.types[ty].inner {
1801                    crate::TypeInner::Scalar(_)
1802                    | crate::TypeInner::Vector { .. }
1803                    | crate::TypeInner::Matrix { .. } => {
1804                        self.put_call_parameters_impl(
1805                            components.iter().copied(),
1806                            ctx,
1807                            put_expression,
1808                        )?;
1809                    }
1810                    crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1811                        write!(self.out, " {{")?;
1812                        for (index, &component) in components.iter().enumerate() {
1813                            if index != 0 {
1814                                write!(self.out, ", ")?;
1815                            }
1816                            // insert padding initialization, if needed
1817                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1818                                write!(self.out, "{{}}, ")?;
1819                            }
1820                            put_expression(self, ctx, component)?;
1821                        }
1822                        write!(self.out, "}}")?;
1823                    }
1824                    _ => return Err(Error::UnsupportedCompose(ty)),
1825                }
1826            }
1827            crate::Expression::Splat { size, value } => {
1828                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1829                    crate::TypeInner::Scalar(scalar) => scalar,
1830                    ref ty => {
1831                        return Err(Error::GenericValidation(format!(
1832                            "Expected splat value type must be a scalar, got {ty:?}",
1833                        )))
1834                    }
1835                };
1836                put_numeric_type(&mut self.out, scalar, &[size])?;
1837                write!(self.out, "(")?;
1838                put_expression(self, ctx, value)?;
1839                write!(self.out, ")")?;
1840            }
1841            _ => {
1842                return Err(Error::Override);
1843            }
1844        }
1845
1846        Ok(())
1847    }
1848
1849    /// Emit code for the expression `expr_handle`.
1850    ///
1851    /// The `is_scoped` argument is true if the surrounding operators have the
1852    /// precedence of the comma operator, or lower. So, for example:
1853    ///
1854    /// - Pass `true` for `is_scoped` when writing function arguments, an
1855    ///   expression statement, an initializer expression, or anything already
1856    ///   wrapped in parenthesis.
1857    ///
1858    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1859    ///   almost anything else.
1860    fn put_expression(
1861        &mut self,
1862        expr_handle: Handle<crate::Expression>,
1863        context: &ExpressionContext,
1864        is_scoped: bool,
1865    ) -> BackendResult {
1866        // Add to the set in order to track the stack size.
1867        #[cfg(test)]
1868        self.put_expression_stack_pointers
1869            .insert(ptr::from_ref(&expr_handle).cast());
1870
1871        if let Some(name) = self.named_expressions.get(&expr_handle) {
1872            write!(self.out, "{name}")?;
1873            return Ok(());
1874        }
1875
1876        let expression = &context.function.expressions[expr_handle];
1877        match *expression {
1878            crate::Expression::Literal(_)
1879            | crate::Expression::Constant(_)
1880            | crate::Expression::ZeroValue(_)
1881            | crate::Expression::Compose { .. }
1882            | crate::Expression::Splat { .. } => {
1883                self.put_possibly_const_expression(
1884                    expr_handle,
1885                    &context.function.expressions,
1886                    context.module,
1887                    context.mod_info,
1888                    context,
1889                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1890                    |writer, context, expr| writer.put_expression(expr, context, true),
1891                )?;
1892            }
1893            crate::Expression::Override(_) => return Err(Error::Override),
1894            crate::Expression::Access { base, .. }
1895            | crate::Expression::AccessIndex { base, .. } => {
1896                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1897                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1898                // access chain at a time, recursing back through `put_expression` only
1899                // for index expressions and the base object, we will never see intermediate
1900                // `Access` or `AccessIndex` expressions here.
1901                let policy = context.choose_bounds_check_policy(base);
1902                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1903                    && self.put_bounds_checks(
1904                        expr_handle,
1905                        context,
1906                        back::Level(0),
1907                        if is_scoped { "" } else { "(" },
1908                    )?
1909                {
1910                    write!(self.out, " ? ")?;
1911                    self.put_access_chain(expr_handle, policy, context)?;
1912                    write!(self.out, " : ")?;
1913
1914                    if context.resolve_type(base).pointer_space().is_some() {
1915                        // We can't just use `DefaultConstructible` if this is a pointer.
1916                        // Instead, we create a dummy local variable to serve as pointer
1917                        // target if the access is out of bounds.
1918                        let result_ty = context.info[expr_handle]
1919                            .ty
1920                            .inner_with(&context.module.types)
1921                            .pointer_base_type();
1922                        let result_ty_handle = match result_ty {
1923                            Some(TypeResolution::Handle(handle)) => handle,
1924                            Some(TypeResolution::Value(_)) => {
1925                                // As long as the result of a pointer access expression is
1926                                // passed to a function or stored in a let binding, the
1927                                // type will be in the arena. If additional uses of
1928                                // pointers become valid, this assumption might no longer
1929                                // hold. Note that the LHS of a load or store doesn't
1930                                // take this path -- there is dedicated code in `put_load`
1931                                // and `put_store`.
1932                                unreachable!(
1933                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1934                                );
1935                            }
1936                            None => {
1937                                unreachable!(
1938                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1939                                )
1940                            }
1941                        };
1942                        let name_key =
1943                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
1944                        self.out.write_str(&self.names[&name_key])?;
1945                    } else {
1946                        write!(self.out, "DefaultConstructible()")?;
1947                    }
1948
1949                    if !is_scoped {
1950                        write!(self.out, ")")?;
1951                    }
1952                } else {
1953                    self.put_access_chain(expr_handle, policy, context)?;
1954                }
1955            }
1956            crate::Expression::Swizzle {
1957                size,
1958                vector,
1959                pattern,
1960            } => {
1961                self.put_wrapped_expression_for_packed_vec3_access(
1962                    vector,
1963                    context,
1964                    false,
1965                    &Self::put_expression,
1966                )?;
1967                write!(self.out, ".")?;
1968                for &sc in pattern[..size as usize].iter() {
1969                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1970                }
1971            }
1972            crate::Expression::FunctionArgument(index) => {
1973                let name_key = match context.origin {
1974                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1975                    FunctionOrigin::EntryPoint(ep_index) => {
1976                        NameKey::EntryPointArgument(ep_index, index)
1977                    }
1978                };
1979                let name = &self.names[&name_key];
1980                write!(self.out, "{name}")?;
1981            }
1982            crate::Expression::GlobalVariable(handle) => {
1983                let name = &self.names[&NameKey::GlobalVariable(handle)];
1984                write!(self.out, "{name}")?;
1985            }
1986            crate::Expression::LocalVariable(handle) => {
1987                let name_key = NameKey::local(context.origin, handle);
1988                let name = &self.names[&name_key];
1989                write!(self.out, "{name}")?;
1990            }
1991            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1992            crate::Expression::ImageSample {
1993                coordinate,
1994                image,
1995                sampler,
1996                clamp_to_edge: true,
1997                gather: None,
1998                array_index: None,
1999                offset: None,
2000                level: crate::SampleLevel::Zero,
2001                depth_ref: None,
2002            } => {
2003                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2004                self.put_expression(image, context, true)?;
2005                write!(self.out, ", ")?;
2006                self.put_expression(sampler, context, true)?;
2007                write!(self.out, ", ")?;
2008                self.put_expression(coordinate, context, true)?;
2009                write!(self.out, ")")?;
2010            }
2011            crate::Expression::ImageSample {
2012                image,
2013                sampler,
2014                gather,
2015                coordinate,
2016                array_index,
2017                offset,
2018                level,
2019                depth_ref,
2020                clamp_to_edge,
2021            } => {
2022                if clamp_to_edge {
2023                    return Err(Error::GenericValidation(
2024                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2025                    ));
2026                }
2027
2028                let main_op = match gather {
2029                    Some(_) => "gather",
2030                    None => "sample",
2031                };
2032                let comparison_op = match depth_ref {
2033                    Some(_) => "_compare",
2034                    None => "",
2035                };
2036                self.put_expression(image, context, false)?;
2037                write!(self.out, ".{main_op}{comparison_op}(")?;
2038                self.put_expression(sampler, context, true)?;
2039                write!(self.out, ", ")?;
2040                self.put_expression(coordinate, context, true)?;
2041                if let Some(expr) = array_index {
2042                    write!(self.out, ", ")?;
2043                    self.put_expression(expr, context, true)?;
2044                }
2045                if let Some(dref) = depth_ref {
2046                    write!(self.out, ", ")?;
2047                    self.put_expression(dref, context, true)?;
2048                }
2049
2050                self.put_image_sample_level(image, level, context)?;
2051
2052                if let Some(offset) = offset {
2053                    write!(self.out, ", ")?;
2054                    self.put_expression(offset, context, true)?;
2055                }
2056
2057                match gather {
2058                    None | Some(crate::SwizzleComponent::X) => {}
2059                    Some(component) => {
2060                        let is_cube_map = match *context.resolve_type(image) {
2061                            crate::TypeInner::Image {
2062                                dim: crate::ImageDimension::Cube,
2063                                ..
2064                            } => true,
2065                            _ => false,
2066                        };
2067                        // Offset always comes before the gather, except
2068                        // in cube maps where it's not applicable
2069                        if offset.is_none() && !is_cube_map {
2070                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2071                        }
2072                        let letter = back::COMPONENTS[component as usize];
2073                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2074                    }
2075                }
2076                write!(self.out, ")")?;
2077            }
2078            crate::Expression::ImageLoad {
2079                image,
2080                coordinate,
2081                array_index,
2082                sample,
2083                level,
2084            } => {
2085                let address = TexelAddress {
2086                    coordinate,
2087                    array_index,
2088                    sample,
2089                    level: level.map(LevelOfDetail::Direct),
2090                };
2091                self.put_image_load(expr_handle, image, address, context)?;
2092            }
2093            //Note: for all the queries, the signed integers are expected,
2094            // so a conversion is needed.
2095            crate::Expression::ImageQuery { image, query } => match query {
2096                crate::ImageQuery::Size { level } => {
2097                    self.put_image_size_query(
2098                        image,
2099                        level.map(LevelOfDetail::Direct),
2100                        crate::ScalarKind::Uint,
2101                        context,
2102                    )?;
2103                }
2104                crate::ImageQuery::NumLevels => {
2105                    self.put_expression(image, context, false)?;
2106                    write!(self.out, ".get_num_mip_levels()")?;
2107                }
2108                crate::ImageQuery::NumLayers => {
2109                    self.put_expression(image, context, false)?;
2110                    write!(self.out, ".get_array_size()")?;
2111                }
2112                crate::ImageQuery::NumSamples => {
2113                    self.put_expression(image, context, false)?;
2114                    write!(self.out, ".get_num_samples()")?;
2115                }
2116            },
2117            crate::Expression::Unary { op, expr } => {
2118                let op_str = match op {
2119                    crate::UnaryOperator::Negate => {
2120                        match context.resolve_type(expr).scalar_kind() {
2121                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2122                            _ => "-",
2123                        }
2124                    }
2125                    crate::UnaryOperator::LogicalNot => "!",
2126                    crate::UnaryOperator::BitwiseNot => "~",
2127                };
2128                write!(self.out, "{op_str}(")?;
2129                self.put_expression(expr, context, false)?;
2130                write!(self.out, ")")?;
2131            }
2132            crate::Expression::Binary { op, left, right } => {
2133                let kind = context
2134                    .resolve_type(left)
2135                    .scalar_kind()
2136                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2137
2138                if op == crate::BinaryOperator::Divide
2139                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2140                {
2141                    write!(self.out, "{DIV_FUNCTION}(")?;
2142                    self.put_expression(left, context, true)?;
2143                    write!(self.out, ", ")?;
2144                    self.put_expression(right, context, true)?;
2145                    write!(self.out, ")")?;
2146                } else if op == crate::BinaryOperator::Modulo
2147                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2148                {
2149                    write!(self.out, "{MOD_FUNCTION}(")?;
2150                    self.put_expression(left, context, true)?;
2151                    write!(self.out, ", ")?;
2152                    self.put_expression(right, context, true)?;
2153                    write!(self.out, ")")?;
2154                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2155                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2156                    //
2157                    // float:
2158                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2159                    write!(self.out, "{NAMESPACE}::fmod(")?;
2160                    self.put_expression(left, context, true)?;
2161                    write!(self.out, ", ")?;
2162                    self.put_expression(right, context, true)?;
2163                    write!(self.out, ")")?;
2164                } else if (op == crate::BinaryOperator::Add
2165                    || op == crate::BinaryOperator::Subtract
2166                    || op == crate::BinaryOperator::Multiply)
2167                    && kind == crate::ScalarKind::Sint
2168                {
2169                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2170                        crate::TypeInner::Scalar(scalar) => {
2171                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2172                                kind: crate::ScalarKind::Uint,
2173                                ..scalar
2174                            }))
2175                        }
2176                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2177                            size,
2178                            scalar: crate::Scalar {
2179                                kind: crate::ScalarKind::Uint,
2180                                ..scalar
2181                            },
2182                        }),
2183                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2184                    };
2185
2186                    // Avoid undefined behaviour due to overflowing signed
2187                    // integer arithmetic. Cast the operands to unsigned prior
2188                    // to performing the operation, then cast the result back
2189                    // to signed.
2190                    self.put_bitcasted_expression(
2191                        context.resolve_type(expr_handle),
2192                        context,
2193                        &|writer, context, is_scoped| {
2194                            writer.put_binop(
2195                                op,
2196                                left,
2197                                right,
2198                                context,
2199                                is_scoped,
2200                                &|writer, expr, context, _is_scoped| {
2201                                    writer.put_bitcasted_expression(
2202                                        &to_unsigned(context.resolve_type(expr))?,
2203                                        context,
2204                                        &|writer, context, is_scoped| {
2205                                            writer.put_expression(expr, context, is_scoped)
2206                                        },
2207                                    )
2208                                },
2209                            )
2210                        },
2211                    )?;
2212                } else {
2213                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2214                }
2215            }
2216            crate::Expression::Select {
2217                condition,
2218                accept,
2219                reject,
2220            } => match *context.resolve_type(condition) {
2221                crate::TypeInner::Scalar(crate::Scalar {
2222                    kind: crate::ScalarKind::Bool,
2223                    ..
2224                }) => {
2225                    if !is_scoped {
2226                        write!(self.out, "(")?;
2227                    }
2228                    self.put_expression(condition, context, false)?;
2229                    write!(self.out, " ? ")?;
2230                    self.put_expression(accept, context, false)?;
2231                    write!(self.out, " : ")?;
2232                    self.put_expression(reject, context, false)?;
2233                    if !is_scoped {
2234                        write!(self.out, ")")?;
2235                    }
2236                }
2237                crate::TypeInner::Vector {
2238                    scalar:
2239                        crate::Scalar {
2240                            kind: crate::ScalarKind::Bool,
2241                            ..
2242                        },
2243                    ..
2244                } => {
2245                    write!(self.out, "{NAMESPACE}::select(")?;
2246                    self.put_expression(reject, context, true)?;
2247                    write!(self.out, ", ")?;
2248                    self.put_expression(accept, context, true)?;
2249                    write!(self.out, ", ")?;
2250                    self.put_expression(condition, context, true)?;
2251                    write!(self.out, ")")?;
2252                }
2253                ref ty => {
2254                    return Err(Error::GenericValidation(format!(
2255                        "Expected select condition to be a non-bool type, got {ty:?}",
2256                    )))
2257                }
2258            },
2259            crate::Expression::Derivative { axis, expr, .. } => {
2260                use crate::DerivativeAxis as Axis;
2261                let op = match axis {
2262                    Axis::X => "dfdx",
2263                    Axis::Y => "dfdy",
2264                    Axis::Width => "fwidth",
2265                };
2266                write!(self.out, "{NAMESPACE}::{op}")?;
2267                self.put_call_parameters(iter::once(expr), context)?;
2268            }
2269            crate::Expression::Relational { fun, argument } => {
2270                let op = match fun {
2271                    crate::RelationalFunction::Any => "any",
2272                    crate::RelationalFunction::All => "all",
2273                    crate::RelationalFunction::IsNan => "isnan",
2274                    crate::RelationalFunction::IsInf => "isinf",
2275                };
2276                write!(self.out, "{NAMESPACE}::{op}")?;
2277                self.put_call_parameters(iter::once(argument), context)?;
2278            }
2279            crate::Expression::Math {
2280                fun,
2281                arg,
2282                arg1,
2283                arg2,
2284                arg3,
2285            } => {
2286                use crate::MathFunction as Mf;
2287
2288                let arg_type = context.resolve_type(arg);
2289                let scalar_argument = match arg_type {
2290                    &crate::TypeInner::Scalar(_) => true,
2291                    _ => false,
2292                };
2293
2294                let fun_name = match fun {
2295                    // comparison
2296                    Mf::Abs => "abs",
2297                    Mf::Min => "min",
2298                    Mf::Max => "max",
2299                    Mf::Clamp => "clamp",
2300                    Mf::Saturate => "saturate",
2301                    // trigonometry
2302                    Mf::Cos => "cos",
2303                    Mf::Cosh => "cosh",
2304                    Mf::Sin => "sin",
2305                    Mf::Sinh => "sinh",
2306                    Mf::Tan => "tan",
2307                    Mf::Tanh => "tanh",
2308                    Mf::Acos => "acos",
2309                    Mf::Asin => "asin",
2310                    Mf::Atan => "atan",
2311                    Mf::Atan2 => "atan2",
2312                    Mf::Asinh => "asinh",
2313                    Mf::Acosh => "acosh",
2314                    Mf::Atanh => "atanh",
2315                    Mf::Radians => "",
2316                    Mf::Degrees => "",
2317                    // decomposition
2318                    Mf::Ceil => "ceil",
2319                    Mf::Floor => "floor",
2320                    Mf::Round => "rint",
2321                    Mf::Fract => "fract",
2322                    Mf::Trunc => "trunc",
2323                    Mf::Modf => MODF_FUNCTION,
2324                    Mf::Frexp => FREXP_FUNCTION,
2325                    Mf::Ldexp => "ldexp",
2326                    // exponent
2327                    Mf::Exp => "exp",
2328                    Mf::Exp2 => "exp2",
2329                    Mf::Log => "log",
2330                    Mf::Log2 => "log2",
2331                    Mf::Pow => "pow",
2332                    // geometry
2333                    Mf::Dot => match *context.resolve_type(arg) {
2334                        crate::TypeInner::Vector {
2335                            scalar:
2336                                crate::Scalar {
2337                                    kind: crate::ScalarKind::Float,
2338                                    ..
2339                                },
2340                            ..
2341                        } => "dot",
2342                        crate::TypeInner::Vector { size, .. } => {
2343                            return self.put_dot_product(
2344                                arg,
2345                                arg1.unwrap(),
2346                                size as usize,
2347                                |writer, arg, index| {
2348                                    // Write the vector expression; this expression is marked to be
2349                                    // cached so unless it can't be cached (for example, it's a Constant)
2350                                    // it shouldn't produce large expressions.
2351                                    writer.put_expression(arg, context, true)?;
2352                                    // Access the current component on the vector.
2353                                    write!(writer.out, ".{}", back::COMPONENTS[index])?;
2354                                    Ok(())
2355                                },
2356                            );
2357                        }
2358                        _ => unreachable!(
2359                            "Correct TypeInner for dot product should be already validated"
2360                        ),
2361                    },
2362                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2363                        if context.lang_version >= (2, 1) {
2364                            // Write potentially optimizable code using `packed_(u?)char4`.
2365                            // The two function arguments were already reinterpreted as packed (signed
2366                            // or unsigned) chars in `Self::put_block`.
2367                            let packed_type = match fun {
2368                                Mf::Dot4I8Packed => "packed_char4",
2369                                Mf::Dot4U8Packed => "packed_uchar4",
2370                                _ => unreachable!(),
2371                            };
2372
2373                            return self.put_dot_product(
2374                                Reinterpreted::new(packed_type, arg),
2375                                Reinterpreted::new(packed_type, arg1.unwrap()),
2376                                4,
2377                                |writer, arg, index| {
2378                                    // MSL implicitly promotes these (signed or unsigned) chars to
2379                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2380                                    write!(writer.out, "{arg}[{index}]")?;
2381                                    Ok(())
2382                                },
2383                            );
2384                        } else {
2385                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2386                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2387                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2388                            let conversion = match fun {
2389                                Mf::Dot4I8Packed => "int",
2390                                Mf::Dot4U8Packed => "",
2391                                _ => unreachable!(),
2392                            };
2393
2394                            return self.put_dot_product(
2395                                arg,
2396                                arg1.unwrap(),
2397                                4,
2398                                |writer, arg, index| {
2399                                    write!(writer.out, "({conversion}(")?;
2400                                    writer.put_expression(arg, context, true)?;
2401                                    if index == 3 {
2402                                        write!(writer.out, ") >> 24)")?;
2403                                    } else {
2404                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2405                                    }
2406                                    Ok(())
2407                                },
2408                            );
2409                        }
2410                    }
2411                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2412                    Mf::Cross => "cross",
2413                    Mf::Distance => "distance",
2414                    Mf::Length if scalar_argument => "abs",
2415                    Mf::Length => "length",
2416                    Mf::Normalize => "normalize",
2417                    Mf::FaceForward => "faceforward",
2418                    Mf::Reflect => "reflect",
2419                    Mf::Refract => "refract",
2420                    // computational
2421                    Mf::Sign => match arg_type.scalar_kind() {
2422                        Some(crate::ScalarKind::Sint) => {
2423                            return self.put_isign(arg, context);
2424                        }
2425                        _ => "sign",
2426                    },
2427                    Mf::Fma => "fma",
2428                    Mf::Mix => "mix",
2429                    Mf::Step => "step",
2430                    Mf::SmoothStep => "smoothstep",
2431                    Mf::Sqrt => "sqrt",
2432                    Mf::InverseSqrt => "rsqrt",
2433                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2434                    Mf::Transpose => "transpose",
2435                    Mf::Determinant => "determinant",
2436                    Mf::QuantizeToF16 => "",
2437                    // bits
2438                    Mf::CountTrailingZeros => "ctz",
2439                    Mf::CountLeadingZeros => "clz",
2440                    Mf::CountOneBits => "popcount",
2441                    Mf::ReverseBits => "reverse_bits",
2442                    Mf::ExtractBits => "",
2443                    Mf::InsertBits => "",
2444                    Mf::FirstTrailingBit => "",
2445                    Mf::FirstLeadingBit => "",
2446                    // data packing
2447                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2448                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2449                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2450                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2451                    Mf::Pack2x16float => "",
2452                    Mf::Pack4xI8 => "",
2453                    Mf::Pack4xU8 => "",
2454                    Mf::Pack4xI8Clamp => "",
2455                    Mf::Pack4xU8Clamp => "",
2456                    // data unpacking
2457                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2458                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2459                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2460                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2461                    Mf::Unpack2x16float => "",
2462                    Mf::Unpack4xI8 => "",
2463                    Mf::Unpack4xU8 => "",
2464                };
2465
2466                match fun {
2467                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2468                        // reverse_bits is listed as requiring MSL 2.1 but that
2469                        // is a copy/paste error. Looking at previous snapshots
2470                        // on web.archive.org it's present in MSL 1.2.
2471                        //
2472                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2473                        // also talks about MSL 1.2 adding "New integer
2474                        // functions to extract, insert, and reverse bits, as
2475                        // described in Integer Functions."
2476                        if context.lang_version < (1, 2) {
2477                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2478                        }
2479                    }
2480                    _ => {}
2481                }
2482
2483                match fun {
2484                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2485                        write!(self.out, "{ABS_FUNCTION}(")?;
2486                        self.put_expression(arg, context, true)?;
2487                        write!(self.out, ")")?;
2488                    }
2489                    Mf::Distance if scalar_argument => {
2490                        write!(self.out, "{NAMESPACE}::abs(")?;
2491                        self.put_expression(arg, context, false)?;
2492                        write!(self.out, " - ")?;
2493                        self.put_expression(arg1.unwrap(), context, false)?;
2494                        write!(self.out, ")")?;
2495                    }
2496                    Mf::FirstTrailingBit => {
2497                        let scalar = context.resolve_type(arg).scalar().unwrap();
2498                        let constant = scalar.width * 8 + 1;
2499
2500                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2501                        self.put_expression(arg, context, true)?;
2502                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2503                    }
2504                    Mf::FirstLeadingBit => {
2505                        let inner = context.resolve_type(arg);
2506                        let scalar = inner.scalar().unwrap();
2507                        let constant = scalar.width * 8 - 1;
2508
2509                        write!(
2510                            self.out,
2511                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2512                        )?;
2513
2514                        if scalar.kind == crate::ScalarKind::Sint {
2515                            write!(self.out, "{NAMESPACE}::select(")?;
2516                            self.put_expression(arg, context, true)?;
2517                            write!(self.out, ", ~")?;
2518                            self.put_expression(arg, context, true)?;
2519                            write!(self.out, ", ")?;
2520                            self.put_expression(arg, context, true)?;
2521                            write!(self.out, " < 0)")?;
2522                        } else {
2523                            self.put_expression(arg, context, true)?;
2524                        }
2525
2526                        write!(self.out, "), ")?;
2527
2528                        // or metal will complain that select is ambiguous
2529                        match *inner {
2530                            crate::TypeInner::Vector { size, scalar } => {
2531                                let size = common::vector_size_str(size);
2532                                let name = scalar.to_msl_name();
2533                                write!(self.out, "{name}{size}")?;
2534                            }
2535                            crate::TypeInner::Scalar(scalar) => {
2536                                let name = scalar.to_msl_name();
2537                                write!(self.out, "{name}")?;
2538                            }
2539                            _ => (),
2540                        }
2541
2542                        write!(self.out, "(-1), ")?;
2543                        self.put_expression(arg, context, true)?;
2544                        write!(self.out, " == 0 || ")?;
2545                        self.put_expression(arg, context, true)?;
2546                        write!(self.out, " == -1)")?;
2547                    }
2548                    Mf::Unpack2x16float => {
2549                        write!(self.out, "float2(as_type<half2>(")?;
2550                        self.put_expression(arg, context, false)?;
2551                        write!(self.out, "))")?;
2552                    }
2553                    Mf::Pack2x16float => {
2554                        write!(self.out, "as_type<uint>(half2(")?;
2555                        self.put_expression(arg, context, false)?;
2556                        write!(self.out, "))")?;
2557                    }
2558                    Mf::ExtractBits => {
2559                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2560                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2561                        // will return out-of-spec values if the extracted range is not within the bit width.
2562                        //
2563                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2564                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2565                        //
2566                        // w = sizeof(x) * 8
2567                        // o = min(offset, w)
2568                        // tmp = w - o
2569                        // c = min(count, tmp)
2570                        //
2571                        // bitfieldExtract(x, o, c)
2572                        //
2573                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2574
2575                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2576
2577                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2578                        self.put_expression(arg, context, true)?;
2579                        write!(self.out, ", {NAMESPACE}::min(")?;
2580                        self.put_expression(arg1.unwrap(), context, true)?;
2581                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2582                        self.put_expression(arg2.unwrap(), context, true)?;
2583                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2584                        self.put_expression(arg1.unwrap(), context, true)?;
2585                        write!(self.out, ", {scalar_bits}u)))")?;
2586                    }
2587                    Mf::InsertBits => {
2588                        // The behavior of InsertBits has the same issue as ExtractBits.
2589                        //
2590                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2591
2592                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2593
2594                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2595                        self.put_expression(arg, context, true)?;
2596                        write!(self.out, ", ")?;
2597                        self.put_expression(arg1.unwrap(), context, true)?;
2598                        write!(self.out, ", {NAMESPACE}::min(")?;
2599                        self.put_expression(arg2.unwrap(), context, true)?;
2600                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2601                        self.put_expression(arg3.unwrap(), context, true)?;
2602                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2603                        self.put_expression(arg2.unwrap(), context, true)?;
2604                        write!(self.out, ", {scalar_bits}u)))")?;
2605                    }
2606                    Mf::Radians => {
2607                        write!(self.out, "((")?;
2608                        self.put_expression(arg, context, false)?;
2609                        write!(self.out, ") * 0.017453292519943295474)")?;
2610                    }
2611                    Mf::Degrees => {
2612                        write!(self.out, "((")?;
2613                        self.put_expression(arg, context, false)?;
2614                        write!(self.out, ") * 57.295779513082322865)")?;
2615                    }
2616                    Mf::Modf | Mf::Frexp => {
2617                        write!(self.out, "{fun_name}")?;
2618                        self.put_call_parameters(iter::once(arg), context)?;
2619                    }
2620                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2621                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2622                    Mf::Pack4xI8Clamp => {
2623                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2624                    }
2625                    Mf::Pack4xU8Clamp => {
2626                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2627                    }
2628                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2629                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2630                            "u"
2631                        } else {
2632                            ""
2633                        };
2634
2635                        if context.lang_version >= (2, 1) {
2636                            // Metal uses little endian byte order, which matches what WGSL expects here.
2637                            write!(
2638                                self.out,
2639                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2640                            )?;
2641                            self.put_expression(arg, context, true)?;
2642                            write!(self.out, "))")?;
2643                        } else {
2644                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2645                            write!(self.out, "({sign_prefix}int4(")?;
2646                            self.put_expression(arg, context, true)?;
2647                            write!(self.out, ", ")?;
2648                            self.put_expression(arg, context, true)?;
2649                            write!(self.out, " >> 8, ")?;
2650                            self.put_expression(arg, context, true)?;
2651                            write!(self.out, " >> 16, ")?;
2652                            self.put_expression(arg, context, true)?;
2653                            write!(self.out, " >> 24) << 24 >> 24)")?;
2654                        }
2655                    }
2656                    Mf::QuantizeToF16 => {
2657                        match *context.resolve_type(arg) {
2658                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2659                            crate::TypeInner::Vector { size, .. } => write!(
2660                                self.out,
2661                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2662                                size = common::vector_size_str(size),
2663                            )?,
2664                            _ => unreachable!(
2665                                "Correct TypeInner for QuantizeToF16 should be already validated"
2666                            ),
2667                        };
2668
2669                        self.put_expression(arg, context, true)?;
2670                        write!(self.out, "))")?;
2671                    }
2672                    _ => {
2673                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2674                        self.put_call_parameters(
2675                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2676                            context,
2677                        )?;
2678                    }
2679                }
2680            }
2681            crate::Expression::As {
2682                expr,
2683                kind,
2684                convert,
2685            } => match *context.resolve_type(expr) {
2686                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2687                    if src.kind == crate::ScalarKind::Float
2688                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2689                        && convert.is_some()
2690                    {
2691                        // Use helper functions for float to int casts in order to avoid
2692                        // undefined behaviour when value is out of range for the target
2693                        // type.
2694                        let fun_name = match (kind, convert) {
2695                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2696                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2697                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2698                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2699                            _ => unreachable!(),
2700                        };
2701                        write!(self.out, "{fun_name}(")?;
2702                        self.put_expression(expr, context, true)?;
2703                        write!(self.out, ")")?;
2704                    } else {
2705                        let target_scalar = crate::Scalar {
2706                            kind,
2707                            width: convert.unwrap_or(src.width),
2708                        };
2709                        let op = match convert {
2710                            Some(_) => "static_cast",
2711                            None => "as_type",
2712                        };
2713                        write!(self.out, "{op}<")?;
2714                        match *context.resolve_type(expr) {
2715                            crate::TypeInner::Vector { size, .. } => {
2716                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2717                            }
2718                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2719                        };
2720                        write!(self.out, ">(")?;
2721                        self.put_expression(expr, context, true)?;
2722                        write!(self.out, ")")?;
2723                    }
2724                }
2725                crate::TypeInner::Matrix {
2726                    columns,
2727                    rows,
2728                    scalar,
2729                } => {
2730                    let target_scalar = crate::Scalar {
2731                        kind,
2732                        width: convert.unwrap_or(scalar.width),
2733                    };
2734                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2735                    write!(self.out, "(")?;
2736                    self.put_expression(expr, context, true)?;
2737                    write!(self.out, ")")?;
2738                }
2739                ref ty => {
2740                    return Err(Error::GenericValidation(format!(
2741                        "Unsupported type for As: {ty:?}"
2742                    )))
2743                }
2744            },
2745            // has to be a named expression
2746            crate::Expression::CallResult(_)
2747            | crate::Expression::AtomicResult { .. }
2748            | crate::Expression::WorkGroupUniformLoadResult { .. }
2749            | crate::Expression::SubgroupBallotResult
2750            | crate::Expression::SubgroupOperationResult { .. }
2751            | crate::Expression::RayQueryProceedResult => {
2752                unreachable!()
2753            }
2754            crate::Expression::ArrayLength(expr) => {
2755                // Find the global to which the array belongs.
2756                let global = match context.function.expressions[expr] {
2757                    crate::Expression::AccessIndex { base, .. } => {
2758                        match context.function.expressions[base] {
2759                            crate::Expression::GlobalVariable(handle) => handle,
2760                            ref ex => {
2761                                return Err(Error::GenericValidation(format!(
2762                                    "Expected global variable in AccessIndex, got {ex:?}"
2763                                )))
2764                            }
2765                        }
2766                    }
2767                    crate::Expression::GlobalVariable(handle) => handle,
2768                    ref ex => {
2769                        return Err(Error::GenericValidation(format!(
2770                            "Unexpected expression in ArrayLength, got {ex:?}"
2771                        )))
2772                    }
2773                };
2774
2775                if !is_scoped {
2776                    write!(self.out, "(")?;
2777                }
2778                write!(self.out, "1 + ")?;
2779                self.put_dynamic_array_max_index(global, context)?;
2780                if !is_scoped {
2781                    write!(self.out, ")")?;
2782                }
2783            }
2784            crate::Expression::RayQueryVertexPositions { .. } => {
2785                unimplemented!()
2786            }
2787            crate::Expression::RayQueryGetIntersection {
2788                query,
2789                committed: _,
2790            } => {
2791                if context.lang_version < (2, 4) {
2792                    return Err(Error::UnsupportedRayTracing);
2793                }
2794
2795                let ty = context.module.special_types.ray_intersection.unwrap();
2796                let type_name = &self.names[&NameKey::Type(ty)];
2797                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2798                self.put_expression(query, context, true)?;
2799                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2800                let fields = [
2801                    "distance",
2802                    "user_instance_id", // req Metal 2.4
2803                    "instance_id",
2804                    "", // SBT offset
2805                    "geometry_id",
2806                    "primitive_id",
2807                    "triangle_barycentric_coord",
2808                    "triangle_front_facing",
2809                    "",                          // padding
2810                    "object_to_world_transform", // req Metal 2.4
2811                    "world_to_object_transform", // req Metal 2.4
2812                ];
2813                for field in fields {
2814                    write!(self.out, ", ")?;
2815                    if field.is_empty() {
2816                        write!(self.out, "{{}}")?;
2817                    } else {
2818                        self.put_expression(query, context, true)?;
2819                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2820                    }
2821                }
2822                write!(self.out, "}}")?;
2823            }
2824        }
2825        Ok(())
2826    }
2827
2828    /// Emits code for a binary operation, using the provided callback to emit
2829    /// the left and right operands.
2830    fn put_binop<F>(
2831        &mut self,
2832        op: crate::BinaryOperator,
2833        left: Handle<crate::Expression>,
2834        right: Handle<crate::Expression>,
2835        context: &ExpressionContext,
2836        is_scoped: bool,
2837        put_expression: &F,
2838    ) -> BackendResult
2839    where
2840        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2841    {
2842        let op_str = back::binary_operation_str(op);
2843
2844        if !is_scoped {
2845            write!(self.out, "(")?;
2846        }
2847
2848        // Cast packed vector if necessary
2849        // Packed vector - matrix multiplications are not supported in MSL
2850        if op == crate::BinaryOperator::Multiply
2851            && matches!(
2852                context.resolve_type(right),
2853                &crate::TypeInner::Matrix { .. }
2854            )
2855        {
2856            self.put_wrapped_expression_for_packed_vec3_access(
2857                left,
2858                context,
2859                false,
2860                put_expression,
2861            )?;
2862        } else {
2863            put_expression(self, left, context, false)?;
2864        }
2865
2866        write!(self.out, " {op_str} ")?;
2867
2868        // See comment above
2869        if op == crate::BinaryOperator::Multiply
2870            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2871        {
2872            self.put_wrapped_expression_for_packed_vec3_access(
2873                right,
2874                context,
2875                false,
2876                put_expression,
2877            )?;
2878        } else {
2879            put_expression(self, right, context, false)?;
2880        }
2881
2882        if !is_scoped {
2883            write!(self.out, ")")?;
2884        }
2885
2886        Ok(())
2887    }
2888
2889    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2890    fn put_wrapped_expression_for_packed_vec3_access<F>(
2891        &mut self,
2892        expr_handle: Handle<crate::Expression>,
2893        context: &ExpressionContext,
2894        is_scoped: bool,
2895        put_expression: &F,
2896    ) -> BackendResult
2897    where
2898        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2899    {
2900        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2901            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2902            put_expression(self, expr_handle, context, is_scoped)?;
2903            write!(self.out, ")")?;
2904        } else {
2905            put_expression(self, expr_handle, context, is_scoped)?;
2906        }
2907        Ok(())
2908    }
2909
2910    /// Emits code for an expression using the provided callback, wrapping the
2911    /// result in a bitcast to the type `cast_to`.
2912    fn put_bitcasted_expression<F>(
2913        &mut self,
2914        cast_to: &crate::TypeInner,
2915        context: &ExpressionContext,
2916        put_expression: &F,
2917    ) -> BackendResult
2918    where
2919        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2920    {
2921        write!(self.out, "as_type<")?;
2922        match *cast_to {
2923            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2924            crate::TypeInner::Vector { size, scalar } => {
2925                put_numeric_type(&mut self.out, scalar, &[size])?
2926            }
2927            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2928        };
2929        write!(self.out, ">(")?;
2930        put_expression(self, context, true)?;
2931        write!(self.out, ")")?;
2932
2933        Ok(())
2934    }
2935
2936    /// Write a `GuardedIndex` as a Metal expression.
2937    fn put_index(
2938        &mut self,
2939        index: index::GuardedIndex,
2940        context: &ExpressionContext,
2941        is_scoped: bool,
2942    ) -> BackendResult {
2943        match index {
2944            index::GuardedIndex::Expression(expr) => {
2945                self.put_expression(expr, context, is_scoped)?
2946            }
2947            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2948        }
2949        Ok(())
2950    }
2951
2952    /// Emit an index bounds check condition for `chain`, if required.
2953    ///
2954    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
2955    /// operating either on a pointer to a value, or on a value directly. If we cannot
2956    /// statically determine that all indexing operations in `chain` are within
2957    /// bounds, then write a conditional expression to check them dynamically,
2958    /// and return true. All accesses in the chain are checked by the generated
2959    /// expression.
2960    ///
2961    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
2962    ///
2963    /// The text written is of the form:
2964    ///
2965    /// ```ignore
2966    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
2967    /// ```
2968    ///
2969    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
2970    /// statements, presumably these arguments start an indented `if` statement; for
2971    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
2972    /// expression. In either case, what is written is not a complete syntactic structure
2973    /// in its own right, and the caller will have to finish it off if we return `true`.
2974    ///
2975    /// If no expression is written, return false.
2976    ///
2977    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
2978    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
2979    /// [`Store`]: crate::Statement::Store
2980    /// [`Load`]: crate::Expression::Load
2981    #[allow(unused_variables)]
2982    fn put_bounds_checks(
2983        &mut self,
2984        chain: Handle<crate::Expression>,
2985        context: &ExpressionContext,
2986        level: back::Level,
2987        prefix: &'static str,
2988    ) -> Result<bool, Error> {
2989        let mut check_written = false;
2990
2991        // Iterate over the access chain, handling each required bounds check.
2992        for item in context.bounds_check_iter(chain) {
2993            let BoundsCheck {
2994                base,
2995                index,
2996                length,
2997            } = item;
2998
2999            if check_written {
3000                write!(self.out, " && ")?;
3001            } else {
3002                write!(self.out, "{level}{prefix}")?;
3003                check_written = true;
3004            }
3005
3006            // Check that the index falls within bounds. Do this with a single
3007            // comparison, by casting the index to `uint` first, so that negative
3008            // indices become large positive values.
3009            write!(self.out, "uint(")?;
3010            self.put_index(index, context, true)?;
3011            self.out.write_str(") < ")?;
3012            match length {
3013                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3014                index::IndexableLength::Dynamic => {
3015                    let global = context.function.originating_global(base).ok_or_else(|| {
3016                        Error::GenericValidation("Could not find originating global".into())
3017                    })?;
3018                    write!(self.out, "1 + ")?;
3019                    self.put_dynamic_array_max_index(global, context)?
3020                }
3021            }
3022        }
3023
3024        Ok(check_written)
3025    }
3026
3027    /// Write the access chain `chain`.
3028    ///
3029    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3030    /// operating either on a pointer to a value, or on a value directly.
3031    ///
3032    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3033    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3034    /// that must be handled in the caller.
3035    ///
3036    /// Handle the entire chain, recursing back into `put_expression` only for index
3037    /// expressions and the base expression that originates the pointer or composite value
3038    /// being accessed. This allows `put_expression` to assume that any `Access` or
3039    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3040    /// `ReadZeroSkipWrite` checks.
3041    ///
3042    /// [`Access`]: crate::Expression::Access
3043    /// [`AccessIndex`]: crate::Expression::AccessIndex
3044    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3045    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3046    fn put_access_chain(
3047        &mut self,
3048        chain: Handle<crate::Expression>,
3049        policy: index::BoundsCheckPolicy,
3050        context: &ExpressionContext,
3051    ) -> BackendResult {
3052        match context.function.expressions[chain] {
3053            crate::Expression::Access { base, index } => {
3054                let mut base_ty = context.resolve_type(base);
3055
3056                // Look through any pointers to see what we're really indexing.
3057                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3058                    base_ty = &context.module.types[base].inner;
3059                }
3060
3061                self.put_subscripted_access_chain(
3062                    base,
3063                    base_ty,
3064                    index::GuardedIndex::Expression(index),
3065                    policy,
3066                    context,
3067                )?;
3068            }
3069            crate::Expression::AccessIndex { base, index } => {
3070                let base_resolution = &context.info[base].ty;
3071                let mut base_ty = base_resolution.inner_with(&context.module.types);
3072                let mut base_ty_handle = base_resolution.handle();
3073
3074                // Look through any pointers to see what we're really indexing.
3075                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3076                    base_ty = &context.module.types[base].inner;
3077                    base_ty_handle = Some(base);
3078                }
3079
3080                // Handle structs and anything else that can use `.x` syntax here, so
3081                // `put_subscripted_access_chain` won't have to handle the absurd case of
3082                // indexing a struct with an expression.
3083                match *base_ty {
3084                    crate::TypeInner::Struct { .. } => {
3085                        let base_ty = base_ty_handle.unwrap();
3086                        self.put_access_chain(base, policy, context)?;
3087                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3088                        write!(self.out, ".{name}")?;
3089                    }
3090                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3091                        self.put_access_chain(base, policy, context)?;
3092                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3093                        // however array indexing is
3094                        if context.get_packed_vec_kind(base).is_some() {
3095                            write!(self.out, "[{index}]")?;
3096                        } else {
3097                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3098                        }
3099                    }
3100                    _ => {
3101                        self.put_subscripted_access_chain(
3102                            base,
3103                            base_ty,
3104                            index::GuardedIndex::Known(index),
3105                            policy,
3106                            context,
3107                        )?;
3108                    }
3109                }
3110            }
3111            _ => self.put_expression(chain, context, false)?,
3112        }
3113
3114        Ok(())
3115    }
3116
3117    /// Write a `[]`-style access of `base` by `index`.
3118    ///
3119    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3120    /// values within bounds.
3121    ///
3122    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3123    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3124    /// removed. Our callers often already have this handy.
3125    ///
3126    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3127    /// referencing vector components by name.
3128    ///
3129    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3130    /// [`Array`]: crate::TypeInner::Array
3131    /// [`Vector`]: crate::TypeInner::Vector
3132    /// [`Pointer`]: crate::TypeInner::Pointer
3133    fn put_subscripted_access_chain(
3134        &mut self,
3135        base: Handle<crate::Expression>,
3136        base_ty: &crate::TypeInner,
3137        index: index::GuardedIndex,
3138        policy: index::BoundsCheckPolicy,
3139        context: &ExpressionContext,
3140    ) -> BackendResult {
3141        let accessing_wrapped_array = match *base_ty {
3142            crate::TypeInner::Array {
3143                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3144                ..
3145            } => true,
3146            _ => false,
3147        };
3148        let accessing_wrapped_binding_array =
3149            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3150
3151        self.put_access_chain(base, policy, context)?;
3152        if accessing_wrapped_array {
3153            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3154        }
3155        write!(self.out, "[")?;
3156
3157        // Decide whether this index needs to be clamped to fall within range.
3158        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3159            context.access_needs_check(base, index)
3160        } else {
3161            None
3162        };
3163        if let Some(limit) = restriction_needed {
3164            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3165            self.put_index(index, context, true)?;
3166            write!(self.out, "), ")?;
3167            match limit {
3168                index::IndexableLength::Known(limit) => {
3169                    write!(self.out, "{}u", limit - 1)?;
3170                }
3171                index::IndexableLength::Dynamic => {
3172                    let global = context.function.originating_global(base).ok_or_else(|| {
3173                        Error::GenericValidation("Could not find originating global".into())
3174                    })?;
3175                    self.put_dynamic_array_max_index(global, context)?;
3176                }
3177            }
3178            write!(self.out, ")")?;
3179        } else {
3180            self.put_index(index, context, true)?;
3181        }
3182
3183        write!(self.out, "]")?;
3184
3185        if accessing_wrapped_binding_array {
3186            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3187        }
3188
3189        Ok(())
3190    }
3191
3192    fn put_load(
3193        &mut self,
3194        pointer: Handle<crate::Expression>,
3195        context: &ExpressionContext,
3196        is_scoped: bool,
3197    ) -> BackendResult {
3198        // Since access chains never cross between address spaces, we can just
3199        // check the index bounds check policy once at the top.
3200        let policy = context.choose_bounds_check_policy(pointer);
3201        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3202            && self.put_bounds_checks(
3203                pointer,
3204                context,
3205                back::Level(0),
3206                if is_scoped { "" } else { "(" },
3207            )?
3208        {
3209            write!(self.out, " ? ")?;
3210            self.put_unchecked_load(pointer, policy, context)?;
3211            write!(self.out, " : DefaultConstructible()")?;
3212
3213            if !is_scoped {
3214                write!(self.out, ")")?;
3215            }
3216        } else {
3217            self.put_unchecked_load(pointer, policy, context)?;
3218        }
3219
3220        Ok(())
3221    }
3222
3223    fn put_unchecked_load(
3224        &mut self,
3225        pointer: Handle<crate::Expression>,
3226        policy: index::BoundsCheckPolicy,
3227        context: &ExpressionContext,
3228    ) -> BackendResult {
3229        let is_atomic_pointer = context
3230            .resolve_type(pointer)
3231            .is_atomic_pointer(&context.module.types);
3232
3233        if is_atomic_pointer {
3234            write!(
3235                self.out,
3236                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3237            )?;
3238            self.put_access_chain(pointer, policy, context)?;
3239            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3240        } else {
3241            // We don't do any dereferencing with `*` here as pointer arguments to functions
3242            // are done by `&` references and not `*` pointers. These do not need to be
3243            // dereferenced.
3244            self.put_access_chain(pointer, policy, context)?;
3245        }
3246
3247        Ok(())
3248    }
3249
3250    fn put_return_value(
3251        &mut self,
3252        level: back::Level,
3253        expr_handle: Handle<crate::Expression>,
3254        result_struct: Option<&str>,
3255        context: &ExpressionContext,
3256    ) -> BackendResult {
3257        match result_struct {
3258            Some(struct_name) => {
3259                let mut has_point_size = false;
3260                let result_ty = context.function.result.as_ref().unwrap().ty;
3261                match context.module.types[result_ty].inner {
3262                    crate::TypeInner::Struct { ref members, .. } => {
3263                        let tmp = "_tmp";
3264                        write!(self.out, "{level}const auto {tmp} = ")?;
3265                        self.put_expression(expr_handle, context, true)?;
3266                        writeln!(self.out, ";")?;
3267                        write!(self.out, "{level}return {struct_name} {{")?;
3268
3269                        let mut is_first = true;
3270
3271                        for (index, member) in members.iter().enumerate() {
3272                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3273                                member.binding
3274                            {
3275                                has_point_size = true;
3276                                if !context.pipeline_options.allow_and_force_point_size {
3277                                    continue;
3278                                }
3279                            }
3280
3281                            let comma = if is_first { "" } else { "," };
3282                            is_first = false;
3283                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3284                            // HACK: we are forcefully deduplicating the expression here
3285                            // to convert from a wrapped struct to a raw array, e.g.
3286                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3287                            if let crate::TypeInner::Array {
3288                                size: crate::ArraySize::Constant(size),
3289                                ..
3290                            } = context.module.types[member.ty].inner
3291                            {
3292                                write!(self.out, "{comma} {{")?;
3293                                for j in 0..size.get() {
3294                                    if j != 0 {
3295                                        write!(self.out, ",")?;
3296                                    }
3297                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3298                                }
3299                                write!(self.out, "}}")?;
3300                            } else {
3301                                write!(self.out, "{comma} {tmp}.{name}")?;
3302                            }
3303                        }
3304                    }
3305                    _ => {
3306                        write!(self.out, "{level}return {struct_name} {{ ")?;
3307                        self.put_expression(expr_handle, context, true)?;
3308                    }
3309                }
3310
3311                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3312                    let stage = context.module.entry_points[ep_index as usize].stage;
3313                    if context.pipeline_options.allow_and_force_point_size
3314                        && stage == crate::ShaderStage::Vertex
3315                        && !has_point_size
3316                    {
3317                        // point size was injected and comes last
3318                        write!(self.out, ", 1.0")?;
3319                    }
3320                }
3321                write!(self.out, " }}")?;
3322            }
3323            None => {
3324                write!(self.out, "{level}return ")?;
3325                self.put_expression(expr_handle, context, true)?;
3326            }
3327        }
3328        writeln!(self.out, ";")?;
3329        Ok(())
3330    }
3331
3332    /// Helper method used to find which expressions of a given function require baking
3333    ///
3334    /// # Notes
3335    /// This function overwrites the contents of `self.need_bake_expressions`
3336    fn update_expressions_to_bake(
3337        &mut self,
3338        func: &crate::Function,
3339        info: &valid::FunctionInfo,
3340        context: &ExpressionContext,
3341    ) {
3342        use crate::Expression;
3343        self.need_bake_expressions.clear();
3344
3345        for (expr_handle, expr) in func.expressions.iter() {
3346            // Expressions whose reference count is above the
3347            // threshold should always be stored in temporaries.
3348            let expr_info = &info[expr_handle];
3349            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3350            if min_ref_count <= expr_info.ref_count {
3351                self.need_bake_expressions.insert(expr_handle);
3352            } else {
3353                match expr_info.ty {
3354                    // force ray desc to be baked: it's used multiple times internally
3355                    TypeResolution::Handle(h)
3356                        if Some(h) == context.module.special_types.ray_desc =>
3357                    {
3358                        self.need_bake_expressions.insert(expr_handle);
3359                    }
3360                    _ => {}
3361                }
3362            }
3363
3364            if let Expression::Math {
3365                fun,
3366                arg,
3367                arg1,
3368                arg2,
3369                ..
3370            } = *expr
3371            {
3372                match fun {
3373                    crate::MathFunction::Dot => {
3374                        // WGSL's `dot` function works on any `vecN` type, but Metal's only
3375                        // works on floating-point vectors, so we emit inline code for
3376                        // integer vector `dot` calls. But that code uses each argument `N`
3377                        // times, once for each component (see `put_dot_product`), so to
3378                        // avoid duplicated evaluation, we must bake integer operands.
3379
3380                        // check what kind of product this is depending
3381                        // on the resolve type of the Dot function itself
3382                        let inner = context.resolve_type(expr_handle);
3383                        if let crate::TypeInner::Scalar(scalar) = *inner {
3384                            match scalar.kind {
3385                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3386                                    self.need_bake_expressions.insert(arg);
3387                                    self.need_bake_expressions.insert(arg1.unwrap());
3388                                }
3389                                _ => {}
3390                            }
3391                        }
3392                    }
3393                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3394                        self.need_bake_expressions.insert(arg);
3395                        self.need_bake_expressions.insert(arg1.unwrap());
3396                    }
3397                    crate::MathFunction::FirstLeadingBit => {
3398                        self.need_bake_expressions.insert(arg);
3399                    }
3400                    crate::MathFunction::Pack4xI8
3401                    | crate::MathFunction::Pack4xU8
3402                    | crate::MathFunction::Pack4xI8Clamp
3403                    | crate::MathFunction::Pack4xU8Clamp
3404                    | crate::MathFunction::Unpack4xI8
3405                    | crate::MathFunction::Unpack4xU8 => {
3406                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3407                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3408                        if context.lang_version < (2, 1) {
3409                            self.need_bake_expressions.insert(arg);
3410                        }
3411                    }
3412                    crate::MathFunction::ExtractBits => {
3413                        // Only argument 1 is re-used.
3414                        self.need_bake_expressions.insert(arg1.unwrap());
3415                    }
3416                    crate::MathFunction::InsertBits => {
3417                        // Only argument 2 is re-used.
3418                        self.need_bake_expressions.insert(arg2.unwrap());
3419                    }
3420                    crate::MathFunction::Sign => {
3421                        // WGSL's `sign` function works also on signed ints, but Metal's only
3422                        // works on floating points, so we emit inline code for integer `sign`
3423                        // calls. But that code uses each argument 2 times (see `put_isign`),
3424                        // so to avoid duplicated evaluation, we must bake the argument.
3425                        let inner = context.resolve_type(expr_handle);
3426                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3427                            self.need_bake_expressions.insert(arg);
3428                        }
3429                    }
3430                    _ => {}
3431                }
3432            }
3433        }
3434    }
3435
3436    fn start_baking_expression(
3437        &mut self,
3438        handle: Handle<crate::Expression>,
3439        context: &ExpressionContext,
3440        name: &str,
3441    ) -> BackendResult {
3442        match context.info[handle].ty {
3443            TypeResolution::Handle(ty_handle) => {
3444                let ty_name = TypeContext {
3445                    handle: ty_handle,
3446                    gctx: context.module.to_ctx(),
3447                    names: &self.names,
3448                    access: crate::StorageAccess::empty(),
3449                    first_time: false,
3450                };
3451                write!(self.out, "{ty_name}")?;
3452            }
3453            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3454                put_numeric_type(&mut self.out, scalar, &[])?;
3455            }
3456            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3457                put_numeric_type(&mut self.out, scalar, &[size])?;
3458            }
3459            TypeResolution::Value(crate::TypeInner::Matrix {
3460                columns,
3461                rows,
3462                scalar,
3463            }) => {
3464                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3465            }
3466            TypeResolution::Value(ref other) => {
3467                log::warn!("Type {other:?} isn't a known local"); //TEMP!
3468                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3469            }
3470        }
3471
3472        //TODO: figure out the naming scheme that wouldn't collide with user names.
3473        write!(self.out, " {name} = ")?;
3474
3475        Ok(())
3476    }
3477
3478    /// Cache a clamped level of detail value, if necessary.
3479    ///
3480    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3481    /// properly clamped level of detail value both in the access itself, and
3482    /// for fetching the size of the requested MIP level, needed to clamp the
3483    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3484    /// it in a temporary variable, as part of the [`Emit`] statement covering
3485    /// the [`ImageLoad`] expression.
3486    ///
3487    /// [`ImageLoad`]: crate::Expression::ImageLoad
3488    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3489    /// [`Emit`]: crate::Statement::Emit
3490    fn put_cache_restricted_level(
3491        &mut self,
3492        load: Handle<crate::Expression>,
3493        image: Handle<crate::Expression>,
3494        mip_level: Option<Handle<crate::Expression>>,
3495        indent: back::Level,
3496        context: &StatementContext,
3497    ) -> BackendResult {
3498        // Does this image access actually require (or even permit) a
3499        // level-of-detail, and does the policy require us to restrict it?
3500        let level_of_detail = match mip_level {
3501            Some(level) => level,
3502            None => return Ok(()),
3503        };
3504
3505        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3506            || !context.expression.image_needs_lod(image)
3507        {
3508            return Ok(());
3509        }
3510
3511        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3512        self.put_restricted_scalar_image_index(
3513            image,
3514            level_of_detail,
3515            "get_num_mip_levels",
3516            &context.expression,
3517        )?;
3518        writeln!(self.out, ";")?;
3519
3520        Ok(())
3521    }
3522
3523    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3524    ///
3525    /// Caches the results in temporary variables (whose names are derived from
3526    /// the original variable names). This caching avoids the need to redo the
3527    /// casting for each vector component when emitting the dot product.
3528    fn put_casting_to_packed_chars(
3529        &mut self,
3530        fun: crate::MathFunction,
3531        arg0: Handle<crate::Expression>,
3532        arg1: Handle<crate::Expression>,
3533        indent: back::Level,
3534        context: &StatementContext<'_>,
3535    ) -> Result<(), Error> {
3536        let packed_type = match fun {
3537            crate::MathFunction::Dot4I8Packed => "packed_char4",
3538            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3539            _ => unreachable!(),
3540        };
3541
3542        for arg in [arg0, arg1] {
3543            write!(
3544                self.out,
3545                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3546                Reinterpreted::new(packed_type, arg)
3547            )?;
3548            self.put_expression(arg, &context.expression, true)?;
3549            writeln!(self.out, ");")?;
3550        }
3551
3552        Ok(())
3553    }
3554
3555    fn put_block(
3556        &mut self,
3557        level: back::Level,
3558        statements: &[crate::Statement],
3559        context: &StatementContext,
3560    ) -> BackendResult {
3561        // Add to the set in order to track the stack size.
3562        #[cfg(test)]
3563        self.put_block_stack_pointers
3564            .insert(ptr::from_ref(&level).cast());
3565
3566        for statement in statements {
3567            log::trace!("statement[{}] {:?}", level.0, statement);
3568            match *statement {
3569                crate::Statement::Emit(ref range) => {
3570                    for handle in range.clone() {
3571                        use crate::MathFunction as Mf;
3572
3573                        match context.expression.function.expressions[handle] {
3574                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3575                            // may need to cache a clamped version of their level-of-detail argument.
3576                            crate::Expression::ImageLoad {
3577                                image,
3578                                level: mip_level,
3579                                ..
3580                            } => {
3581                                self.put_cache_restricted_level(
3582                                    handle, image, mip_level, level, context,
3583                                )?;
3584                            }
3585
3586                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3587                            // 2.1+ then we introduce two intermediate variables that recast the two
3588                            // arguments as packed (signed or unsigned) chars. The actual dot product
3589                            // is implemented in `Self::put_expression`, and it uses both of these
3590                            // intermediate variables multiple times. There's no danger that the
3591                            // original arguments get modified between the definition of these
3592                            // intermediate variables and the implementation of the actual dot
3593                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3594                            crate::Expression::Math {
3595                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3596                                arg,
3597                                arg1,
3598                                ..
3599                            } if context.expression.lang_version >= (2, 1) => {
3600                                self.put_casting_to_packed_chars(
3601                                    fun,
3602                                    arg,
3603                                    arg1.unwrap(),
3604                                    level,
3605                                    context,
3606                                )?;
3607                            }
3608
3609                            _ => (),
3610                        }
3611
3612                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3613                        let expr_name = if ptr_class.is_some() {
3614                            None // don't bake pointer expressions (just yet)
3615                        } else if let Some(name) =
3616                            context.expression.function.named_expressions.get(&handle)
3617                        {
3618                            // The `crate::Function::named_expressions` table holds
3619                            // expressions that should be saved in temporaries once they
3620                            // are `Emit`ted. We only add them to `self.named_expressions`
3621                            // when we reach the `Emit` that covers them, so that we don't
3622                            // try to use their names before we've actually initialized
3623                            // the temporary that holds them.
3624                            //
3625                            // Don't assume the names in `named_expressions` are unique,
3626                            // or even valid. Use the `Namer`.
3627                            Some(self.namer.call(name))
3628                        } else {
3629                            // If this expression is an index that we're going to first compare
3630                            // against a limit, and then actually use as an index, then we may
3631                            // want to cache it in a temporary, to avoid evaluating it twice.
3632                            let bake = if context.expression.guarded_indices.contains(handle) {
3633                                true
3634                            } else {
3635                                self.need_bake_expressions.contains(&handle)
3636                            };
3637
3638                            if bake {
3639                                Some(Baked(handle).to_string())
3640                            } else {
3641                                None
3642                            }
3643                        };
3644
3645                        if let Some(name) = expr_name {
3646                            write!(self.out, "{level}")?;
3647                            self.start_baking_expression(handle, &context.expression, &name)?;
3648                            self.put_expression(handle, &context.expression, true)?;
3649                            self.named_expressions.insert(handle, name);
3650                            writeln!(self.out, ";")?;
3651                        }
3652                    }
3653                }
3654                crate::Statement::Block(ref block) => {
3655                    if !block.is_empty() {
3656                        writeln!(self.out, "{level}{{")?;
3657                        self.put_block(level.next(), block, context)?;
3658                        writeln!(self.out, "{level}}}")?;
3659                    }
3660                }
3661                crate::Statement::If {
3662                    condition,
3663                    ref accept,
3664                    ref reject,
3665                } => {
3666                    write!(self.out, "{level}if (")?;
3667                    self.put_expression(condition, &context.expression, true)?;
3668                    writeln!(self.out, ") {{")?;
3669                    self.put_block(level.next(), accept, context)?;
3670                    if !reject.is_empty() {
3671                        writeln!(self.out, "{level}}} else {{")?;
3672                        self.put_block(level.next(), reject, context)?;
3673                    }
3674                    writeln!(self.out, "{level}}}")?;
3675                }
3676                crate::Statement::Switch {
3677                    selector,
3678                    ref cases,
3679                } => {
3680                    write!(self.out, "{level}switch(")?;
3681                    self.put_expression(selector, &context.expression, true)?;
3682                    writeln!(self.out, ") {{")?;
3683                    let lcase = level.next();
3684                    for case in cases.iter() {
3685                        match case.value {
3686                            crate::SwitchValue::I32(value) => {
3687                                write!(self.out, "{lcase}case {value}:")?;
3688                            }
3689                            crate::SwitchValue::U32(value) => {
3690                                write!(self.out, "{lcase}case {value}u:")?;
3691                            }
3692                            crate::SwitchValue::Default => {
3693                                write!(self.out, "{lcase}default:")?;
3694                            }
3695                        }
3696
3697                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3698                        if write_block_braces {
3699                            writeln!(self.out, " {{")?;
3700                        } else {
3701                            writeln!(self.out)?;
3702                        }
3703
3704                        self.put_block(lcase.next(), &case.body, context)?;
3705                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3706                        {
3707                            writeln!(self.out, "{}break;", lcase.next())?;
3708                        }
3709
3710                        if write_block_braces {
3711                            writeln!(self.out, "{lcase}}}")?;
3712                        }
3713                    }
3714                    writeln!(self.out, "{level}}}")?;
3715                }
3716                crate::Statement::Loop {
3717                    ref body,
3718                    ref continuing,
3719                    break_if,
3720                } => {
3721                    let force_loop_bound_statements =
3722                        self.gen_force_bounded_loop_statements(level, context);
3723                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3724                        .then(|| self.namer.call("loop_init"));
3725
3726                    if let Some((ref decl, _)) = force_loop_bound_statements {
3727                        writeln!(self.out, "{decl}")?;
3728                    }
3729                    if let Some(ref gate_name) = gate_name {
3730                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3731                    }
3732
3733                    writeln!(self.out, "{level}while(true) {{",)?;
3734                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3735                        writeln!(self.out, "{break_and_inc}")?;
3736                    }
3737                    if let Some(ref gate_name) = gate_name {
3738                        let lif = level.next();
3739                        let lcontinuing = lif.next();
3740                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3741                        self.put_block(lcontinuing, continuing, context)?;
3742                        if let Some(condition) = break_if {
3743                            write!(self.out, "{lcontinuing}if (")?;
3744                            self.put_expression(condition, &context.expression, true)?;
3745                            writeln!(self.out, ") {{")?;
3746                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3747                            writeln!(self.out, "{lcontinuing}}}")?;
3748                        }
3749                        writeln!(self.out, "{lif}}}")?;
3750                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3751                    }
3752                    self.put_block(level.next(), body, context)?;
3753
3754                    writeln!(self.out, "{level}}}")?;
3755                }
3756                crate::Statement::Break => {
3757                    writeln!(self.out, "{level}break;")?;
3758                }
3759                crate::Statement::Continue => {
3760                    writeln!(self.out, "{level}continue;")?;
3761                }
3762                crate::Statement::Return {
3763                    value: Some(expr_handle),
3764                } => {
3765                    self.put_return_value(
3766                        level,
3767                        expr_handle,
3768                        context.result_struct,
3769                        &context.expression,
3770                    )?;
3771                }
3772                crate::Statement::Return { value: None } => {
3773                    writeln!(self.out, "{level}return;")?;
3774                }
3775                crate::Statement::Kill => {
3776                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3777                }
3778                crate::Statement::ControlBarrier(flags)
3779                | crate::Statement::MemoryBarrier(flags) => {
3780                    self.write_barrier(flags, level)?;
3781                }
3782                crate::Statement::Store { pointer, value } => {
3783                    self.put_store(pointer, value, level, context)?
3784                }
3785                crate::Statement::ImageStore {
3786                    image,
3787                    coordinate,
3788                    array_index,
3789                    value,
3790                } => {
3791                    let address = TexelAddress {
3792                        coordinate,
3793                        array_index,
3794                        sample: None,
3795                        level: None,
3796                    };
3797                    self.put_image_store(level, image, &address, value, context)?
3798                }
3799                crate::Statement::Call {
3800                    function,
3801                    ref arguments,
3802                    result,
3803                } => {
3804                    write!(self.out, "{level}")?;
3805                    if let Some(expr) = result {
3806                        let name = Baked(expr).to_string();
3807                        self.start_baking_expression(expr, &context.expression, &name)?;
3808                        self.named_expressions.insert(expr, name);
3809                    }
3810                    let fun_name = &self.names[&NameKey::Function(function)];
3811                    write!(self.out, "{fun_name}(")?;
3812                    // first, write down the actual arguments
3813                    for (i, &handle) in arguments.iter().enumerate() {
3814                        if i != 0 {
3815                            write!(self.out, ", ")?;
3816                        }
3817                        self.put_expression(handle, &context.expression, true)?;
3818                    }
3819                    // follow-up with any global resources used
3820                    let mut separate = !arguments.is_empty();
3821                    let fun_info = &context.expression.mod_info[function];
3822                    let mut needs_buffer_sizes = false;
3823                    for (handle, var) in context.expression.module.global_variables.iter() {
3824                        if fun_info[handle].is_empty() {
3825                            continue;
3826                        }
3827                        if var.space.needs_pass_through() {
3828                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3829                            if separate {
3830                                write!(self.out, ", ")?;
3831                            } else {
3832                                separate = true;
3833                            }
3834                            write!(self.out, "{name}")?;
3835                        }
3836                        needs_buffer_sizes |=
3837                            needs_array_length(var.ty, &context.expression.module.types);
3838                    }
3839                    if needs_buffer_sizes {
3840                        if separate {
3841                            write!(self.out, ", ")?;
3842                        }
3843                        write!(self.out, "_buffer_sizes")?;
3844                    }
3845
3846                    // done
3847                    writeln!(self.out, ");")?;
3848                }
3849                crate::Statement::Atomic {
3850                    pointer,
3851                    ref fun,
3852                    value,
3853                    result,
3854                } => {
3855                    let context = &context.expression;
3856
3857                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
3858                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
3859                    // `Some`, we are not operating on a 64-bit value, and that if we are
3860                    // operating on a 64-bit value, `result` is `None`.
3861                    write!(self.out, "{level}")?;
3862                    let fun_key = if let Some(result) = result {
3863                        let res_name = Baked(result).to_string();
3864                        self.start_baking_expression(result, context, &res_name)?;
3865                        self.named_expressions.insert(result, res_name);
3866                        fun.to_msl()
3867                    } else if context.resolve_type(value).scalar_width() == Some(8) {
3868                        fun.to_msl_64_bit()?
3869                    } else {
3870                        fun.to_msl()
3871                    };
3872
3873                    // If the pointer we're passing to the atomic operation needs to be conditional
3874                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
3875                    // the pointer operand should be unchecked.
3876                    let policy = context.choose_bounds_check_policy(pointer);
3877                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3878                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3879
3880                    // If requested and successfully put bounds checks, continue the ternary expression.
3881                    if checked {
3882                        write!(self.out, " ? ")?;
3883                    }
3884
3885                    // Put the atomic function invocation.
3886                    match *fun {
3887                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3888                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3889                            self.put_access_chain(pointer, policy, context)?;
3890                            write!(self.out, ", ")?;
3891                            self.put_expression(cmp, context, true)?;
3892                            write!(self.out, ", ")?;
3893                            self.put_expression(value, context, true)?;
3894                            write!(self.out, ")")?;
3895                        }
3896                        _ => {
3897                            write!(
3898                                self.out,
3899                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3900                            )?;
3901                            self.put_access_chain(pointer, policy, context)?;
3902                            write!(self.out, ", ")?;
3903                            self.put_expression(value, context, true)?;
3904                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3905                        }
3906                    }
3907
3908                    // Finish the ternary expression.
3909                    if checked {
3910                        write!(self.out, " : DefaultConstructible()")?;
3911                    }
3912
3913                    // Done
3914                    writeln!(self.out, ";")?;
3915                }
3916                crate::Statement::ImageAtomic {
3917                    image,
3918                    coordinate,
3919                    array_index,
3920                    fun,
3921                    value,
3922                } => {
3923                    let address = TexelAddress {
3924                        coordinate,
3925                        array_index,
3926                        sample: None,
3927                        level: None,
3928                    };
3929                    self.put_image_atomic(level, image, &address, fun, value, context)?
3930                }
3931                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3932                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3933
3934                    write!(self.out, "{level}")?;
3935                    let name = self.namer.call("");
3936                    self.start_baking_expression(result, &context.expression, &name)?;
3937                    self.put_load(pointer, &context.expression, true)?;
3938                    self.named_expressions.insert(result, name);
3939
3940                    writeln!(self.out, ";")?;
3941                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3942                }
3943                crate::Statement::RayQuery { query, ref fun } => {
3944                    if context.expression.lang_version < (2, 4) {
3945                        return Err(Error::UnsupportedRayTracing);
3946                    }
3947
3948                    match *fun {
3949                        crate::RayQueryFunction::Initialize {
3950                            acceleration_structure,
3951                            descriptor,
3952                        } => {
3953                            //TODO: how to deal with winding?
3954                            write!(self.out, "{level}")?;
3955                            self.put_expression(query, &context.expression, true)?;
3956                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3957                            {
3958                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3959                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3960                                write!(self.out, "{level}")?;
3961                                self.put_expression(query, &context.expression, true)?;
3962                                write!(
3963                                    self.out,
3964                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3965                                )?;
3966                                self.put_expression(descriptor, &context.expression, true)?;
3967                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3968                                self.put_expression(descriptor, &context.expression, true)?;
3969                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3970                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3971                            }
3972                            {
3973                                let f_opaque = back::RayFlag::OPAQUE.bits();
3974                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3975                                write!(self.out, "{level}")?;
3976                                self.put_expression(query, &context.expression, true)?;
3977                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3978                                self.put_expression(descriptor, &context.expression, true)?;
3979                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3980                                self.put_expression(descriptor, &context.expression, true)?;
3981                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3982                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3983                            }
3984                            {
3985                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3986                                write!(self.out, "{level}")?;
3987                                self.put_expression(query, &context.expression, true)?;
3988                                write!(
3989                                    self.out,
3990                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3991                                )?;
3992                                self.put_expression(descriptor, &context.expression, true)?;
3993                                writeln!(self.out, ".flags & {flag}) != 0);")?;
3994                            }
3995
3996                            write!(self.out, "{level}")?;
3997                            self.put_expression(query, &context.expression, true)?;
3998                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3999                            self.put_expression(query, &context.expression, true)?;
4000                            write!(
4001                                self.out,
4002                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4003                            )?;
4004                            self.put_expression(descriptor, &context.expression, true)?;
4005                            write!(self.out, ".origin, ")?;
4006                            self.put_expression(descriptor, &context.expression, true)?;
4007                            write!(self.out, ".dir, ")?;
4008                            self.put_expression(descriptor, &context.expression, true)?;
4009                            write!(self.out, ".tmin, ")?;
4010                            self.put_expression(descriptor, &context.expression, true)?;
4011                            write!(self.out, ".tmax), ")?;
4012                            self.put_expression(acceleration_structure, &context.expression, true)?;
4013                            write!(self.out, ", ")?;
4014                            self.put_expression(descriptor, &context.expression, true)?;
4015                            write!(self.out, ".cull_mask);")?;
4016
4017                            write!(self.out, "{level}")?;
4018                            self.put_expression(query, &context.expression, true)?;
4019                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4020                        }
4021                        crate::RayQueryFunction::Proceed { result } => {
4022                            write!(self.out, "{level}")?;
4023                            let name = Baked(result).to_string();
4024                            self.start_baking_expression(result, &context.expression, &name)?;
4025                            self.named_expressions.insert(result, name);
4026                            self.put_expression(query, &context.expression, true)?;
4027                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4028                            if RAY_QUERY_MODERN_SUPPORT {
4029                                write!(self.out, "{level}")?;
4030                                self.put_expression(query, &context.expression, true)?;
4031                                writeln!(self.out, ".?.next();")?;
4032                            }
4033                        }
4034                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4035                            if RAY_QUERY_MODERN_SUPPORT {
4036                                write!(self.out, "{level}")?;
4037                                self.put_expression(query, &context.expression, true)?;
4038                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4039                                self.put_expression(hit_t, &context.expression, true)?;
4040                                writeln!(self.out, ");")?;
4041                            } else {
4042                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4043                            }
4044                        }
4045                        crate::RayQueryFunction::ConfirmIntersection => {
4046                            if RAY_QUERY_MODERN_SUPPORT {
4047                                write!(self.out, "{level}")?;
4048                                self.put_expression(query, &context.expression, true)?;
4049                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4050                            } else {
4051                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4052                            }
4053                        }
4054                        crate::RayQueryFunction::Terminate => {
4055                            if RAY_QUERY_MODERN_SUPPORT {
4056                                write!(self.out, "{level}")?;
4057                                self.put_expression(query, &context.expression, true)?;
4058                                writeln!(self.out, ".?.abort();")?;
4059                            }
4060                            write!(self.out, "{level}")?;
4061                            self.put_expression(query, &context.expression, true)?;
4062                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4063                        }
4064                    }
4065                }
4066                // TODO: write emitters for these
4067                crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => {
4068                    unimplemented!()
4069                }
4070                crate::Statement::MeshFunction(
4071                    crate::MeshFunction::SetVertex { .. }
4072                    | crate::MeshFunction::SetPrimitive { .. },
4073                ) => unimplemented!(),
4074                crate::Statement::SubgroupBallot { result, predicate } => {
4075                    write!(self.out, "{level}")?;
4076                    let name = self.namer.call("");
4077                    self.start_baking_expression(result, &context.expression, &name)?;
4078                    self.named_expressions.insert(result, name);
4079                    write!(
4080                        self.out,
4081                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4082                    )?;
4083                    if let Some(predicate) = predicate {
4084                        self.put_expression(predicate, &context.expression, true)?;
4085                    } else {
4086                        write!(self.out, "true")?;
4087                    }
4088                    writeln!(self.out, "), 0, 0, 0);")?;
4089                }
4090                crate::Statement::SubgroupCollectiveOperation {
4091                    op,
4092                    collective_op,
4093                    argument,
4094                    result,
4095                } => {
4096                    write!(self.out, "{level}")?;
4097                    let name = self.namer.call("");
4098                    self.start_baking_expression(result, &context.expression, &name)?;
4099                    self.named_expressions.insert(result, name);
4100                    match (collective_op, op) {
4101                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4102                            write!(self.out, "{NAMESPACE}::simd_all(")?
4103                        }
4104                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4105                            write!(self.out, "{NAMESPACE}::simd_any(")?
4106                        }
4107                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4108                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4109                        }
4110                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4111                            write!(self.out, "{NAMESPACE}::simd_product(")?
4112                        }
4113                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4114                            write!(self.out, "{NAMESPACE}::simd_max(")?
4115                        }
4116                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4117                            write!(self.out, "{NAMESPACE}::simd_min(")?
4118                        }
4119                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4120                            write!(self.out, "{NAMESPACE}::simd_and(")?
4121                        }
4122                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4123                            write!(self.out, "{NAMESPACE}::simd_or(")?
4124                        }
4125                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4126                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4127                        }
4128                        (
4129                            crate::CollectiveOperation::ExclusiveScan,
4130                            crate::SubgroupOperation::Add,
4131                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4132                        (
4133                            crate::CollectiveOperation::ExclusiveScan,
4134                            crate::SubgroupOperation::Mul,
4135                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4136                        (
4137                            crate::CollectiveOperation::InclusiveScan,
4138                            crate::SubgroupOperation::Add,
4139                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4140                        (
4141                            crate::CollectiveOperation::InclusiveScan,
4142                            crate::SubgroupOperation::Mul,
4143                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4144                        _ => unimplemented!(),
4145                    }
4146                    self.put_expression(argument, &context.expression, true)?;
4147                    writeln!(self.out, ");")?;
4148                }
4149                crate::Statement::SubgroupGather {
4150                    mode,
4151                    argument,
4152                    result,
4153                } => {
4154                    write!(self.out, "{level}")?;
4155                    let name = self.namer.call("");
4156                    self.start_baking_expression(result, &context.expression, &name)?;
4157                    self.named_expressions.insert(result, name);
4158                    match mode {
4159                        crate::GatherMode::BroadcastFirst => {
4160                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4161                        }
4162                        crate::GatherMode::Broadcast(_) => {
4163                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4164                        }
4165                        crate::GatherMode::Shuffle(_) => {
4166                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4167                        }
4168                        crate::GatherMode::ShuffleDown(_) => {
4169                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4170                        }
4171                        crate::GatherMode::ShuffleUp(_) => {
4172                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4173                        }
4174                        crate::GatherMode::ShuffleXor(_) => {
4175                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4176                        }
4177                        crate::GatherMode::QuadBroadcast(_) => {
4178                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4179                        }
4180                        crate::GatherMode::QuadSwap(_) => {
4181                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4182                        }
4183                    }
4184                    self.put_expression(argument, &context.expression, true)?;
4185                    match mode {
4186                        crate::GatherMode::BroadcastFirst => {}
4187                        crate::GatherMode::Broadcast(index)
4188                        | crate::GatherMode::Shuffle(index)
4189                        | crate::GatherMode::ShuffleDown(index)
4190                        | crate::GatherMode::ShuffleUp(index)
4191                        | crate::GatherMode::ShuffleXor(index)
4192                        | crate::GatherMode::QuadBroadcast(index) => {
4193                            write!(self.out, ", ")?;
4194                            self.put_expression(index, &context.expression, true)?;
4195                        }
4196                        crate::GatherMode::QuadSwap(direction) => {
4197                            write!(self.out, ", ")?;
4198                            match direction {
4199                                crate::Direction::X => {
4200                                    write!(self.out, "1u")?;
4201                                }
4202                                crate::Direction::Y => {
4203                                    write!(self.out, "2u")?;
4204                                }
4205                                crate::Direction::Diagonal => {
4206                                    write!(self.out, "3u")?;
4207                                }
4208                            }
4209                        }
4210                    }
4211                    writeln!(self.out, ");")?;
4212                }
4213            }
4214        }
4215
4216        // un-emit expressions
4217        //TODO: take care of loop/continuing?
4218        for statement in statements {
4219            if let crate::Statement::Emit(ref range) = *statement {
4220                for handle in range.clone() {
4221                    self.named_expressions.shift_remove(&handle);
4222                }
4223            }
4224        }
4225        Ok(())
4226    }
4227
4228    fn put_store(
4229        &mut self,
4230        pointer: Handle<crate::Expression>,
4231        value: Handle<crate::Expression>,
4232        level: back::Level,
4233        context: &StatementContext,
4234    ) -> BackendResult {
4235        let policy = context.expression.choose_bounds_check_policy(pointer);
4236        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4237            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4238        {
4239            writeln!(self.out, ") {{")?;
4240            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4241            writeln!(self.out, "{level}}}")?;
4242        } else {
4243            self.put_unchecked_store(pointer, value, policy, level, context)?;
4244        }
4245
4246        Ok(())
4247    }
4248
4249    fn put_unchecked_store(
4250        &mut self,
4251        pointer: Handle<crate::Expression>,
4252        value: Handle<crate::Expression>,
4253        policy: index::BoundsCheckPolicy,
4254        level: back::Level,
4255        context: &StatementContext,
4256    ) -> BackendResult {
4257        let is_atomic_pointer = context
4258            .expression
4259            .resolve_type(pointer)
4260            .is_atomic_pointer(&context.expression.module.types);
4261
4262        if is_atomic_pointer {
4263            write!(
4264                self.out,
4265                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4266            )?;
4267            self.put_access_chain(pointer, policy, &context.expression)?;
4268            write!(self.out, ", ")?;
4269            self.put_expression(value, &context.expression, true)?;
4270            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4271        } else {
4272            write!(self.out, "{level}")?;
4273            self.put_access_chain(pointer, policy, &context.expression)?;
4274            write!(self.out, " = ")?;
4275            self.put_expression(value, &context.expression, true)?;
4276            writeln!(self.out, ";")?;
4277        }
4278
4279        Ok(())
4280    }
4281
4282    pub fn write(
4283        &mut self,
4284        module: &crate::Module,
4285        info: &valid::ModuleInfo,
4286        options: &Options,
4287        pipeline_options: &PipelineOptions,
4288    ) -> Result<TranslationInfo, Error> {
4289        self.names.clear();
4290        self.namer.reset(
4291            module,
4292            &super::keywords::RESERVED_SET,
4293            proc::CaseInsensitiveKeywordSet::empty(),
4294            &[CLAMPED_LOD_LOAD_PREFIX],
4295            &mut self.names,
4296        );
4297        self.wrapped_functions.clear();
4298        self.struct_member_pads.clear();
4299
4300        writeln!(
4301            self.out,
4302            "// language: metal{}.{}",
4303            options.lang_version.0, options.lang_version.1
4304        )?;
4305        writeln!(self.out, "#include <metal_stdlib>")?;
4306        writeln!(self.out, "#include <simd/simd.h>")?;
4307        writeln!(self.out)?;
4308        // Work around Metal bug where `uint` is not available by default
4309        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4310
4311        let mut uses_ray_query = false;
4312        for (_, ty) in module.types.iter() {
4313            match ty.inner {
4314                crate::TypeInner::AccelerationStructure { .. } => {
4315                    if options.lang_version < (2, 4) {
4316                        return Err(Error::UnsupportedRayTracing);
4317                    }
4318                }
4319                crate::TypeInner::RayQuery { .. } => {
4320                    if options.lang_version < (2, 4) {
4321                        return Err(Error::UnsupportedRayTracing);
4322                    }
4323                    uses_ray_query = true;
4324                }
4325                _ => (),
4326            }
4327        }
4328
4329        if module.special_types.ray_desc.is_some()
4330            || module.special_types.ray_intersection.is_some()
4331        {
4332            if options.lang_version < (2, 4) {
4333                return Err(Error::UnsupportedRayTracing);
4334            }
4335        }
4336
4337        if uses_ray_query {
4338            self.put_ray_query_type()?;
4339        }
4340
4341        if options
4342            .bounds_check_policies
4343            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4344        {
4345            self.put_default_constructible()?;
4346        }
4347        writeln!(self.out)?;
4348
4349        {
4350            // Make a `Vec` of all the `GlobalVariable`s that contain
4351            // runtime-sized arrays.
4352            let globals: Vec<Handle<crate::GlobalVariable>> = module
4353                .global_variables
4354                .iter()
4355                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4356                .map(|(handle, _)| handle)
4357                .collect();
4358
4359            let mut buffer_indices = vec![];
4360            for vbm in &pipeline_options.vertex_buffer_mappings {
4361                buffer_indices.push(vbm.id);
4362            }
4363
4364            if !globals.is_empty() || !buffer_indices.is_empty() {
4365                writeln!(self.out, "struct _mslBufferSizes {{")?;
4366
4367                for global in globals {
4368                    writeln!(
4369                        self.out,
4370                        "{}uint {};",
4371                        back::INDENT,
4372                        ArraySizeMember(global)
4373                    )?;
4374                }
4375
4376                for idx in buffer_indices {
4377                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4378                }
4379
4380                writeln!(self.out, "}};")?;
4381                writeln!(self.out)?;
4382            }
4383        };
4384
4385        self.write_type_defs(module)?;
4386        self.write_global_constants(module, info)?;
4387        self.write_functions(module, info, options, pipeline_options)
4388    }
4389
4390    /// Write the definition for the `DefaultConstructible` class.
4391    ///
4392    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4393    /// produce 'zero' values for any type, including structs, arrays, and so
4394    /// on. We could do this by emitting default constructor applications, but
4395    /// that would entail printing the name of the type, which is more trouble
4396    /// than you'd think. Instead, we just construct this magic C++14 class that
4397    /// can be converted to any type that can be default constructed, using
4398    /// template parameter inference to detect which type is needed, so we don't
4399    /// have to figure out the name.
4400    ///
4401    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4402    fn put_default_constructible(&mut self) -> BackendResult {
4403        let tab = back::INDENT;
4404        writeln!(self.out, "struct DefaultConstructible {{")?;
4405        writeln!(self.out, "{tab}template<typename T>")?;
4406        writeln!(self.out, "{tab}operator T() && {{")?;
4407        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4408        writeln!(self.out, "{tab}}}")?;
4409        writeln!(self.out, "}};")?;
4410        Ok(())
4411    }
4412
4413    fn put_ray_query_type(&mut self) -> BackendResult {
4414        let tab = back::INDENT;
4415        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4416        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4417        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4418        writeln!(
4419            self.out,
4420            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4421        )?;
4422        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4423        writeln!(self.out, "}};")?;
4424        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4425        let v_triangle = back::RayIntersectionType::Triangle as u32;
4426        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4427        writeln!(
4428            self.out,
4429            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4430        )?;
4431        writeln!(
4432            self.out,
4433            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4434        )?;
4435        writeln!(self.out, "}}")?;
4436        Ok(())
4437    }
4438
4439    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4440        let mut generated_argument_buffer_wrapper = false;
4441        let mut generated_external_texture_wrapper = false;
4442        for (handle, ty) in module.types.iter() {
4443            match ty.inner {
4444                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4445                    writeln!(self.out, "template <typename T>")?;
4446                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4447                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4448                    writeln!(self.out, "}};")?;
4449                    generated_argument_buffer_wrapper = true;
4450                }
4451                crate::TypeInner::Image {
4452                    class: crate::ImageClass::External,
4453                    ..
4454                } if !generated_external_texture_wrapper => {
4455                    let params_ty_name = &self.names
4456                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4457                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4458                    writeln!(
4459                        self.out,
4460                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4461                        back::INDENT
4462                    )?;
4463                    writeln!(
4464                        self.out,
4465                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4466                        back::INDENT
4467                    )?;
4468                    writeln!(
4469                        self.out,
4470                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4471                        back::INDENT
4472                    )?;
4473                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4474                    writeln!(self.out, "}};")?;
4475                    generated_external_texture_wrapper = true;
4476                }
4477                _ => {}
4478            }
4479
4480            if !ty.needs_alias() {
4481                continue;
4482            }
4483            let name = &self.names[&NameKey::Type(handle)];
4484            match ty.inner {
4485                // Naga IR can pass around arrays by value, but Metal, following
4486                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4487                // on expressions of array type, so assigning the array by value
4488                // isn't possible. However, Metal *does* assign structs by
4489                // value. So in our Metal output, we wrap all array types in
4490                // synthetic struct types:
4491                //
4492                //     struct type1 {
4493                //         float inner[10]
4494                //     };
4495                //
4496                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4497                // any expression that actually wants access to the array.
4498                crate::TypeInner::Array {
4499                    base,
4500                    size,
4501                    stride: _,
4502                } => {
4503                    let base_name = TypeContext {
4504                        handle: base,
4505                        gctx: module.to_ctx(),
4506                        names: &self.names,
4507                        access: crate::StorageAccess::empty(),
4508                        first_time: false,
4509                    };
4510
4511                    match size.resolve(module.to_ctx())? {
4512                        proc::IndexableLength::Known(size) => {
4513                            writeln!(self.out, "struct {name} {{")?;
4514                            writeln!(
4515                                self.out,
4516                                "{}{} {}[{}];",
4517                                back::INDENT,
4518                                base_name,
4519                                WRAPPED_ARRAY_FIELD,
4520                                size
4521                            )?;
4522                            writeln!(self.out, "}};")?;
4523                        }
4524                        proc::IndexableLength::Dynamic => {
4525                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4526                        }
4527                    }
4528                }
4529                crate::TypeInner::Struct {
4530                    ref members, span, ..
4531                } => {
4532                    writeln!(self.out, "struct {name} {{")?;
4533                    let mut last_offset = 0;
4534                    for (index, member) in members.iter().enumerate() {
4535                        if member.offset > last_offset {
4536                            self.struct_member_pads.insert((handle, index as u32));
4537                            let pad = member.offset - last_offset;
4538                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4539                        }
4540                        let ty_inner = &module.types[member.ty].inner;
4541                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4542
4543                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4544
4545                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4546                        match should_pack_struct_member(members, span, index, module) {
4547                            Some(scalar) => {
4548                                writeln!(
4549                                    self.out,
4550                                    "{}{}::packed_{}3 {};",
4551                                    back::INDENT,
4552                                    NAMESPACE,
4553                                    scalar.to_msl_name(),
4554                                    member_name
4555                                )?;
4556                            }
4557                            None => {
4558                                let base_name = TypeContext {
4559                                    handle: member.ty,
4560                                    gctx: module.to_ctx(),
4561                                    names: &self.names,
4562                                    access: crate::StorageAccess::empty(),
4563                                    first_time: false,
4564                                };
4565                                writeln!(
4566                                    self.out,
4567                                    "{}{} {};",
4568                                    back::INDENT,
4569                                    base_name,
4570                                    member_name
4571                                )?;
4572
4573                                // for 3-component vectors, add one component
4574                                if let crate::TypeInner::Vector {
4575                                    size: crate::VectorSize::Tri,
4576                                    scalar,
4577                                } = *ty_inner
4578                                {
4579                                    last_offset += scalar.width as u32;
4580                                }
4581                            }
4582                        }
4583                    }
4584                    if last_offset < span {
4585                        let pad = span - last_offset;
4586                        writeln!(
4587                            self.out,
4588                            "{}char _pad{}[{}];",
4589                            back::INDENT,
4590                            members.len(),
4591                            pad
4592                        )?;
4593                    }
4594                    writeln!(self.out, "}};")?;
4595                }
4596                _ => {
4597                    let ty_name = TypeContext {
4598                        handle,
4599                        gctx: module.to_ctx(),
4600                        names: &self.names,
4601                        access: crate::StorageAccess::empty(),
4602                        first_time: true,
4603                    };
4604                    writeln!(self.out, "typedef {ty_name} {name};")?;
4605                }
4606            }
4607        }
4608
4609        // Write functions to create special types.
4610        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4611            match type_key {
4612                &crate::PredeclaredType::ModfResult { size, scalar }
4613                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4614                    let arg_type_name_owner;
4615                    let arg_type_name = if let Some(size) = size {
4616                        arg_type_name_owner = format!(
4617                            "{NAMESPACE}::{}{}",
4618                            if scalar.width == 8 { "double" } else { "float" },
4619                            size as u8
4620                        );
4621                        &arg_type_name_owner
4622                    } else if scalar.width == 8 {
4623                        "double"
4624                    } else {
4625                        "float"
4626                    };
4627
4628                    let other_type_name_owner;
4629                    let (defined_func_name, called_func_name, other_type_name) =
4630                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4631                            (MODF_FUNCTION, "modf", arg_type_name)
4632                        } else {
4633                            let other_type_name = if let Some(size) = size {
4634                                other_type_name_owner = format!("int{}", size as u8);
4635                                &other_type_name_owner
4636                            } else {
4637                                "int"
4638                            };
4639                            (FREXP_FUNCTION, "frexp", other_type_name)
4640                        };
4641
4642                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4643
4644                    writeln!(self.out)?;
4645                    writeln!(
4646                        self.out,
4647                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4648    {other_type_name} other;
4649    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4650    return {struct_name}{{ fract, other }};
4651}}"
4652                    )?;
4653                }
4654                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4655                    let arg_type_name = scalar.to_msl_name();
4656                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4657                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4658                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4659
4660                    writeln!(self.out)?;
4661
4662                    for address_space_name in ["device", "threadgroup"] {
4663                        writeln!(
4664                            self.out,
4665                            "\
4666template <typename A>
4667{struct_name} {defined_func_name}(
4668    {address_space_name} A *atomic_ptr,
4669    {arg_type_name} cmp,
4670    {arg_type_name} v
4671) {{
4672    bool swapped = {NAMESPACE}::{called_func_name}(
4673        atomic_ptr, &cmp, v,
4674        metal::memory_order_relaxed, metal::memory_order_relaxed
4675    );
4676    return {struct_name}{{cmp, swapped}};
4677}}"
4678                        )?;
4679                    }
4680                }
4681            }
4682        }
4683
4684        Ok(())
4685    }
4686
4687    /// Writes all named constants
4688    fn write_global_constants(
4689        &mut self,
4690        module: &crate::Module,
4691        mod_info: &valid::ModuleInfo,
4692    ) -> BackendResult {
4693        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4694
4695        for (handle, constant) in constants {
4696            let ty_name = TypeContext {
4697                handle: constant.ty,
4698                gctx: module.to_ctx(),
4699                names: &self.names,
4700                access: crate::StorageAccess::empty(),
4701                first_time: false,
4702            };
4703            let name = &self.names[&NameKey::Constant(handle)];
4704            write!(self.out, "constant {ty_name} {name} = ")?;
4705            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4706            writeln!(self.out, ";")?;
4707        }
4708
4709        Ok(())
4710    }
4711
4712    fn put_inline_sampler_properties(
4713        &mut self,
4714        level: back::Level,
4715        sampler: &sm::InlineSampler,
4716    ) -> BackendResult {
4717        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4718            writeln!(
4719                self.out,
4720                "{}{}::{}_address::{},",
4721                level,
4722                NAMESPACE,
4723                letter,
4724                address.as_str(),
4725            )?;
4726        }
4727        writeln!(
4728            self.out,
4729            "{}{}::mag_filter::{},",
4730            level,
4731            NAMESPACE,
4732            sampler.mag_filter.as_str(),
4733        )?;
4734        writeln!(
4735            self.out,
4736            "{}{}::min_filter::{},",
4737            level,
4738            NAMESPACE,
4739            sampler.min_filter.as_str(),
4740        )?;
4741        if let Some(filter) = sampler.mip_filter {
4742            writeln!(
4743                self.out,
4744                "{}{}::mip_filter::{},",
4745                level,
4746                NAMESPACE,
4747                filter.as_str(),
4748            )?;
4749        }
4750        // avoid setting it on platforms that don't support it
4751        if sampler.border_color != sm::BorderColor::TransparentBlack {
4752            writeln!(
4753                self.out,
4754                "{}{}::border_color::{},",
4755                level,
4756                NAMESPACE,
4757                sampler.border_color.as_str(),
4758            )?;
4759        }
4760        //TODO: I'm not able to feed this in a way that MSL likes:
4761        //>error: use of undeclared identifier 'lod_clamp'
4762        //>error: no member named 'max_anisotropy' in namespace 'metal'
4763        if false {
4764            if let Some(ref lod) = sampler.lod_clamp {
4765                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4766            }
4767            if let Some(aniso) = sampler.max_anisotropy {
4768                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4769            }
4770        }
4771        if sampler.compare_func != sm::CompareFunc::Never {
4772            writeln!(
4773                self.out,
4774                "{}{}::compare_func::{},",
4775                level,
4776                NAMESPACE,
4777                sampler.compare_func.as_str(),
4778            )?;
4779        }
4780        writeln!(
4781            self.out,
4782            "{}{}::coord::{}",
4783            level,
4784            NAMESPACE,
4785            sampler.coord.as_str()
4786        )?;
4787        Ok(())
4788    }
4789
4790    fn write_unpacking_function(
4791        &mut self,
4792        format: back::msl::VertexFormat,
4793    ) -> Result<(String, u32, u32), Error> {
4794        use back::msl::VertexFormat::*;
4795        match format {
4796            Uint8 => {
4797                let name = self.namer.call("unpackUint8");
4798                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4799                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4800                writeln!(self.out, "}}")?;
4801                Ok((name, 1, 1))
4802            }
4803            Uint8x2 => {
4804                let name = self.namer.call("unpackUint8x2");
4805                writeln!(
4806                    self.out,
4807                    "metal::uint2 {name}(metal::uchar b0, \
4808                                         metal::uchar b1) {{"
4809                )?;
4810                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4811                writeln!(self.out, "}}")?;
4812                Ok((name, 2, 2))
4813            }
4814            Uint8x4 => {
4815                let name = self.namer.call("unpackUint8x4");
4816                writeln!(
4817                    self.out,
4818                    "metal::uint4 {name}(metal::uchar b0, \
4819                                         metal::uchar b1, \
4820                                         metal::uchar b2, \
4821                                         metal::uchar b3) {{"
4822                )?;
4823                writeln!(
4824                    self.out,
4825                    "{}return metal::uint4(b0, b1, b2, b3);",
4826                    back::INDENT
4827                )?;
4828                writeln!(self.out, "}}")?;
4829                Ok((name, 4, 4))
4830            }
4831            Sint8 => {
4832                let name = self.namer.call("unpackSint8");
4833                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4834                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4835                writeln!(self.out, "}}")?;
4836                Ok((name, 1, 1))
4837            }
4838            Sint8x2 => {
4839                let name = self.namer.call("unpackSint8x2");
4840                writeln!(
4841                    self.out,
4842                    "metal::int2 {name}(metal::uchar b0, \
4843                                        metal::uchar b1) {{"
4844                )?;
4845                writeln!(
4846                    self.out,
4847                    "{}return metal::int2(as_type<char>(b0), \
4848                                          as_type<char>(b1));",
4849                    back::INDENT
4850                )?;
4851                writeln!(self.out, "}}")?;
4852                Ok((name, 2, 2))
4853            }
4854            Sint8x4 => {
4855                let name = self.namer.call("unpackSint8x4");
4856                writeln!(
4857                    self.out,
4858                    "metal::int4 {name}(metal::uchar b0, \
4859                                        metal::uchar b1, \
4860                                        metal::uchar b2, \
4861                                        metal::uchar b3) {{"
4862                )?;
4863                writeln!(
4864                    self.out,
4865                    "{}return metal::int4(as_type<char>(b0), \
4866                                          as_type<char>(b1), \
4867                                          as_type<char>(b2), \
4868                                          as_type<char>(b3));",
4869                    back::INDENT
4870                )?;
4871                writeln!(self.out, "}}")?;
4872                Ok((name, 4, 4))
4873            }
4874            Unorm8 => {
4875                let name = self.namer.call("unpackUnorm8");
4876                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4877                writeln!(
4878                    self.out,
4879                    "{}return float(float(b0) / 255.0f);",
4880                    back::INDENT
4881                )?;
4882                writeln!(self.out, "}}")?;
4883                Ok((name, 1, 1))
4884            }
4885            Unorm8x2 => {
4886                let name = self.namer.call("unpackUnorm8x2");
4887                writeln!(
4888                    self.out,
4889                    "metal::float2 {name}(metal::uchar b0, \
4890                                          metal::uchar b1) {{"
4891                )?;
4892                writeln!(
4893                    self.out,
4894                    "{}return metal::float2(float(b0) / 255.0f, \
4895                                            float(b1) / 255.0f);",
4896                    back::INDENT
4897                )?;
4898                writeln!(self.out, "}}")?;
4899                Ok((name, 2, 2))
4900            }
4901            Unorm8x4 => {
4902                let name = self.namer.call("unpackUnorm8x4");
4903                writeln!(
4904                    self.out,
4905                    "metal::float4 {name}(metal::uchar b0, \
4906                                          metal::uchar b1, \
4907                                          metal::uchar b2, \
4908                                          metal::uchar b3) {{"
4909                )?;
4910                writeln!(
4911                    self.out,
4912                    "{}return metal::float4(float(b0) / 255.0f, \
4913                                            float(b1) / 255.0f, \
4914                                            float(b2) / 255.0f, \
4915                                            float(b3) / 255.0f);",
4916                    back::INDENT
4917                )?;
4918                writeln!(self.out, "}}")?;
4919                Ok((name, 4, 4))
4920            }
4921            Snorm8 => {
4922                let name = self.namer.call("unpackSnorm8");
4923                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4924                writeln!(
4925                    self.out,
4926                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4927                    back::INDENT
4928                )?;
4929                writeln!(self.out, "}}")?;
4930                Ok((name, 1, 1))
4931            }
4932            Snorm8x2 => {
4933                let name = self.namer.call("unpackSnorm8x2");
4934                writeln!(
4935                    self.out,
4936                    "metal::float2 {name}(metal::uchar b0, \
4937                                          metal::uchar b1) {{"
4938                )?;
4939                writeln!(
4940                    self.out,
4941                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4942                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4943                    back::INDENT
4944                )?;
4945                writeln!(self.out, "}}")?;
4946                Ok((name, 2, 2))
4947            }
4948            Snorm8x4 => {
4949                let name = self.namer.call("unpackSnorm8x4");
4950                writeln!(
4951                    self.out,
4952                    "metal::float4 {name}(metal::uchar b0, \
4953                                          metal::uchar b1, \
4954                                          metal::uchar b2, \
4955                                          metal::uchar b3) {{"
4956                )?;
4957                writeln!(
4958                    self.out,
4959                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4960                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4961                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4962                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4963                    back::INDENT
4964                )?;
4965                writeln!(self.out, "}}")?;
4966                Ok((name, 4, 4))
4967            }
4968            Uint16 => {
4969                let name = self.namer.call("unpackUint16");
4970                writeln!(
4971                    self.out,
4972                    "metal::uint {name}(metal::uint b0, \
4973                                        metal::uint b1) {{"
4974                )?;
4975                writeln!(
4976                    self.out,
4977                    "{}return metal::uint(b1 << 8 | b0);",
4978                    back::INDENT
4979                )?;
4980                writeln!(self.out, "}}")?;
4981                Ok((name, 2, 1))
4982            }
4983            Uint16x2 => {
4984                let name = self.namer.call("unpackUint16x2");
4985                writeln!(
4986                    self.out,
4987                    "metal::uint2 {name}(metal::uint b0, \
4988                                         metal::uint b1, \
4989                                         metal::uint b2, \
4990                                         metal::uint b3) {{"
4991                )?;
4992                writeln!(
4993                    self.out,
4994                    "{}return metal::uint2(b1 << 8 | b0, \
4995                                           b3 << 8 | b2);",
4996                    back::INDENT
4997                )?;
4998                writeln!(self.out, "}}")?;
4999                Ok((name, 4, 2))
5000            }
5001            Uint16x4 => {
5002                let name = self.namer.call("unpackUint16x4");
5003                writeln!(
5004                    self.out,
5005                    "metal::uint4 {name}(metal::uint b0, \
5006                                         metal::uint b1, \
5007                                         metal::uint b2, \
5008                                         metal::uint b3, \
5009                                         metal::uint b4, \
5010                                         metal::uint b5, \
5011                                         metal::uint b6, \
5012                                         metal::uint b7) {{"
5013                )?;
5014                writeln!(
5015                    self.out,
5016                    "{}return metal::uint4(b1 << 8 | b0, \
5017                                           b3 << 8 | b2, \
5018                                           b5 << 8 | b4, \
5019                                           b7 << 8 | b6);",
5020                    back::INDENT
5021                )?;
5022                writeln!(self.out, "}}")?;
5023                Ok((name, 8, 4))
5024            }
5025            Sint16 => {
5026                let name = self.namer.call("unpackSint16");
5027                writeln!(
5028                    self.out,
5029                    "int {name}(metal::ushort b0, \
5030                                metal::ushort b1) {{"
5031                )?;
5032                writeln!(
5033                    self.out,
5034                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5035                    back::INDENT
5036                )?;
5037                writeln!(self.out, "}}")?;
5038                Ok((name, 2, 1))
5039            }
5040            Sint16x2 => {
5041                let name = self.namer.call("unpackSint16x2");
5042                writeln!(
5043                    self.out,
5044                    "metal::int2 {name}(metal::ushort b0, \
5045                                        metal::ushort b1, \
5046                                        metal::ushort b2, \
5047                                        metal::ushort b3) {{"
5048                )?;
5049                writeln!(
5050                    self.out,
5051                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5052                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5053                    back::INDENT
5054                )?;
5055                writeln!(self.out, "}}")?;
5056                Ok((name, 4, 2))
5057            }
5058            Sint16x4 => {
5059                let name = self.namer.call("unpackSint16x4");
5060                writeln!(
5061                    self.out,
5062                    "metal::int4 {name}(metal::ushort b0, \
5063                                        metal::ushort b1, \
5064                                        metal::ushort b2, \
5065                                        metal::ushort b3, \
5066                                        metal::ushort b4, \
5067                                        metal::ushort b5, \
5068                                        metal::ushort b6, \
5069                                        metal::ushort b7) {{"
5070                )?;
5071                writeln!(
5072                    self.out,
5073                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5074                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5075                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5076                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5077                    back::INDENT
5078                )?;
5079                writeln!(self.out, "}}")?;
5080                Ok((name, 8, 4))
5081            }
5082            Unorm16 => {
5083                let name = self.namer.call("unpackUnorm16");
5084                writeln!(
5085                    self.out,
5086                    "float {name}(metal::ushort b0, \
5087                                  metal::ushort b1) {{"
5088                )?;
5089                writeln!(
5090                    self.out,
5091                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5092                    back::INDENT
5093                )?;
5094                writeln!(self.out, "}}")?;
5095                Ok((name, 2, 1))
5096            }
5097            Unorm16x2 => {
5098                let name = self.namer.call("unpackUnorm16x2");
5099                writeln!(
5100                    self.out,
5101                    "metal::float2 {name}(metal::ushort b0, \
5102                                          metal::ushort b1, \
5103                                          metal::ushort b2, \
5104                                          metal::ushort b3) {{"
5105                )?;
5106                writeln!(
5107                    self.out,
5108                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5109                                            float(b3 << 8 | b2) / 65535.0f);",
5110                    back::INDENT
5111                )?;
5112                writeln!(self.out, "}}")?;
5113                Ok((name, 4, 2))
5114            }
5115            Unorm16x4 => {
5116                let name = self.namer.call("unpackUnorm16x4");
5117                writeln!(
5118                    self.out,
5119                    "metal::float4 {name}(metal::ushort b0, \
5120                                          metal::ushort b1, \
5121                                          metal::ushort b2, \
5122                                          metal::ushort b3, \
5123                                          metal::ushort b4, \
5124                                          metal::ushort b5, \
5125                                          metal::ushort b6, \
5126                                          metal::ushort b7) {{"
5127                )?;
5128                writeln!(
5129                    self.out,
5130                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5131                                            float(b3 << 8 | b2) / 65535.0f, \
5132                                            float(b5 << 8 | b4) / 65535.0f, \
5133                                            float(b7 << 8 | b6) / 65535.0f);",
5134                    back::INDENT
5135                )?;
5136                writeln!(self.out, "}}")?;
5137                Ok((name, 8, 4))
5138            }
5139            Snorm16 => {
5140                let name = self.namer.call("unpackSnorm16");
5141                writeln!(
5142                    self.out,
5143                    "float {name}(metal::ushort b0, \
5144                                  metal::ushort b1) {{"
5145                )?;
5146                writeln!(
5147                    self.out,
5148                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5149                    back::INDENT
5150                )?;
5151                writeln!(self.out, "}}")?;
5152                Ok((name, 2, 1))
5153            }
5154            Snorm16x2 => {
5155                let name = self.namer.call("unpackSnorm16x2");
5156                writeln!(
5157                    self.out,
5158                    "metal::float2 {name}(uint b0, \
5159                                          uint b1, \
5160                                          uint b2, \
5161                                          uint b3) {{"
5162                )?;
5163                writeln!(
5164                    self.out,
5165                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5166                    back::INDENT
5167                )?;
5168                writeln!(self.out, "}}")?;
5169                Ok((name, 4, 2))
5170            }
5171            Snorm16x4 => {
5172                let name = self.namer.call("unpackSnorm16x4");
5173                writeln!(
5174                    self.out,
5175                    "metal::float4 {name}(uint b0, \
5176                                          uint b1, \
5177                                          uint b2, \
5178                                          uint b3, \
5179                                          uint b4, \
5180                                          uint b5, \
5181                                          uint b6, \
5182                                          uint b7) {{"
5183                )?;
5184                writeln!(
5185                    self.out,
5186                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5187                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5188                    back::INDENT
5189                )?;
5190                writeln!(self.out, "}}")?;
5191                Ok((name, 8, 4))
5192            }
5193            Float16 => {
5194                let name = self.namer.call("unpackFloat16");
5195                writeln!(
5196                    self.out,
5197                    "float {name}(metal::ushort b0, \
5198                                  metal::ushort b1) {{"
5199                )?;
5200                writeln!(
5201                    self.out,
5202                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5203                    back::INDENT
5204                )?;
5205                writeln!(self.out, "}}")?;
5206                Ok((name, 2, 1))
5207            }
5208            Float16x2 => {
5209                let name = self.namer.call("unpackFloat16x2");
5210                writeln!(
5211                    self.out,
5212                    "metal::float2 {name}(metal::ushort b0, \
5213                                          metal::ushort b1, \
5214                                          metal::ushort b2, \
5215                                          metal::ushort b3) {{"
5216                )?;
5217                writeln!(
5218                    self.out,
5219                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5220                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5221                    back::INDENT
5222                )?;
5223                writeln!(self.out, "}}")?;
5224                Ok((name, 4, 2))
5225            }
5226            Float16x4 => {
5227                let name = self.namer.call("unpackFloat16x4");
5228                writeln!(
5229                    self.out,
5230                    "metal::float4 {name}(metal::ushort b0, \
5231                                        metal::ushort b1, \
5232                                        metal::ushort b2, \
5233                                        metal::ushort b3, \
5234                                        metal::ushort b4, \
5235                                        metal::ushort b5, \
5236                                        metal::ushort b6, \
5237                                        metal::ushort b7) {{"
5238                )?;
5239                writeln!(
5240                    self.out,
5241                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5242                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5243                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5244                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5245                    back::INDENT
5246                )?;
5247                writeln!(self.out, "}}")?;
5248                Ok((name, 8, 4))
5249            }
5250            Float32 => {
5251                let name = self.namer.call("unpackFloat32");
5252                writeln!(
5253                    self.out,
5254                    "float {name}(uint b0, \
5255                                  uint b1, \
5256                                  uint b2, \
5257                                  uint b3) {{"
5258                )?;
5259                writeln!(
5260                    self.out,
5261                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5262                    back::INDENT
5263                )?;
5264                writeln!(self.out, "}}")?;
5265                Ok((name, 4, 1))
5266            }
5267            Float32x2 => {
5268                let name = self.namer.call("unpackFloat32x2");
5269                writeln!(
5270                    self.out,
5271                    "metal::float2 {name}(uint b0, \
5272                                          uint b1, \
5273                                          uint b2, \
5274                                          uint b3, \
5275                                          uint b4, \
5276                                          uint b5, \
5277                                          uint b6, \
5278                                          uint b7) {{"
5279                )?;
5280                writeln!(
5281                    self.out,
5282                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5283                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5284                    back::INDENT
5285                )?;
5286                writeln!(self.out, "}}")?;
5287                Ok((name, 8, 2))
5288            }
5289            Float32x3 => {
5290                let name = self.namer.call("unpackFloat32x3");
5291                writeln!(
5292                    self.out,
5293                    "metal::float3 {name}(uint b0, \
5294                                          uint b1, \
5295                                          uint b2, \
5296                                          uint b3, \
5297                                          uint b4, \
5298                                          uint b5, \
5299                                          uint b6, \
5300                                          uint b7, \
5301                                          uint b8, \
5302                                          uint b9, \
5303                                          uint b10, \
5304                                          uint b11) {{"
5305                )?;
5306                writeln!(
5307                    self.out,
5308                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5309                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5310                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5311                    back::INDENT
5312                )?;
5313                writeln!(self.out, "}}")?;
5314                Ok((name, 12, 3))
5315            }
5316            Float32x4 => {
5317                let name = self.namer.call("unpackFloat32x4");
5318                writeln!(
5319                    self.out,
5320                    "metal::float4 {name}(uint b0, \
5321                                          uint b1, \
5322                                          uint b2, \
5323                                          uint b3, \
5324                                          uint b4, \
5325                                          uint b5, \
5326                                          uint b6, \
5327                                          uint b7, \
5328                                          uint b8, \
5329                                          uint b9, \
5330                                          uint b10, \
5331                                          uint b11, \
5332                                          uint b12, \
5333                                          uint b13, \
5334                                          uint b14, \
5335                                          uint b15) {{"
5336                )?;
5337                writeln!(
5338                    self.out,
5339                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5340                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5341                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5342                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5343                    back::INDENT
5344                )?;
5345                writeln!(self.out, "}}")?;
5346                Ok((name, 16, 4))
5347            }
5348            Uint32 => {
5349                let name = self.namer.call("unpackUint32");
5350                writeln!(
5351                    self.out,
5352                    "uint {name}(uint b0, \
5353                                 uint b1, \
5354                                 uint b2, \
5355                                 uint b3) {{"
5356                )?;
5357                writeln!(
5358                    self.out,
5359                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5360                    back::INDENT
5361                )?;
5362                writeln!(self.out, "}}")?;
5363                Ok((name, 4, 1))
5364            }
5365            Uint32x2 => {
5366                let name = self.namer.call("unpackUint32x2");
5367                writeln!(
5368                    self.out,
5369                    "uint2 {name}(uint b0, \
5370                                  uint b1, \
5371                                  uint b2, \
5372                                  uint b3, \
5373                                  uint b4, \
5374                                  uint b5, \
5375                                  uint b6, \
5376                                  uint b7) {{"
5377                )?;
5378                writeln!(
5379                    self.out,
5380                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5381                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5382                    back::INDENT
5383                )?;
5384                writeln!(self.out, "}}")?;
5385                Ok((name, 8, 2))
5386            }
5387            Uint32x3 => {
5388                let name = self.namer.call("unpackUint32x3");
5389                writeln!(
5390                    self.out,
5391                    "uint3 {name}(uint b0, \
5392                                  uint b1, \
5393                                  uint b2, \
5394                                  uint b3, \
5395                                  uint b4, \
5396                                  uint b5, \
5397                                  uint b6, \
5398                                  uint b7, \
5399                                  uint b8, \
5400                                  uint b9, \
5401                                  uint b10, \
5402                                  uint b11) {{"
5403                )?;
5404                writeln!(
5405                    self.out,
5406                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5407                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5408                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5409                    back::INDENT
5410                )?;
5411                writeln!(self.out, "}}")?;
5412                Ok((name, 12, 3))
5413            }
5414            Uint32x4 => {
5415                let name = self.namer.call("unpackUint32x4");
5416                writeln!(
5417                    self.out,
5418                    "{NAMESPACE}::uint4 {name}(uint b0, \
5419                                  uint b1, \
5420                                  uint b2, \
5421                                  uint b3, \
5422                                  uint b4, \
5423                                  uint b5, \
5424                                  uint b6, \
5425                                  uint b7, \
5426                                  uint b8, \
5427                                  uint b9, \
5428                                  uint b10, \
5429                                  uint b11, \
5430                                  uint b12, \
5431                                  uint b13, \
5432                                  uint b14, \
5433                                  uint b15) {{"
5434                )?;
5435                writeln!(
5436                    self.out,
5437                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5438                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5439                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5440                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5441                    back::INDENT
5442                )?;
5443                writeln!(self.out, "}}")?;
5444                Ok((name, 16, 4))
5445            }
5446            Sint32 => {
5447                let name = self.namer.call("unpackSint32");
5448                writeln!(
5449                    self.out,
5450                    "int {name}(uint b0, \
5451                                uint b1, \
5452                                uint b2, \
5453                                uint b3) {{"
5454                )?;
5455                writeln!(
5456                    self.out,
5457                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5458                    back::INDENT
5459                )?;
5460                writeln!(self.out, "}}")?;
5461                Ok((name, 4, 1))
5462            }
5463            Sint32x2 => {
5464                let name = self.namer.call("unpackSint32x2");
5465                writeln!(
5466                    self.out,
5467                    "metal::int2 {name}(uint b0, \
5468                                        uint b1, \
5469                                        uint b2, \
5470                                        uint b3, \
5471                                        uint b4, \
5472                                        uint b5, \
5473                                        uint b6, \
5474                                        uint b7) {{"
5475                )?;
5476                writeln!(
5477                    self.out,
5478                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5479                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5480                    back::INDENT
5481                )?;
5482                writeln!(self.out, "}}")?;
5483                Ok((name, 8, 2))
5484            }
5485            Sint32x3 => {
5486                let name = self.namer.call("unpackSint32x3");
5487                writeln!(
5488                    self.out,
5489                    "metal::int3 {name}(uint b0, \
5490                                        uint b1, \
5491                                        uint b2, \
5492                                        uint b3, \
5493                                        uint b4, \
5494                                        uint b5, \
5495                                        uint b6, \
5496                                        uint b7, \
5497                                        uint b8, \
5498                                        uint b9, \
5499                                        uint b10, \
5500                                        uint b11) {{"
5501                )?;
5502                writeln!(
5503                    self.out,
5504                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5505                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5506                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5507                    back::INDENT
5508                )?;
5509                writeln!(self.out, "}}")?;
5510                Ok((name, 12, 3))
5511            }
5512            Sint32x4 => {
5513                let name = self.namer.call("unpackSint32x4");
5514                writeln!(
5515                    self.out,
5516                    "metal::int4 {name}(uint b0, \
5517                                        uint b1, \
5518                                        uint b2, \
5519                                        uint b3, \
5520                                        uint b4, \
5521                                        uint b5, \
5522                                        uint b6, \
5523                                        uint b7, \
5524                                        uint b8, \
5525                                        uint b9, \
5526                                        uint b10, \
5527                                        uint b11, \
5528                                        uint b12, \
5529                                        uint b13, \
5530                                        uint b14, \
5531                                        uint b15) {{"
5532                )?;
5533                writeln!(
5534                    self.out,
5535                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5536                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5537                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5538                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5539                    back::INDENT
5540                )?;
5541                writeln!(self.out, "}}")?;
5542                Ok((name, 16, 4))
5543            }
5544            Unorm10_10_10_2 => {
5545                let name = self.namer.call("unpackUnorm10_10_10_2");
5546                writeln!(
5547                    self.out,
5548                    "metal::float4 {name}(uint b0, \
5549                                          uint b1, \
5550                                          uint b2, \
5551                                          uint b3) {{"
5552                )?;
5553                writeln!(
5554                    self.out,
5555                    // The following is correct for RGBA packing, but our format seems to
5556                    // match ABGR, which can be fed into the Metal builtin function
5557                    // unpack_unorm10a2_to_float.
5558                    /*
5559                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5560                       uint r = (v & 0xFFC00000) >> 22; \
5561                       uint g = (v & 0x003FF000) >> 12; \
5562                       uint b = (v & 0x00000FFC) >> 2; \
5563                       uint a = (v & 0x00000003); \
5564                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5565                    */
5566                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5567                    back::INDENT
5568                )?;
5569                writeln!(self.out, "}}")?;
5570                Ok((name, 4, 4))
5571            }
5572            Unorm8x4Bgra => {
5573                let name = self.namer.call("unpackUnorm8x4Bgra");
5574                writeln!(
5575                    self.out,
5576                    "metal::float4 {name}(metal::uchar b0, \
5577                                          metal::uchar b1, \
5578                                          metal::uchar b2, \
5579                                          metal::uchar b3) {{"
5580                )?;
5581                writeln!(
5582                    self.out,
5583                    "{}return metal::float4(float(b2) / 255.0f, \
5584                                            float(b1) / 255.0f, \
5585                                            float(b0) / 255.0f, \
5586                                            float(b3) / 255.0f);",
5587                    back::INDENT
5588                )?;
5589                writeln!(self.out, "}}")?;
5590                Ok((name, 4, 4))
5591            }
5592        }
5593    }
5594
5595    fn write_wrapped_unary_op(
5596        &mut self,
5597        module: &crate::Module,
5598        func_ctx: &back::FunctionCtx,
5599        op: crate::UnaryOperator,
5600        operand: Handle<crate::Expression>,
5601    ) -> BackendResult {
5602        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5603        match op {
5604            // Negating the TYPE_MIN of a two's complement signed integer
5605            // type causes overflow, which is undefined behaviour in MSL. To
5606            // avoid this we bitcast the value to unsigned and negate it,
5607            // then bitcast back to signed.
5608            // This adheres to the WGSL spec in that the negative of the
5609            // type's minimum value should equal to the minimum value.
5610            crate::UnaryOperator::Negate
5611                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5612            {
5613                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5614                    return Ok(());
5615                };
5616                let wrapped = WrappedFunction::UnaryOp {
5617                    op,
5618                    ty: (vector_size, scalar),
5619                };
5620                if !self.wrapped_functions.insert(wrapped) {
5621                    return Ok(());
5622                }
5623
5624                let unsigned_scalar = crate::Scalar {
5625                    kind: crate::ScalarKind::Uint,
5626                    ..scalar
5627                };
5628                let mut type_name = String::new();
5629                let mut unsigned_type_name = String::new();
5630                match vector_size {
5631                    None => {
5632                        put_numeric_type(&mut type_name, scalar, &[])?;
5633                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5634                    }
5635                    Some(size) => {
5636                        put_numeric_type(&mut type_name, scalar, &[size])?;
5637                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5638                    }
5639                };
5640
5641                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5642                let level = back::Level(1);
5643                writeln!(
5644                    self.out,
5645                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5646                )?;
5647                writeln!(self.out, "}}")?;
5648                writeln!(self.out)?;
5649            }
5650            _ => {}
5651        }
5652        Ok(())
5653    }
5654
5655    fn write_wrapped_binary_op(
5656        &mut self,
5657        module: &crate::Module,
5658        func_ctx: &back::FunctionCtx,
5659        expr: Handle<crate::Expression>,
5660        op: crate::BinaryOperator,
5661        left: Handle<crate::Expression>,
5662        right: Handle<crate::Expression>,
5663    ) -> BackendResult {
5664        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5665        let left_ty = func_ctx.resolve_type(left, &module.types);
5666        let right_ty = func_ctx.resolve_type(right, &module.types);
5667        match (op, expr_ty.scalar_kind()) {
5668            // Signed integer division of TYPE_MIN / -1, or signed or
5669            // unsigned division by zero, gives an unspecified value in MSL.
5670            // We override the divisor to 1 in these cases.
5671            // This adheres to the WGSL spec in that:
5672            // * TYPE_MIN / -1 == TYPE_MIN
5673            // * x / 0 == x
5674            (
5675                crate::BinaryOperator::Divide,
5676                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5677            ) => {
5678                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5679                    return Ok(());
5680                };
5681                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5682                    return Ok(());
5683                };
5684                let wrapped = WrappedFunction::BinaryOp {
5685                    op,
5686                    left_ty: left_wrapped_ty,
5687                    right_ty: right_wrapped_ty,
5688                };
5689                if !self.wrapped_functions.insert(wrapped) {
5690                    return Ok(());
5691                }
5692
5693                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5694                    return Ok(());
5695                };
5696                let mut type_name = String::new();
5697                match vector_size {
5698                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5699                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5700                };
5701                writeln!(
5702                    self.out,
5703                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5704                )?;
5705                let level = back::Level(1);
5706                match scalar.kind {
5707                    crate::ScalarKind::Sint => {
5708                        let min_val = match scalar.width {
5709                            4 => crate::Literal::I32(i32::MIN),
5710                            8 => crate::Literal::I64(i64::MIN),
5711                            _ => {
5712                                return Err(Error::GenericValidation(format!(
5713                                    "Unexpected width for scalar {scalar:?}"
5714                                )));
5715                            }
5716                        };
5717                        write!(
5718                            self.out,
5719                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5720                        )?;
5721                        self.put_literal(min_val)?;
5722                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5723                    }
5724                    crate::ScalarKind::Uint => writeln!(
5725                        self.out,
5726                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5727                    )?,
5728                    _ => unreachable!(),
5729                }
5730                writeln!(self.out, "}}")?;
5731                writeln!(self.out)?;
5732            }
5733            // Integer modulo where one or both operands are negative, or the
5734            // divisor is zero, is undefined behaviour in MSL. To avoid this
5735            // we use the following equation:
5736            //
5737            // dividend - (dividend / divisor) * divisor
5738            //
5739            // overriding the divisor to 1 if either it is 0, or it is -1
5740            // and the dividend is TYPE_MIN.
5741            //
5742            // This adheres to the WGSL spec in that:
5743            // * TYPE_MIN % -1 == 0
5744            // * x % 0 == 0
5745            (
5746                crate::BinaryOperator::Modulo,
5747                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5748            ) => {
5749                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5750                    return Ok(());
5751                };
5752                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5753                else {
5754                    return Ok(());
5755                };
5756                let wrapped = WrappedFunction::BinaryOp {
5757                    op,
5758                    left_ty: left_wrapped_ty,
5759                    right_ty: (right_vector_size, right_scalar),
5760                };
5761                if !self.wrapped_functions.insert(wrapped) {
5762                    return Ok(());
5763                }
5764
5765                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5766                    return Ok(());
5767                };
5768                let mut type_name = String::new();
5769                match vector_size {
5770                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5771                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5772                };
5773                let mut rhs_type_name = String::new();
5774                match right_vector_size {
5775                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5776                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5777                };
5778
5779                writeln!(
5780                    self.out,
5781                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5782                )?;
5783                let level = back::Level(1);
5784                match scalar.kind {
5785                    crate::ScalarKind::Sint => {
5786                        let min_val = match scalar.width {
5787                            4 => crate::Literal::I32(i32::MIN),
5788                            8 => crate::Literal::I64(i64::MIN),
5789                            _ => {
5790                                return Err(Error::GenericValidation(format!(
5791                                    "Unexpected width for scalar {scalar:?}"
5792                                )));
5793                            }
5794                        };
5795                        write!(
5796                            self.out,
5797                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5798                        )?;
5799                        self.put_literal(min_val)?;
5800                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5801                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5802                    }
5803                    crate::ScalarKind::Uint => writeln!(
5804                        self.out,
5805                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5806                    )?,
5807                    _ => unreachable!(),
5808                }
5809                writeln!(self.out, "}}")?;
5810                writeln!(self.out)?;
5811            }
5812            _ => {}
5813        }
5814        Ok(())
5815    }
5816
5817    #[allow(clippy::too_many_arguments)]
5818    fn write_wrapped_math_function(
5819        &mut self,
5820        module: &crate::Module,
5821        func_ctx: &back::FunctionCtx,
5822        fun: crate::MathFunction,
5823        arg: Handle<crate::Expression>,
5824        _arg1: Option<Handle<crate::Expression>>,
5825        _arg2: Option<Handle<crate::Expression>>,
5826        _arg3: Option<Handle<crate::Expression>>,
5827    ) -> BackendResult {
5828        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5829        match fun {
5830            // Taking the absolute value of the TYPE_MIN of a two's
5831            // complement signed integer type causes overflow, which is
5832            // undefined behaviour in MSL. To avoid this, when the value is
5833            // negative we bitcast the value to unsigned and negate it, then
5834            // bitcast back to signed.
5835            // This adheres to the WGSL spec in that the absolute of the
5836            // type's minimum value should equal to the minimum value.
5837            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5838                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5839                    return Ok(());
5840                };
5841                let wrapped = WrappedFunction::Math {
5842                    fun,
5843                    arg_ty: (vector_size, scalar),
5844                };
5845                if !self.wrapped_functions.insert(wrapped) {
5846                    return Ok(());
5847                }
5848
5849                let unsigned_scalar = crate::Scalar {
5850                    kind: crate::ScalarKind::Uint,
5851                    ..scalar
5852                };
5853                let mut type_name = String::new();
5854                let mut unsigned_type_name = String::new();
5855                match vector_size {
5856                    None => {
5857                        put_numeric_type(&mut type_name, scalar, &[])?;
5858                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5859                    }
5860                    Some(size) => {
5861                        put_numeric_type(&mut type_name, scalar, &[size])?;
5862                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5863                    }
5864                };
5865
5866                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5867                let level = back::Level(1);
5868                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5869                writeln!(self.out, "}}")?;
5870                writeln!(self.out)?;
5871            }
5872            _ => {}
5873        }
5874        Ok(())
5875    }
5876
5877    fn write_wrapped_cast(
5878        &mut self,
5879        module: &crate::Module,
5880        func_ctx: &back::FunctionCtx,
5881        expr: Handle<crate::Expression>,
5882        kind: crate::ScalarKind,
5883        convert: Option<crate::Bytes>,
5884    ) -> BackendResult {
5885        // Avoid undefined behaviour when casting from a float to integer
5886        // when the value is out of range for the target type. Additionally
5887        // ensure we clamp to the correct value as per the WGSL spec.
5888        //
5889        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
5890        // * If X is exactly representable in the target type T, then the
5891        //   result is that value.
5892        // * Otherwise, the result is the value in T closest to
5893        //   truncate(X) and also exactly representable in the original
5894        //   floating point type.
5895        let src_ty = func_ctx.resolve_type(expr, &module.types);
5896        let Some(width) = convert else {
5897            return Ok(());
5898        };
5899        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5900            return Ok(());
5901        };
5902        let dst_scalar = crate::Scalar { kind, width };
5903        if src_scalar.kind != crate::ScalarKind::Float
5904            || (dst_scalar.kind != crate::ScalarKind::Sint
5905                && dst_scalar.kind != crate::ScalarKind::Uint)
5906        {
5907            return Ok(());
5908        }
5909        let wrapped = WrappedFunction::Cast {
5910            src_scalar,
5911            vector_size,
5912            dst_scalar,
5913        };
5914        if !self.wrapped_functions.insert(wrapped) {
5915            return Ok(());
5916        }
5917        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5918
5919        let mut src_type_name = String::new();
5920        match vector_size {
5921            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5922            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5923        };
5924        let mut dst_type_name = String::new();
5925        match vector_size {
5926            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5927            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5928        };
5929        let fun_name = match dst_scalar {
5930            crate::Scalar::I32 => F2I32_FUNCTION,
5931            crate::Scalar::U32 => F2U32_FUNCTION,
5932            crate::Scalar::I64 => F2I64_FUNCTION,
5933            crate::Scalar::U64 => F2U64_FUNCTION,
5934            _ => unreachable!(),
5935        };
5936
5937        writeln!(
5938            self.out,
5939            "{dst_type_name} {fun_name}({src_type_name} value) {{"
5940        )?;
5941        let level = back::Level(1);
5942        write!(
5943            self.out,
5944            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5945        )?;
5946        self.put_literal(min)?;
5947        write!(self.out, ", ")?;
5948        self.put_literal(max)?;
5949        writeln!(self.out, "));")?;
5950        writeln!(self.out, "}}")?;
5951        writeln!(self.out)?;
5952        Ok(())
5953    }
5954
5955    /// Helper function used by [`Self::write_wrapped_image_load`] and
5956    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
5957    /// conversion code for external textures. Expects the preceding code to
5958    /// declare the Y component as a `float` variable of name `y`, the UV
5959    /// components as a `float2` variable of name `uv`, and the external
5960    /// texture params as a variable of name `params`. The emitted code will
5961    /// return the result.
5962    fn write_convert_yuv_to_rgb_and_return(
5963        &mut self,
5964        level: back::Level,
5965        y: &str,
5966        uv: &str,
5967        params: &str,
5968    ) -> BackendResult {
5969        let l1 = level;
5970        let l2 = l1.next();
5971
5972        // Convert from YUV to non-linear RGB in the source color space.
5973        writeln!(
5974            self.out,
5975            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
5976        )?;
5977
5978        // Apply the inverse of the source transfer function to convert to
5979        // linear RGB in the source color space.
5980        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
5981        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
5982        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
5983        writeln!(
5984            self.out,
5985            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
5986        )?;
5987
5988        // Multiply by the gamut conversion matrix to convert to linear RGB in
5989        // the destination color space.
5990        writeln!(
5991            self.out,
5992            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
5993        )?;
5994
5995        // Finally, apply the dest transfer function to convert to non-linear
5996        // RGB in the destination color space, and return the result.
5997        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
5998        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
5999        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6000        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6001
6002        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6003        Ok(())
6004    }
6005
6006    #[allow(clippy::too_many_arguments)]
6007    fn write_wrapped_image_load(
6008        &mut self,
6009        module: &crate::Module,
6010        func_ctx: &back::FunctionCtx,
6011        image: Handle<crate::Expression>,
6012        _coordinate: Handle<crate::Expression>,
6013        _array_index: Option<Handle<crate::Expression>>,
6014        _sample: Option<Handle<crate::Expression>>,
6015        _level: Option<Handle<crate::Expression>>,
6016    ) -> BackendResult {
6017        // We currently only need to wrap image loads for external textures
6018        let class = match *func_ctx.resolve_type(image, &module.types) {
6019            crate::TypeInner::Image { class, .. } => class,
6020            _ => unreachable!(),
6021        };
6022        if class != crate::ImageClass::External {
6023            return Ok(());
6024        }
6025        let wrapped = WrappedFunction::ImageLoad { class };
6026        if !self.wrapped_functions.insert(wrapped) {
6027            return Ok(());
6028        }
6029
6030        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6031        let l1 = back::Level(1);
6032        let l2 = l1.next();
6033        let l3 = l2.next();
6034        writeln!(
6035            self.out,
6036            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6037        )?;
6038        // Clamp coords to provided size of external texture to prevent OOB
6039        // read. If params.size is zero then clamp to the actual size of the
6040        // texture.
6041        writeln!(
6042            self.out,
6043            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6044        )?;
6045        writeln!(
6046            self.out,
6047            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6048        )?;
6049
6050        // Apply load transformation
6051        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6052        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6053        // For single plane, simply read from plane0
6054        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6055        writeln!(self.out, "{l1}}} else {{")?;
6056
6057        // Chroma planes may be subsampled so we must scale the coords accordingly.
6058        writeln!(
6059            self.out,
6060            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6061        )?;
6062        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6063
6064        // For multi-plane, read the Y value from plane 0
6065        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6066
6067        writeln!(self.out, "{l2}float2 uv;")?;
6068        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6069        // For 2 planes, read UV from interleaved plane 1
6070        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6071        writeln!(self.out, "{l2}}} else {{")?;
6072        // For 3 planes, read U and V from planes 1 and 2 respectively
6073        writeln!(
6074            self.out,
6075            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6076        )?;
6077        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6078        writeln!(
6079            self.out,
6080            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6081        )?;
6082        writeln!(self.out, "{l2}}}")?;
6083
6084        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6085
6086        writeln!(self.out, "{l1}}}")?;
6087        writeln!(self.out, "}}")?;
6088        writeln!(self.out)?;
6089        Ok(())
6090    }
6091
6092    #[allow(clippy::too_many_arguments)]
6093    fn write_wrapped_image_sample(
6094        &mut self,
6095        module: &crate::Module,
6096        func_ctx: &back::FunctionCtx,
6097        image: Handle<crate::Expression>,
6098        _sampler: Handle<crate::Expression>,
6099        _gather: Option<crate::SwizzleComponent>,
6100        _coordinate: Handle<crate::Expression>,
6101        _array_index: Option<Handle<crate::Expression>>,
6102        _offset: Option<Handle<crate::Expression>>,
6103        _level: crate::SampleLevel,
6104        _depth_ref: Option<Handle<crate::Expression>>,
6105        clamp_to_edge: bool,
6106    ) -> BackendResult {
6107        // We currently only need to wrap textureSampleBaseClampToEdge, for
6108        // both sampled and external textures.
6109        if !clamp_to_edge {
6110            return Ok(());
6111        }
6112        let class = match *func_ctx.resolve_type(image, &module.types) {
6113            crate::TypeInner::Image { class, .. } => class,
6114            _ => unreachable!(),
6115        };
6116        let wrapped = WrappedFunction::ImageSample {
6117            class,
6118            clamp_to_edge: true,
6119        };
6120        if !self.wrapped_functions.insert(wrapped) {
6121            return Ok(());
6122        }
6123        match class {
6124            crate::ImageClass::External => {
6125                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6126                let l1 = back::Level(1);
6127                let l2 = l1.next();
6128                let l3 = l2.next();
6129                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6130                writeln!(
6131                    self.out,
6132                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6133                )?;
6134
6135                // Calculate the sample bounds. The purported size of the texture
6136                // (params.size) is irrelevant here as we are dealing with normalized
6137                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6138                // apply the sample transformation to that, also bearing in mind that it
6139                // may contain a flip on either axis. We calculate and adjust for the
6140                // half-texel separately for each plane as it depends on the actual
6141                // texture size which may vary between planes.
6142                writeln!(
6143                    self.out,
6144                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6145                )?;
6146                writeln!(
6147                    self.out,
6148                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6149                )?;
6150                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6151                writeln!(
6152                    self.out,
6153                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6154                )?;
6155                writeln!(
6156                    self.out,
6157                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6158                )?;
6159                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6160                // For single plane, simply sample from plane0
6161                writeln!(
6162                    self.out,
6163                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6164                )?;
6165                writeln!(self.out, "{l1}}} else {{")?;
6166                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6167                writeln!(
6168                    self.out,
6169                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6170                )?;
6171                writeln!(
6172                    self.out,
6173                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6174                )?;
6175
6176                // For multi-plane, sample the Y value from plane 0
6177                writeln!(
6178                    self.out,
6179                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6180                )?;
6181                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6182                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6183                // For 2 planes, sample UV from interleaved plane 1
6184                writeln!(
6185                    self.out,
6186                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6187                )?;
6188                writeln!(self.out, "{l2}}} else {{")?;
6189                // For 3 planes, sample U and V from planes 1 and 2 respectively
6190                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6191                writeln!(
6192                    self.out,
6193                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6194                )?;
6195                writeln!(
6196                    self.out,
6197                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6198                )?;
6199                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6200                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6201                writeln!(self.out, "{l2}}}")?;
6202
6203                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6204
6205                writeln!(self.out, "{l1}}}")?;
6206                writeln!(self.out, "}}")?;
6207                writeln!(self.out)?;
6208            }
6209            _ => {
6210                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6211                let l1 = back::Level(1);
6212                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6213                writeln!(
6214                    self.out,
6215                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6216                )?;
6217                writeln!(self.out, "}}")?;
6218                writeln!(self.out)?;
6219            }
6220        }
6221        Ok(())
6222    }
6223
6224    fn write_wrapped_image_query(
6225        &mut self,
6226        module: &crate::Module,
6227        func_ctx: &back::FunctionCtx,
6228        image: Handle<crate::Expression>,
6229        query: crate::ImageQuery,
6230    ) -> BackendResult {
6231        // We currently only need to wrap size image queries for external textures
6232        if !matches!(query, crate::ImageQuery::Size { .. }) {
6233            return Ok(());
6234        }
6235        let class = match *func_ctx.resolve_type(image, &module.types) {
6236            crate::TypeInner::Image { class, .. } => class,
6237            _ => unreachable!(),
6238        };
6239        if class != crate::ImageClass::External {
6240            return Ok(());
6241        }
6242        let wrapped = WrappedFunction::ImageQuerySize { class };
6243        if !self.wrapped_functions.insert(wrapped) {
6244            return Ok(());
6245        }
6246        writeln!(
6247            self.out,
6248            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6249        )?;
6250        let l1 = back::Level(1);
6251        let l2 = l1.next();
6252        writeln!(
6253            self.out,
6254            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6255        )?;
6256        writeln!(self.out, "{l2}return tex.params.size;")?;
6257        writeln!(self.out, "{l1}}} else {{")?;
6258        // params.size == (0, 0) indicates to query and return plane 0's actual size
6259        writeln!(
6260            self.out,
6261            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6262        )?;
6263        writeln!(self.out, "{l1}}}")?;
6264        writeln!(self.out, "}}")?;
6265        writeln!(self.out)?;
6266        Ok(())
6267    }
6268
6269    pub(super) fn write_wrapped_functions(
6270        &mut self,
6271        module: &crate::Module,
6272        func_ctx: &back::FunctionCtx,
6273    ) -> BackendResult {
6274        for (expr_handle, expr) in func_ctx.expressions.iter() {
6275            match *expr {
6276                crate::Expression::Unary { op, expr: operand } => {
6277                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6278                }
6279                crate::Expression::Binary { op, left, right } => {
6280                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6281                }
6282                crate::Expression::Math {
6283                    fun,
6284                    arg,
6285                    arg1,
6286                    arg2,
6287                    arg3,
6288                } => {
6289                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6290                }
6291                crate::Expression::As {
6292                    expr,
6293                    kind,
6294                    convert,
6295                } => {
6296                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6297                }
6298                crate::Expression::ImageLoad {
6299                    image,
6300                    coordinate,
6301                    array_index,
6302                    sample,
6303                    level,
6304                } => {
6305                    self.write_wrapped_image_load(
6306                        module,
6307                        func_ctx,
6308                        image,
6309                        coordinate,
6310                        array_index,
6311                        sample,
6312                        level,
6313                    )?;
6314                }
6315                crate::Expression::ImageSample {
6316                    image,
6317                    sampler,
6318                    gather,
6319                    coordinate,
6320                    array_index,
6321                    offset,
6322                    level,
6323                    depth_ref,
6324                    clamp_to_edge,
6325                } => {
6326                    self.write_wrapped_image_sample(
6327                        module,
6328                        func_ctx,
6329                        image,
6330                        sampler,
6331                        gather,
6332                        coordinate,
6333                        array_index,
6334                        offset,
6335                        level,
6336                        depth_ref,
6337                        clamp_to_edge,
6338                    )?;
6339                }
6340                crate::Expression::ImageQuery { image, query } => {
6341                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6342                }
6343                _ => {}
6344            }
6345        }
6346
6347        Ok(())
6348    }
6349
6350    // Returns the array of mapped entry point names.
6351    fn write_functions(
6352        &mut self,
6353        module: &crate::Module,
6354        mod_info: &valid::ModuleInfo,
6355        options: &Options,
6356        pipeline_options: &PipelineOptions,
6357    ) -> Result<TranslationInfo, Error> {
6358        use back::msl::VertexFormat;
6359
6360        // Define structs to hold resolved/generated data for vertex buffers and
6361        // their attributes.
6362        struct AttributeMappingResolved {
6363            ty_name: String,
6364            dimension: u32,
6365            ty_is_int: bool,
6366            name: String,
6367        }
6368        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6369
6370        struct VertexBufferMappingResolved<'a> {
6371            id: u32,
6372            stride: u32,
6373            step_mode: back::msl::VertexBufferStepMode,
6374            ty_name: String,
6375            param_name: String,
6376            elem_name: String,
6377            attributes: &'a Vec<back::msl::AttributeMapping>,
6378        }
6379        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6380
6381        // Define a struct to hold a named reference to a byte-unpacking function.
6382        struct UnpackingFunction {
6383            name: String,
6384            byte_count: u32,
6385            dimension: u32,
6386        }
6387        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6388
6389        // Check if we are attempting vertex pulling. If we are, generate some
6390        // names we'll need, and iterate the vertex buffer mappings to output
6391        // all the conversion functions we'll need to unpack the attribute data.
6392        // We can re-use these names for all entry points that need them, since
6393        // those entry points also use self.namer.
6394        let mut needs_vertex_id = false;
6395        let v_id = self.namer.call("v_id");
6396
6397        let mut needs_instance_id = false;
6398        let i_id = self.namer.call("i_id");
6399        if pipeline_options.vertex_pulling_transform {
6400            for vbm in &pipeline_options.vertex_buffer_mappings {
6401                let buffer_id = vbm.id;
6402                let buffer_stride = vbm.stride;
6403
6404                assert!(
6405                    buffer_stride > 0,
6406                    "Vertex pulling requires a non-zero buffer stride."
6407                );
6408
6409                match vbm.step_mode {
6410                    back::msl::VertexBufferStepMode::Constant => {}
6411                    back::msl::VertexBufferStepMode::ByVertex => {
6412                        needs_vertex_id = true;
6413                    }
6414                    back::msl::VertexBufferStepMode::ByInstance => {
6415                        needs_instance_id = true;
6416                    }
6417                }
6418
6419                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6420                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6421                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6422
6423                vbm_resolved.push(VertexBufferMappingResolved {
6424                    id: buffer_id,
6425                    stride: buffer_stride,
6426                    step_mode: vbm.step_mode,
6427                    ty_name: buffer_ty,
6428                    param_name: buffer_param,
6429                    elem_name: buffer_elem,
6430                    attributes: &vbm.attributes,
6431                });
6432
6433                // Iterate the attributes and generate needed unpacking functions.
6434                for attribute in &vbm.attributes {
6435                    if unpacking_functions.contains_key(&attribute.format) {
6436                        continue;
6437                    }
6438                    let (name, byte_count, dimension) =
6439                        match self.write_unpacking_function(attribute.format) {
6440                            Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6441                            _ => {
6442                                continue;
6443                            }
6444                        };
6445                    unpacking_functions.insert(
6446                        attribute.format,
6447                        UnpackingFunction {
6448                            name,
6449                            byte_count,
6450                            dimension,
6451                        },
6452                    );
6453                }
6454            }
6455        }
6456
6457        let mut pass_through_globals = Vec::new();
6458        for (fun_handle, fun) in module.functions.iter() {
6459            log::trace!(
6460                "function {:?}, handle {:?}",
6461                fun.name.as_deref().unwrap_or("(anonymous)"),
6462                fun_handle
6463            );
6464
6465            let ctx = back::FunctionCtx {
6466                ty: back::FunctionType::Function(fun_handle),
6467                info: &mod_info[fun_handle],
6468                expressions: &fun.expressions,
6469                named_expressions: &fun.named_expressions,
6470            };
6471
6472            writeln!(self.out)?;
6473            self.write_wrapped_functions(module, &ctx)?;
6474
6475            let fun_info = &mod_info[fun_handle];
6476            pass_through_globals.clear();
6477            let mut needs_buffer_sizes = false;
6478            for (handle, var) in module.global_variables.iter() {
6479                if !fun_info[handle].is_empty() {
6480                    if var.space.needs_pass_through() {
6481                        pass_through_globals.push(handle);
6482                    }
6483                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6484                }
6485            }
6486
6487            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6488            match fun.result {
6489                Some(ref result) => {
6490                    let ty_name = TypeContext {
6491                        handle: result.ty,
6492                        gctx: module.to_ctx(),
6493                        names: &self.names,
6494                        access: crate::StorageAccess::empty(),
6495                        first_time: false,
6496                    };
6497                    write!(self.out, "{ty_name}")?;
6498                }
6499                None => {
6500                    write!(self.out, "void")?;
6501                }
6502            }
6503            writeln!(self.out, " {fun_name}(")?;
6504
6505            for (index, arg) in fun.arguments.iter().enumerate() {
6506                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6507                let param_type_name = TypeContext {
6508                    handle: arg.ty,
6509                    gctx: module.to_ctx(),
6510                    names: &self.names,
6511                    access: crate::StorageAccess::empty(),
6512                    first_time: false,
6513                };
6514                let separator = separate(
6515                    !pass_through_globals.is_empty()
6516                        || index + 1 != fun.arguments.len()
6517                        || needs_buffer_sizes,
6518                );
6519                writeln!(
6520                    self.out,
6521                    "{}{} {}{}",
6522                    back::INDENT,
6523                    param_type_name,
6524                    name,
6525                    separator
6526                )?;
6527            }
6528            for (index, &handle) in pass_through_globals.iter().enumerate() {
6529                let tyvar = TypedGlobalVariable {
6530                    module,
6531                    names: &self.names,
6532                    handle,
6533                    usage: fun_info[handle],
6534
6535                    reference: true,
6536                };
6537                let separator =
6538                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6539                write!(self.out, "{}", back::INDENT)?;
6540                tyvar.try_fmt(&mut self.out)?;
6541                writeln!(self.out, "{separator}")?;
6542            }
6543
6544            if needs_buffer_sizes {
6545                writeln!(
6546                    self.out,
6547                    "{}constant _mslBufferSizes& _buffer_sizes",
6548                    back::INDENT
6549                )?;
6550            }
6551
6552            writeln!(self.out, ") {{")?;
6553
6554            let guarded_indices =
6555                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6556
6557            let context = StatementContext {
6558                expression: ExpressionContext {
6559                    function: fun,
6560                    origin: FunctionOrigin::Handle(fun_handle),
6561                    info: fun_info,
6562                    lang_version: options.lang_version,
6563                    policies: options.bounds_check_policies,
6564                    guarded_indices,
6565                    module,
6566                    mod_info,
6567                    pipeline_options,
6568                    force_loop_bounding: options.force_loop_bounding,
6569                },
6570                result_struct: None,
6571            };
6572
6573            self.put_locals(&context.expression)?;
6574            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6575            self.put_block(back::Level(1), &fun.body, &context)?;
6576            writeln!(self.out, "}}")?;
6577            self.named_expressions.clear();
6578        }
6579
6580        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6581            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6582
6583        let mut info = TranslationInfo {
6584            entry_point_names: Vec::with_capacity(ep_range.len()),
6585        };
6586
6587        for ep_index in ep_range {
6588            let ep = &module.entry_points[ep_index];
6589            let fun = &ep.function;
6590            let fun_info = mod_info.get_entry_point(ep_index);
6591            let mut ep_error = None;
6592
6593            // For vertex_id and instance_id arguments, presume that we'll
6594            // use our generated names, but switch to the name of an
6595            // existing @builtin param, if we find one.
6596            let mut v_existing_id = None;
6597            let mut i_existing_id = None;
6598
6599            log::trace!(
6600                "entry point {:?}, index {:?}",
6601                fun.name.as_deref().unwrap_or("(anonymous)"),
6602                ep_index
6603            );
6604
6605            let ctx = back::FunctionCtx {
6606                ty: back::FunctionType::EntryPoint(ep_index as u16),
6607                info: fun_info,
6608                expressions: &fun.expressions,
6609                named_expressions: &fun.named_expressions,
6610            };
6611
6612            self.write_wrapped_functions(module, &ctx)?;
6613
6614            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6615                crate::ShaderStage::Vertex => (
6616                    "vertex",
6617                    LocationMode::VertexInput,
6618                    LocationMode::VertexOutput,
6619                    true,
6620                ),
6621                crate::ShaderStage::Fragment => (
6622                    "fragment",
6623                    LocationMode::FragmentInput,
6624                    LocationMode::FragmentOutput,
6625                    false,
6626                ),
6627                crate::ShaderStage::Compute => (
6628                    "kernel",
6629                    LocationMode::Uniform,
6630                    LocationMode::Uniform,
6631                    false,
6632                ),
6633                crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6634            };
6635
6636            // Should this entry point be modified to do vertex pulling?
6637            let do_vertex_pulling = can_vertex_pull
6638                && pipeline_options.vertex_pulling_transform
6639                && !pipeline_options.vertex_buffer_mappings.is_empty();
6640
6641            // Is any global variable used by this entry point dynamically sized?
6642            let needs_buffer_sizes = do_vertex_pulling
6643                || module
6644                    .global_variables
6645                    .iter()
6646                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6647                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6648
6649            // skip this entry point if any global bindings are missing,
6650            // or their types are incompatible.
6651            if !options.fake_missing_bindings {
6652                for (var_handle, var) in module.global_variables.iter() {
6653                    if fun_info[var_handle].is_empty() {
6654                        continue;
6655                    }
6656                    match var.space {
6657                        crate::AddressSpace::Uniform
6658                        | crate::AddressSpace::Storage { .. }
6659                        | crate::AddressSpace::Handle => {
6660                            let br = match var.binding {
6661                                Some(ref br) => br,
6662                                None => {
6663                                    let var_name = var.name.clone().unwrap_or_default();
6664                                    ep_error =
6665                                        Some(super::EntryPointError::MissingBinding(var_name));
6666                                    break;
6667                                }
6668                            };
6669                            let target = options.get_resource_binding_target(ep, br);
6670                            let good = match target {
6671                                Some(target) => {
6672                                    // We intentionally don't dereference binding_arrays here,
6673                                    // so that binding arrays fall to the buffer location.
6674
6675                                    match module.types[var.ty].inner {
6676                                        crate::TypeInner::Image {
6677                                            class: crate::ImageClass::External,
6678                                            ..
6679                                        } => target.external_texture.is_some(),
6680                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6681                                        crate::TypeInner::Sampler { .. } => {
6682                                            target.sampler.is_some()
6683                                        }
6684                                        _ => target.buffer.is_some(),
6685                                    }
6686                                }
6687                                None => false,
6688                            };
6689                            if !good {
6690                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6691                                break;
6692                            }
6693                        }
6694                        crate::AddressSpace::PushConstant => {
6695                            if let Err(e) = options.resolve_push_constants(ep) {
6696                                ep_error = Some(e);
6697                                break;
6698                            }
6699                        }
6700                        crate::AddressSpace::TaskPayload => {
6701                            unimplemented!()
6702                        }
6703                        crate::AddressSpace::Function
6704                        | crate::AddressSpace::Private
6705                        | crate::AddressSpace::WorkGroup => {}
6706                    }
6707                }
6708                if needs_buffer_sizes {
6709                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6710                        ep_error = Some(err);
6711                    }
6712                }
6713            }
6714
6715            if let Some(err) = ep_error {
6716                info.entry_point_names.push(Err(err));
6717                continue;
6718            }
6719            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6720            info.entry_point_names.push(Ok(fun_name.clone()));
6721
6722            writeln!(self.out)?;
6723
6724            // Since `Namer.reset` wasn't expecting struct members to be
6725            // suddenly injected into another namespace like this,
6726            // `self.names` doesn't keep them distinct from other variables.
6727            // Generate fresh names for these arguments, and remember the
6728            // mapping.
6729            let mut flattened_member_names = FastHashMap::default();
6730            // Varyings' members get their own namespace
6731            let mut varyings_namer = proc::Namer::default();
6732
6733            // List all the Naga `EntryPoint`'s `Function`'s arguments,
6734            // flattening structs into their members. In Metal, we will pass
6735            // each of these values to the entry point as a separate argument—
6736            // except for the varyings, handled next.
6737            let mut flattened_arguments = Vec::new();
6738            for (arg_index, arg) in fun.arguments.iter().enumerate() {
6739                match module.types[arg.ty].inner {
6740                    crate::TypeInner::Struct { ref members, .. } => {
6741                        for (member_index, member) in members.iter().enumerate() {
6742                            let member_index = member_index as u32;
6743                            flattened_arguments.push((
6744                                NameKey::StructMember(arg.ty, member_index),
6745                                member.ty,
6746                                member.binding.as_ref(),
6747                            ));
6748                            let name_key = NameKey::StructMember(arg.ty, member_index);
6749                            let name = match member.binding {
6750                                Some(crate::Binding::Location { .. }) => {
6751                                    if do_vertex_pulling {
6752                                        self.namer.call(&self.names[&name_key])
6753                                    } else {
6754                                        varyings_namer.call(&self.names[&name_key])
6755                                    }
6756                                }
6757                                _ => self.namer.call(&self.names[&name_key]),
6758                            };
6759                            flattened_member_names.insert(name_key, name);
6760                        }
6761                    }
6762                    _ => flattened_arguments.push((
6763                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6764                        arg.ty,
6765                        arg.binding.as_ref(),
6766                    )),
6767                }
6768            }
6769
6770            // Identify the varyings among the argument values, and maybe emit
6771            // a struct type named `<fun>Input` to hold them. If we are doing
6772            // vertex pulling, we instead update our attribute mapping to
6773            // note the types, names, and zero values of the attributes.
6774            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6775            let varyings_member_name = self.namer.call("varyings");
6776            let mut has_varyings = false;
6777            if !flattened_arguments.is_empty() {
6778                if !do_vertex_pulling {
6779                    writeln!(self.out, "struct {stage_in_name} {{")?;
6780                }
6781                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6782                    let (binding, location) = match binding {
6783                        Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6784                            (binding, location)
6785                        }
6786                        _ => continue,
6787                    };
6788                    let name = match *name_key {
6789                        NameKey::StructMember(..) => &flattened_member_names[name_key],
6790                        _ => &self.names[name_key],
6791                    };
6792                    let ty_name = TypeContext {
6793                        handle: ty,
6794                        gctx: module.to_ctx(),
6795                        names: &self.names,
6796                        access: crate::StorageAccess::empty(),
6797                        first_time: false,
6798                    };
6799                    let resolved = options.resolve_local_binding(binding, in_mode)?;
6800                    if do_vertex_pulling {
6801                        // Update our attribute mapping.
6802                        am_resolved.insert(
6803                            location,
6804                            AttributeMappingResolved {
6805                                ty_name: ty_name.to_string(),
6806                                dimension: ty_name.vertex_input_dimension(),
6807                                ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6808                                name: name.to_string(),
6809                            },
6810                        );
6811                    } else {
6812                        has_varyings = true;
6813                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6814                        resolved.try_fmt(&mut self.out)?;
6815                        writeln!(self.out, ";")?;
6816                    }
6817                }
6818                if !do_vertex_pulling {
6819                    writeln!(self.out, "}};")?;
6820                }
6821            }
6822
6823            // Define a struct type named for the return value, if any, named
6824            // `<fun>Output`.
6825            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6826            let result_member_name = self.namer.call("member");
6827            let result_type_name = match fun.result {
6828                Some(ref result) => {
6829                    let mut result_members = Vec::new();
6830                    if let crate::TypeInner::Struct { ref members, .. } =
6831                        module.types[result.ty].inner
6832                    {
6833                        for (member_index, member) in members.iter().enumerate() {
6834                            result_members.push((
6835                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6836                                member.ty,
6837                                member.binding.as_ref(),
6838                            ));
6839                        }
6840                    } else {
6841                        result_members.push((
6842                            &result_member_name,
6843                            result.ty,
6844                            result.binding.as_ref(),
6845                        ));
6846                    }
6847
6848                    writeln!(self.out, "struct {stage_out_name} {{")?;
6849                    let mut has_point_size = false;
6850                    for (name, ty, binding) in result_members {
6851                        let ty_name = TypeContext {
6852                            handle: ty,
6853                            gctx: module.to_ctx(),
6854                            names: &self.names,
6855                            access: crate::StorageAccess::empty(),
6856                            first_time: true,
6857                        };
6858                        let binding = binding.ok_or_else(|| {
6859                            Error::GenericValidation("Expected binding, got None".into())
6860                        })?;
6861
6862                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6863                            has_point_size = true;
6864                            if !pipeline_options.allow_and_force_point_size {
6865                                continue;
6866                            }
6867                        }
6868
6869                        let array_len = match module.types[ty].inner {
6870                            crate::TypeInner::Array {
6871                                size: crate::ArraySize::Constant(size),
6872                                ..
6873                            } => Some(size),
6874                            _ => None,
6875                        };
6876                        let resolved = options.resolve_local_binding(binding, out_mode)?;
6877                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6878                        if let Some(array_len) = array_len {
6879                            write!(self.out, " [{array_len}]")?;
6880                        }
6881                        resolved.try_fmt(&mut self.out)?;
6882                        writeln!(self.out, ";")?;
6883                    }
6884
6885                    if pipeline_options.allow_and_force_point_size
6886                        && ep.stage == crate::ShaderStage::Vertex
6887                        && !has_point_size
6888                    {
6889                        // inject the point size output last
6890                        writeln!(
6891                            self.out,
6892                            "{}float _point_size [[point_size]];",
6893                            back::INDENT
6894                        )?;
6895                    }
6896                    writeln!(self.out, "}};")?;
6897                    &stage_out_name
6898                }
6899                None => "void",
6900            };
6901
6902            // If we're doing a vertex pulling transform, define the buffer
6903            // structure types.
6904            if do_vertex_pulling {
6905                for vbm in &vbm_resolved {
6906                    let buffer_stride = vbm.stride;
6907                    let buffer_ty = &vbm.ty_name;
6908
6909                    // Define a structure of bytes of the appropriate size.
6910                    // When we access the attributes, we'll be unpacking these
6911                    // bytes at some offset.
6912                    writeln!(
6913                        self.out,
6914                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6915                    )?;
6916                }
6917            }
6918
6919            // Write the entry point function's name, and begin its argument list.
6920            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6921
6922            let mut is_first_argument = true;
6923            let mut separator = || {
6924                if is_first_argument {
6925                    is_first_argument = false;
6926                    ' '
6927                } else {
6928                    ','
6929                }
6930            };
6931
6932            // If we have produced a struct holding the `EntryPoint`'s
6933            // `Function`'s arguments' varyings, pass that struct first.
6934            if has_varyings {
6935                writeln!(
6936                    self.out,
6937                    "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
6938                    separator()
6939                )?;
6940            }
6941
6942            let mut local_invocation_id = None;
6943
6944            // Then pass the remaining arguments not included in the varyings
6945            // struct.
6946            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6947                let binding = match binding {
6948                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6949                    _ => continue,
6950                };
6951                let name = match *name_key {
6952                    NameKey::StructMember(..) => &flattened_member_names[name_key],
6953                    _ => &self.names[name_key],
6954                };
6955
6956                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6957                    local_invocation_id = Some(name_key);
6958                }
6959
6960                let ty_name = TypeContext {
6961                    handle: ty,
6962                    gctx: module.to_ctx(),
6963                    names: &self.names,
6964                    access: crate::StorageAccess::empty(),
6965                    first_time: false,
6966                };
6967
6968                match *binding {
6969                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
6970                        v_existing_id = Some(name.clone());
6971                    }
6972                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
6973                        i_existing_id = Some(name.clone());
6974                    }
6975                    _ => {}
6976                };
6977
6978                let resolved = options.resolve_local_binding(binding, in_mode)?;
6979                write!(self.out, "{} {ty_name} {name}", separator())?;
6980                resolved.try_fmt(&mut self.out)?;
6981                writeln!(self.out)?;
6982            }
6983
6984            let need_workgroup_variables_initialization =
6985                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
6986
6987            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
6988                writeln!(
6989                    self.out,
6990                    "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
6991                )?;
6992            }
6993
6994            // Those global variables used by this entry point and its callees
6995            // get passed as arguments. `Private` globals are an exception, they
6996            // don't outlive this invocation, so we declare them below as locals
6997            // within the entry point.
6998            for (handle, var) in module.global_variables.iter() {
6999                let usage = fun_info[handle];
7000                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7001                    continue;
7002                }
7003
7004                if options.lang_version < (1, 2) {
7005                    match var.space {
7006                        // This restriction is not documented in the MSL spec
7007                        // but validation will fail if it is not upheld.
7008                        //
7009                        // We infer the required version from the "Function
7010                        // Buffer Read-Writes" section of [what's new], where
7011                        // the feature sets listed correspond with the ones
7012                        // supporting MSL 1.2.
7013                        //
7014                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7015                        crate::AddressSpace::Storage { access }
7016                            if access.contains(crate::StorageAccess::STORE)
7017                                && ep.stage == crate::ShaderStage::Fragment =>
7018                        {
7019                            return Err(Error::UnsupportedWriteableStorageBuffer)
7020                        }
7021                        crate::AddressSpace::Handle => {
7022                            match module.types[var.ty].inner {
7023                                crate::TypeInner::Image {
7024                                    class: crate::ImageClass::Storage { access, .. },
7025                                    ..
7026                                } => {
7027                                    // This restriction is not documented in the MSL spec
7028                                    // but validation will fail if it is not upheld.
7029                                    //
7030                                    // We infer the required version from the "Function
7031                                    // Texture Read-Writes" section of [what's new], where
7032                                    // the feature sets listed correspond with the ones
7033                                    // supporting MSL 1.2.
7034                                    //
7035                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7036                                    if access.contains(crate::StorageAccess::STORE)
7037                                        && (ep.stage == crate::ShaderStage::Vertex
7038                                            || ep.stage == crate::ShaderStage::Fragment)
7039                                    {
7040                                        return Err(Error::UnsupportedWriteableStorageTexture(
7041                                            ep.stage,
7042                                        ));
7043                                    }
7044
7045                                    if access.contains(
7046                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7047                                    ) {
7048                                        return Err(Error::UnsupportedRWStorageTexture);
7049                                    }
7050                                }
7051                                _ => {}
7052                            }
7053                        }
7054                        _ => {}
7055                    }
7056                }
7057
7058                // Check min MSL version for binding arrays
7059                match var.space {
7060                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7061                        crate::TypeInner::BindingArray { base, .. } => {
7062                            match module.types[base].inner {
7063                                crate::TypeInner::Sampler { .. } => {
7064                                    if options.lang_version < (2, 0) {
7065                                        return Err(Error::UnsupportedArrayOf(
7066                                            "samplers".to_string(),
7067                                        ));
7068                                    }
7069                                }
7070                                crate::TypeInner::Image { class, .. } => match class {
7071                                    crate::ImageClass::Sampled { .. }
7072                                    | crate::ImageClass::Depth { .. }
7073                                    | crate::ImageClass::Storage {
7074                                        access: crate::StorageAccess::LOAD,
7075                                        ..
7076                                    } => {
7077                                        // Array of textures since:
7078                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7079                                        // - macOS: Metal 2
7080
7081                                        if options.lang_version < (2, 0) {
7082                                            return Err(Error::UnsupportedArrayOf(
7083                                                "textures".to_string(),
7084                                            ));
7085                                        }
7086                                    }
7087                                    crate::ImageClass::Storage {
7088                                        access: crate::StorageAccess::STORE,
7089                                        ..
7090                                    } => {
7091                                        // Array of write-only textures since:
7092                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7093                                        // - macOS: Metal 2
7094
7095                                        if options.lang_version < (2, 0) {
7096                                            return Err(Error::UnsupportedArrayOf(
7097                                                "write-only textures".to_string(),
7098                                            ));
7099                                        }
7100                                    }
7101                                    crate::ImageClass::Storage { .. } => {
7102                                        return Err(Error::UnsupportedArrayOf(
7103                                            "read-write textures".to_string(),
7104                                        ));
7105                                    }
7106                                    crate::ImageClass::External => {
7107                                        return Err(Error::UnsupportedArrayOf(
7108                                            "external textures".to_string(),
7109                                        ));
7110                                    }
7111                                },
7112                                _ => {
7113                                    return Err(Error::UnsupportedArrayOfType(base));
7114                                }
7115                            }
7116                        }
7117                        _ => {}
7118                    },
7119                    _ => {}
7120                }
7121
7122                // the resolves have already been checked for `!fake_missing_bindings` case
7123                let resolved = match var.space {
7124                    crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7125                    crate::AddressSpace::WorkGroup => None,
7126                    _ => options
7127                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7128                        .ok(),
7129                };
7130                if let Some(ref resolved) = resolved {
7131                    // Inline samplers are be defined in the EP body
7132                    if resolved.as_inline_sampler(options).is_some() {
7133                        continue;
7134                    }
7135                }
7136
7137                match module.types[var.ty].inner {
7138                    crate::TypeInner::Image {
7139                        class: crate::ImageClass::External,
7140                        ..
7141                    } => {
7142                        // External texture global variables get lowered to 3 textures
7143                        // and a constant buffer. We must emit a separate argument for
7144                        // each of these.
7145                        let target = match resolved {
7146                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7147                                target.external_texture
7148                            }
7149                            _ => None,
7150                        };
7151
7152                        for i in 0..3 {
7153                            write!(self.out, "{} ", separator())?;
7154
7155                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7156                                handle,
7157                                ExternalTextureNameKey::Plane(i),
7158                            )];
7159                            write!(
7160                              self.out,
7161                              "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7162                            )?;
7163                            if let Some(ref target) = target {
7164                                write!(self.out, " [[texture({})]]", target.planes[i])?;
7165                            }
7166                            writeln!(self.out)?;
7167                        }
7168                        let params_ty_name = &self.names
7169                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7170                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7171                            handle,
7172                            ExternalTextureNameKey::Params,
7173                        )];
7174                        write!(self.out, "{} ", separator())?;
7175                        write!(self.out, "constant {params_ty_name}& {params_name}")?;
7176                        if let Some(ref target) = target {
7177                            write!(self.out, " [[buffer({})]]", target.params)?;
7178                        }
7179                    }
7180                    _ => {
7181                        let tyvar = TypedGlobalVariable {
7182                            module,
7183                            names: &self.names,
7184                            handle,
7185                            usage,
7186                            reference: true,
7187                        };
7188                        write!(self.out, "{} ", separator())?;
7189                        tyvar.try_fmt(&mut self.out)?;
7190                        if let Some(resolved) = resolved {
7191                            resolved.try_fmt(&mut self.out)?;
7192                        }
7193                        if let Some(value) = var.init {
7194                            write!(self.out, " = ")?;
7195                            self.put_const_expression(
7196                                value,
7197                                module,
7198                                mod_info,
7199                                &module.global_expressions,
7200                            )?;
7201                        }
7202                    }
7203                }
7204                writeln!(self.out)?;
7205            }
7206
7207            if do_vertex_pulling {
7208                if needs_vertex_id && v_existing_id.is_none() {
7209                    // Write the [[vertex_id]] argument.
7210                    writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7211                }
7212
7213                if needs_instance_id && i_existing_id.is_none() {
7214                    writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7215                }
7216
7217                // Iterate vbm_resolved, output one argument for every vertex buffer,
7218                // using the names we generated earlier.
7219                for vbm in &vbm_resolved {
7220                    let id = &vbm.id;
7221                    let ty_name = &vbm.ty_name;
7222                    let param_name = &vbm.param_name;
7223                    writeln!(
7224                        self.out,
7225                        "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7226                        separator()
7227                    )?;
7228                }
7229            }
7230
7231            // If this entry uses any variable-length arrays, their sizes are
7232            // passed as a final struct-typed argument.
7233            if needs_buffer_sizes {
7234                // this is checked earlier
7235                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7236                write!(
7237                    self.out,
7238                    "{} constant _mslBufferSizes& _buffer_sizes",
7239                    separator()
7240                )?;
7241                resolved.try_fmt(&mut self.out)?;
7242                writeln!(self.out)?;
7243            }
7244
7245            // end of the entry point argument list
7246            writeln!(self.out, ") {{")?;
7247
7248            // Starting the function body.
7249            if do_vertex_pulling {
7250                // Provide zero values for all the attributes, which we will overwrite with
7251                // real data from the vertex attribute buffers, if the indices are in-bounds.
7252                for vbm in &vbm_resolved {
7253                    for attribute in vbm.attributes {
7254                        let location = attribute.shader_location;
7255                        let am_option = am_resolved.get(&location);
7256                        if am_option.is_none() {
7257                            // This bound attribute isn't used in this entry point, so
7258                            // don't bother zero-initializing it.
7259                            continue;
7260                        }
7261                        let am = am_option.unwrap();
7262                        let attribute_ty_name = &am.ty_name;
7263                        let attribute_name = &am.name;
7264
7265                        writeln!(
7266                            self.out,
7267                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7268                            back::Level(1)
7269                        )?;
7270                    }
7271
7272                    // Output a bounds check block that will set real values for the
7273                    // attributes, if the bounds are satisfied.
7274                    write!(self.out, "{}if (", back::Level(1))?;
7275
7276                    let idx = &vbm.id;
7277                    let stride = &vbm.stride;
7278                    let index_name = match vbm.step_mode {
7279                        back::msl::VertexBufferStepMode::Constant => "0",
7280                        back::msl::VertexBufferStepMode::ByVertex => {
7281                            if let Some(ref name) = v_existing_id {
7282                                name
7283                            } else {
7284                                &v_id
7285                            }
7286                        }
7287                        back::msl::VertexBufferStepMode::ByInstance => {
7288                            if let Some(ref name) = i_existing_id {
7289                                name
7290                            } else {
7291                                &i_id
7292                            }
7293                        }
7294                    };
7295                    write!(
7296                        self.out,
7297                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7298                    )?;
7299
7300                    writeln!(self.out, ") {{")?;
7301
7302                    // Pull the bytes out of the vertex buffer.
7303                    let ty_name = &vbm.ty_name;
7304                    let elem_name = &vbm.elem_name;
7305                    let param_name = &vbm.param_name;
7306
7307                    writeln!(
7308                        self.out,
7309                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7310                        back::Level(2),
7311                    )?;
7312
7313                    // Now set real values for each of the attributes, by unpacking the data
7314                    // from the buffer elements.
7315                    for attribute in vbm.attributes {
7316                        let location = attribute.shader_location;
7317                        let Some(am) = am_resolved.get(&location) else {
7318                            // This bound attribute isn't used in this entry point, so
7319                            // don't bother extracting the data. Too bad we emitted the
7320                            // unpacking function earlier -- it might not get used.
7321                            continue;
7322                        };
7323                        let attribute_name = &am.name;
7324                        let attribute_ty_name = &am.ty_name;
7325
7326                        let offset = attribute.offset;
7327                        let func = unpacking_functions
7328                            .get(&attribute.format)
7329                            .expect("Should have generated this unpacking function earlier.");
7330                        let func_name = &func.name;
7331
7332                        // Check dimensionality of the attribute compared to the unpacking
7333                        // function. If attribute dimension > unpack dimension, we have to
7334                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7335                        // scalar type. Otherwise, if attribute dimension is < unpack
7336                        // dimension, then we need to explicitly truncate the result.
7337
7338                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7339
7340                        if needs_padding_or_truncation != Ordering::Equal {
7341                            // Emit a comment flagging that a conversion is happening,
7342                            // since the actual logic can be at the end of a long line.
7343                            writeln!(
7344                                self.out,
7345                                "{}// {attribute_ty_name} <- {:?}",
7346                                back::Level(2),
7347                                attribute.format
7348                            )?;
7349                        }
7350
7351                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7352
7353                        if needs_padding_or_truncation == Ordering::Greater {
7354                            // Needs padding: emit constructor call for wider type
7355                            write!(self.out, "{attribute_ty_name}(")?;
7356                        }
7357
7358                        // Emit call to unpacking function
7359                        write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7360                        for i in (offset + 1)..(offset + func.byte_count) {
7361                            write!(self.out, ", {elem_name}.data[{i}]")?;
7362                        }
7363                        write!(self.out, ")")?;
7364
7365                        match needs_padding_or_truncation {
7366                            Ordering::Greater => {
7367                                // Padding
7368                                let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7369                                let one_value = if am.ty_is_int { "1" } else { "1.0" };
7370                                for i in func.dimension..am.dimension {
7371                                    write!(
7372                                        self.out,
7373                                        ", {}",
7374                                        if i == 3 { one_value } else { zero_value }
7375                                    )?;
7376                                }
7377                                write!(self.out, ")")?;
7378                            }
7379                            Ordering::Less => {
7380                                // Truncate to the first `am.dimension` components
7381                                write!(
7382                                    self.out,
7383                                    ".{}",
7384                                    &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7385                                )?;
7386                            }
7387                            Ordering::Equal => {}
7388                        }
7389
7390                        writeln!(self.out, ";")?;
7391                    }
7392
7393                    // End the bounds check / attribute setting block.
7394                    writeln!(self.out, "{}}}", back::Level(1))?;
7395                }
7396            }
7397
7398            if need_workgroup_variables_initialization {
7399                self.write_workgroup_variables_initialization(
7400                    module,
7401                    mod_info,
7402                    fun_info,
7403                    local_invocation_id,
7404                )?;
7405            }
7406
7407            // Metal doesn't support private mutable variables outside of functions,
7408            // so we put them here, just like the locals.
7409            for (handle, var) in module.global_variables.iter() {
7410                let usage = fun_info[handle];
7411                if usage.is_empty() {
7412                    continue;
7413                }
7414                if var.space == crate::AddressSpace::Private {
7415                    let tyvar = TypedGlobalVariable {
7416                        module,
7417                        names: &self.names,
7418                        handle,
7419                        usage,
7420
7421                        reference: false,
7422                    };
7423                    write!(self.out, "{}", back::INDENT)?;
7424                    tyvar.try_fmt(&mut self.out)?;
7425                    match var.init {
7426                        Some(value) => {
7427                            write!(self.out, " = ")?;
7428                            self.put_const_expression(
7429                                value,
7430                                module,
7431                                mod_info,
7432                                &module.global_expressions,
7433                            )?;
7434                            writeln!(self.out, ";")?;
7435                        }
7436                        None => {
7437                            writeln!(self.out, " = {{}};")?;
7438                        }
7439                    };
7440                } else if let Some(ref binding) = var.binding {
7441                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7442                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7443                        // write an inline sampler
7444                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7445                        writeln!(
7446                            self.out,
7447                            "{}constexpr {}::sampler {}(",
7448                            back::INDENT,
7449                            NAMESPACE,
7450                            name
7451                        )?;
7452                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7453                        writeln!(self.out, "{});", back::INDENT)?;
7454                    } else if let crate::TypeInner::Image {
7455                        class: crate::ImageClass::External,
7456                        ..
7457                    } = module.types[var.ty].inner
7458                    {
7459                        // Wrap the individual arguments for each external texture global
7460                        // in a struct which can be easily passed around.
7461                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7462                        let l1 = back::Level(1);
7463                        let l2 = l1.next();
7464                        writeln!(
7465                            self.out,
7466                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7467                        )?;
7468                        for i in 0..3 {
7469                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7470                                handle,
7471                                ExternalTextureNameKey::Plane(i),
7472                            )];
7473                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7474                        }
7475                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7476                            handle,
7477                            ExternalTextureNameKey::Params,
7478                        )];
7479                        writeln!(self.out, "{l2}.params = {params_name},")?;
7480                        writeln!(self.out, "{l1}}};")?;
7481                    }
7482                }
7483            }
7484
7485            // Now take the arguments that we gathered into structs, and the
7486            // structs that we flattened into arguments, and emit local
7487            // variables with initializers that put everything back the way the
7488            // body code expects.
7489            //
7490            // If we had to generate fresh names for struct members passed as
7491            // arguments, be sure to use those names when rebuilding the struct.
7492            //
7493            // "Each day, I change some zeros to ones, and some ones to zeros.
7494            // The rest, I leave alone."
7495            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7496                let arg_name =
7497                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7498                match module.types[arg.ty].inner {
7499                    crate::TypeInner::Struct { ref members, .. } => {
7500                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7501                        write!(
7502                            self.out,
7503                            "{}const {} {} = {{ ",
7504                            back::INDENT,
7505                            struct_name,
7506                            arg_name
7507                        )?;
7508                        for (member_index, member) in members.iter().enumerate() {
7509                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7510                            let name = &flattened_member_names[&key];
7511                            if member_index != 0 {
7512                                write!(self.out, ", ")?;
7513                            }
7514                            // insert padding initialization, if needed
7515                            if self
7516                                .struct_member_pads
7517                                .contains(&(arg.ty, member_index as u32))
7518                            {
7519                                write!(self.out, "{{}}, ")?;
7520                            }
7521                            if let Some(crate::Binding::Location { .. }) = member.binding {
7522                                if has_varyings {
7523                                    write!(self.out, "{varyings_member_name}.")?;
7524                                }
7525                            }
7526                            write!(self.out, "{name}")?;
7527                        }
7528                        writeln!(self.out, " }};")?;
7529                    }
7530                    _ => {
7531                        if let Some(crate::Binding::Location { .. }) = arg.binding {
7532                            if has_varyings {
7533                                writeln!(
7534                                    self.out,
7535                                    "{}const auto {} = {}.{};",
7536                                    back::INDENT,
7537                                    arg_name,
7538                                    varyings_member_name,
7539                                    arg_name
7540                                )?;
7541                            }
7542                        }
7543                    }
7544                }
7545            }
7546
7547            let guarded_indices =
7548                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7549
7550            let context = StatementContext {
7551                expression: ExpressionContext {
7552                    function: fun,
7553                    origin: FunctionOrigin::EntryPoint(ep_index as _),
7554                    info: fun_info,
7555                    lang_version: options.lang_version,
7556                    policies: options.bounds_check_policies,
7557                    guarded_indices,
7558                    module,
7559                    mod_info,
7560                    pipeline_options,
7561                    force_loop_bounding: options.force_loop_bounding,
7562                },
7563                result_struct: Some(&stage_out_name),
7564            };
7565
7566            // Finally, declare all the local variables that we need
7567            //TODO: we can postpone this till the relevant expressions are emitted
7568            self.put_locals(&context.expression)?;
7569            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7570            self.put_block(back::Level(1), &fun.body, &context)?;
7571            writeln!(self.out, "}}")?;
7572            if ep_index + 1 != module.entry_points.len() {
7573                writeln!(self.out)?;
7574            }
7575            self.named_expressions.clear();
7576        }
7577
7578        Ok(info)
7579    }
7580
7581    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7582        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
7583        // so we try to avoid it here.
7584        if flags.is_empty() {
7585            writeln!(
7586                self.out,
7587                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7588            )?;
7589        }
7590        if flags.contains(crate::Barrier::STORAGE) {
7591            writeln!(
7592                self.out,
7593                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7594            )?;
7595        }
7596        if flags.contains(crate::Barrier::WORK_GROUP) {
7597            writeln!(
7598                self.out,
7599                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7600            )?;
7601        }
7602        if flags.contains(crate::Barrier::SUB_GROUP) {
7603            writeln!(
7604                self.out,
7605                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7606            )?;
7607        }
7608        if flags.contains(crate::Barrier::TEXTURE) {
7609            writeln!(
7610                self.out,
7611                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7612            )?;
7613        }
7614        Ok(())
7615    }
7616}
7617
7618/// Initializing workgroup variables is more tricky for Metal because we have to deal
7619/// with atomics at the type-level (which don't have a copy constructor).
7620mod workgroup_mem_init {
7621    use crate::EntryPoint;
7622
7623    use super::*;
7624
7625    enum Access {
7626        GlobalVariable(Handle<crate::GlobalVariable>),
7627        StructMember(Handle<crate::Type>, u32),
7628        Array(usize),
7629    }
7630
7631    impl Access {
7632        fn write<W: Write>(
7633            &self,
7634            writer: &mut W,
7635            names: &FastHashMap<NameKey, String>,
7636        ) -> Result<(), core::fmt::Error> {
7637            match *self {
7638                Access::GlobalVariable(handle) => {
7639                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7640                }
7641                Access::StructMember(handle, index) => {
7642                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7643                }
7644                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7645            }
7646        }
7647    }
7648
7649    struct AccessStack {
7650        stack: Vec<Access>,
7651        array_depth: usize,
7652    }
7653
7654    impl AccessStack {
7655        const fn new() -> Self {
7656            Self {
7657                stack: Vec::new(),
7658                array_depth: 0,
7659            }
7660        }
7661
7662        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7663            let array_depth = self.array_depth;
7664            self.stack.push(Access::Array(array_depth));
7665            self.array_depth += 1;
7666            let res = cb(self, array_depth);
7667            self.stack.pop();
7668            self.array_depth -= 1;
7669            res
7670        }
7671
7672        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7673            self.stack.push(new);
7674            let res = cb(self);
7675            self.stack.pop();
7676            res
7677        }
7678
7679        fn write<W: Write>(
7680            &self,
7681            writer: &mut W,
7682            names: &FastHashMap<NameKey, String>,
7683        ) -> Result<(), core::fmt::Error> {
7684            for next in self.stack.iter() {
7685                next.write(writer, names)?;
7686            }
7687            Ok(())
7688        }
7689    }
7690
7691    impl<W: Write> Writer<W> {
7692        pub(super) fn need_workgroup_variables_initialization(
7693            &mut self,
7694            options: &Options,
7695            ep: &EntryPoint,
7696            module: &crate::Module,
7697            fun_info: &valid::FunctionInfo,
7698        ) -> bool {
7699            options.zero_initialize_workgroup_memory
7700                && ep.stage.compute_like()
7701                && module.global_variables.iter().any(|(handle, var)| {
7702                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7703                })
7704        }
7705
7706        pub(super) fn write_workgroup_variables_initialization(
7707            &mut self,
7708            module: &crate::Module,
7709            module_info: &valid::ModuleInfo,
7710            fun_info: &valid::FunctionInfo,
7711            local_invocation_id: Option<&NameKey>,
7712        ) -> BackendResult {
7713            let level = back::Level(1);
7714
7715            writeln!(
7716                self.out,
7717                "{}if ({}::all({} == {}::uint3(0u))) {{",
7718                level,
7719                NAMESPACE,
7720                local_invocation_id
7721                    .map(|name_key| self.names[name_key].as_str())
7722                    .unwrap_or("__local_invocation_id"),
7723                NAMESPACE,
7724            )?;
7725
7726            let mut access_stack = AccessStack::new();
7727
7728            let vars = module.global_variables.iter().filter(|&(handle, var)| {
7729                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7730            });
7731
7732            for (handle, var) in vars {
7733                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7734                    self.write_workgroup_variable_initialization(
7735                        module,
7736                        module_info,
7737                        var.ty,
7738                        access_stack,
7739                        level.next(),
7740                    )
7741                })?;
7742            }
7743
7744            writeln!(self.out, "{level}}}")?;
7745            self.write_barrier(crate::Barrier::WORK_GROUP, level)
7746        }
7747
7748        fn write_workgroup_variable_initialization(
7749            &mut self,
7750            module: &crate::Module,
7751            module_info: &valid::ModuleInfo,
7752            ty: Handle<crate::Type>,
7753            access_stack: &mut AccessStack,
7754            level: back::Level,
7755        ) -> BackendResult {
7756            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7757                write!(self.out, "{level}")?;
7758                access_stack.write(&mut self.out, &self.names)?;
7759                writeln!(self.out, " = {{}};")?;
7760            } else {
7761                match module.types[ty].inner {
7762                    crate::TypeInner::Atomic { .. } => {
7763                        write!(
7764                            self.out,
7765                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7766                        )?;
7767                        access_stack.write(&mut self.out, &self.names)?;
7768                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7769                    }
7770                    crate::TypeInner::Array { base, size, .. } => {
7771                        let count = match size.resolve(module.to_ctx())? {
7772                            proc::IndexableLength::Known(count) => count,
7773                            proc::IndexableLength::Dynamic => unreachable!(),
7774                        };
7775
7776                        access_stack.enter_array(|access_stack, array_depth| {
7777                            writeln!(
7778                                self.out,
7779                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7780                            )?;
7781                            self.write_workgroup_variable_initialization(
7782                                module,
7783                                module_info,
7784                                base,
7785                                access_stack,
7786                                level.next(),
7787                            )?;
7788                            writeln!(self.out, "{level}}}")?;
7789                            BackendResult::Ok(())
7790                        })?;
7791                    }
7792                    crate::TypeInner::Struct { ref members, .. } => {
7793                        for (index, member) in members.iter().enumerate() {
7794                            access_stack.enter(
7795                                Access::StructMember(ty, index as u32),
7796                                |access_stack| {
7797                                    self.write_workgroup_variable_initialization(
7798                                        module,
7799                                        module_info,
7800                                        member.ty,
7801                                        access_stack,
7802                                        level,
7803                                    )
7804                                },
7805                            )?;
7806                        }
7807                    }
7808                    _ => unreachable!(),
7809                }
7810            }
7811
7812            Ok(())
7813        }
7814    }
7815}
7816
7817impl crate::AtomicFunction {
7818    const fn to_msl(self) -> &'static str {
7819        match self {
7820            Self::Add => "fetch_add",
7821            Self::Subtract => "fetch_sub",
7822            Self::And => "fetch_and",
7823            Self::InclusiveOr => "fetch_or",
7824            Self::ExclusiveOr => "fetch_xor",
7825            Self::Min => "fetch_min",
7826            Self::Max => "fetch_max",
7827            Self::Exchange { compare: None } => "exchange",
7828            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7829        }
7830    }
7831
7832    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7833        Ok(match self {
7834            Self::Min => "min",
7835            Self::Max => "max",
7836            _ => Err(Error::FeatureNotImplemented(
7837                "64-bit atomic operation other than min/max".to_string(),
7838            ))?,
7839        })
7840    }
7841}