naga/proc/
constant_evaluator.rs

1// Code in this file intentionally uses `for` loops and `.push()` rather than
2// `ArrayVec::from_iter`, because the latter is monomorphized by all three of
3// the item type, the capacity, and the iterator type, which can easily bloat
4// the compiled executable (by ~260 KiB, when it was removed).
5
6use alloc::{
7    format,
8    string::{String, ToString},
9    vec,
10    vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19    arena::{Arena, Handle, HandleVec, UniqueArena},
20    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21    ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
28/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
29///
30/// Technique stolen directly from
31/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
32macro_rules! with_dollar_sign {
33    ($($body:tt)*) => {
34        macro_rules! __with_dollar_sign { $($body)* }
35        __with_dollar_sign!($);
36    }
37}
38
39macro_rules! gen_component_wise_extractor {
40    (
41        $ident:ident -> $target:ident,
42        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44    ) => {
45        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
46        #[derive(Debug)]
47        #[cfg_attr(test, derive(PartialEq))]
48        enum $target<const N: usize> {
49            $(
50                #[doc = concat!(
51                    "Maps to [`Literal::",
52                    stringify!($literal),
53                    "`]",
54                )]
55                $mapping([$ty; N]),
56            )+
57        }
58
59        impl From<$target<1>> for Expression {
60            fn from(value: $target<1>) -> Self {
61                match value {
62                    $(
63                        $target::$mapping([value]) => {
64                            Expression::Literal(Literal::$literal(value))
65                        }
66                    )+
67                }
68            }
69        }
70
71        #[doc = concat!(
72            "Attempts to evaluate multiple `exprs` as a combined [`",
73            stringify!($target),
74            "`] to pass to `handler`. ",
75        )]
76        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
77        /// component of each vector.
78        ///
79        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
80        /// same length, a new vector expression is registered, composed of each component emitted
81        /// by `handler`.
82        fn $ident<const N: usize, const M: usize>(
83            eval: &mut ConstantEvaluator<'_>,
84            span: Span,
85            exprs: [Handle<Expression>; N],
86            handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88        where
89            $target<M>: Into<Expression>,
90        {
91            assert!(N > 0);
92            let err = ConstantEvaluatorError::InvalidMathArg;
93            let mut exprs = exprs.into_iter();
94
95            macro_rules! sanitize {
96                ($expr:expr) => {
97                    eval.eval_zero_value_and_splat($expr, span)
98                        .map(|expr| &eval.expressions[expr])
99                };
100            }
101
102            let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103                $(
104                    &Expression::Literal(Literal::$literal(x)) => {
105                        let mut arr = ArrayVec::<_, N>::new();
106                        arr.push(x);
107                        for expr in exprs {
108                            match sanitize!(expr)? {
109                                &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110                                _ => return Err(err),
111                            }
112                        }
113                        let comps = $target::$mapping(arr.into_inner().unwrap());
114                        Ok(handler(comps)?.into())
115                    },
116                )+
117                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118                    &TypeInner::Vector { size, scalar } => match scalar.kind {
119                        $(ScalarKind::$scalar_kind)|* => {
120                            let first_ty = ty;
121                            let mut component_groups =
122                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123                            {
124                                let mut inner = ArrayVec::new();
125                                for item in crate::proc::flatten_compose(
126                                    first_ty,
127                                    components,
128                                    eval.expressions,
129                                    eval.types,
130                                ) {
131                                    inner.push(item);
132                                }
133                                component_groups.push(inner);
134                            }
135                            for expr in exprs {
136                                match sanitize!(expr)? {
137                                    &Expression::Compose { ty, ref components }
138                                        if &eval.types[ty].inner
139                                            == &eval.types[first_ty].inner =>
140                                    {
141                                        let mut inner = ArrayVec::new();
142                                        for item in crate::proc::flatten_compose(
143                                            ty,
144                                            components,
145                                            eval.expressions,
146                                            eval.types,
147                                        ) {
148                                            inner.push(item);
149                                        }
150                                        component_groups.push(inner);
151                                    }
152                                    _ => return Err(err),
153                                }
154                            }
155                            let component_groups = component_groups.into_inner().unwrap();
156                            let mut new_components =
157                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158                            for idx in 0..(size as u8).into() {
159                                let mut group_arr = ArrayVec::<_, N>::new();
160                                for cs in component_groups.iter() {
161                                    group_arr.push(
162                                        cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163                                    );
164                                }
165                                let group = group_arr.into_inner().unwrap();
166                                new_components.push($ident(
167                                    eval,
168                                    span,
169                                    group,
170                                    handler,
171                                )?);
172                            }
173                            Ok(Expression::Compose {
174                                ty: first_ty,
175                                components: new_components.into_iter().collect(),
176                            })
177                        }
178                        _ => return Err(err),
179                    },
180                    _ => return Err(err),
181                },
182                _ => return Err(err),
183            };
184            eval.register_evaluated_expr(new_expr?, span)
185        }
186
187        with_dollar_sign! {
188            ($d:tt) => {
189                #[allow(unused)]
190                #[doc = concat!(
191                    "A convenience macro for using the same RHS for each [`",
192                    stringify!($target),
193                    "`] variant in a call to [`",
194                    stringify!($ident),
195                    "`].",
196                )]
197                macro_rules! $ident {
198                    (
199                        $eval:expr,
200                        $span:expr,
201                        [$d ($d expr:expr),+ $d (,)?],
202                        |$d ($d arg:ident),+| $d tt:tt
203                    ) => {
204                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
205                            $(
206                                $target::$mapping([$d ($d arg),+]) => {
207                                    let res = $d tt;
208                                    Result::map(res, $target::$mapping)
209                                },
210                            )+
211                        })
212                    };
213                }
214            };
215        }
216    };
217}
218
219gen_component_wise_extractor! {
220    component_wise_scalar -> Scalar,
221    literals: [
222        AbstractFloat => AbstractFloat: f64,
223        F32 => F32: f32,
224        F16 => F16: f16,
225        AbstractInt => AbstractInt: i64,
226        U32 => U32: u32,
227        I32 => I32: i32,
228        U64 => U64: u64,
229        I64 => I64: i64,
230    ],
231    scalar_kinds: [
232        Float,
233        AbstractFloat,
234        Sint,
235        Uint,
236        AbstractInt,
237    ],
238}
239
240gen_component_wise_extractor! {
241    component_wise_float -> Float,
242    literals: [
243        AbstractFloat => Abstract: f64,
244        F32 => F32: f32,
245        F16 => F16: f16,
246    ],
247    scalar_kinds: [
248        Float,
249        AbstractFloat,
250    ],
251}
252
253gen_component_wise_extractor! {
254    component_wise_concrete_int -> ConcreteInt,
255    literals: [
256        U32 => U32: u32,
257        I32 => I32: i32,
258    ],
259    scalar_kinds: [
260        Sint,
261        Uint,
262    ],
263}
264
265gen_component_wise_extractor! {
266    component_wise_signed -> Signed,
267    literals: [
268        AbstractFloat => AbstractFloat: f64,
269        AbstractInt => AbstractInt: i64,
270        F32 => F32: f32,
271        F16 => F16: f16,
272        I32 => I32: i32,
273    ],
274    scalar_kinds: [
275        Sint,
276        AbstractInt,
277        Float,
278        AbstractFloat,
279    ],
280}
281
282/// Vectors with a concrete element type.
283#[derive(Debug)]
284enum LiteralVector {
285    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288    U16(ArrayVec<u16, { crate::VectorSize::MAX }>),
289    I16(ArrayVec<i16, { crate::VectorSize::MAX }>),
290    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
291    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
292    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
293    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
294    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
295    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
296    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
297}
298
299impl LiteralVector {
300    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
301    fn len(&self) -> usize {
302        match *self {
303            LiteralVector::F64(ref v) => v.len(),
304            LiteralVector::F32(ref v) => v.len(),
305            LiteralVector::F16(ref v) => v.len(),
306            LiteralVector::U16(ref v) => v.len(),
307            LiteralVector::I16(ref v) => v.len(),
308            LiteralVector::U32(ref v) => v.len(),
309            LiteralVector::I32(ref v) => v.len(),
310            LiteralVector::U64(ref v) => v.len(),
311            LiteralVector::I64(ref v) => v.len(),
312            LiteralVector::Bool(ref v) => v.len(),
313            LiteralVector::AbstractInt(ref v) => v.len(),
314            LiteralVector::AbstractFloat(ref v) => v.len(),
315        }
316    }
317
318    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
319    fn from_literal(literal: Literal) -> Self {
320        fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
321            let mut v = ArrayVec::new();
322            v.push(val);
323            v
324        }
325        match literal {
326            Literal::F64(e) => Self::F64(arrayvec_of(e)),
327            Literal::F32(e) => Self::F32(arrayvec_of(e)),
328            Literal::U16(e) => Self::U16(arrayvec_of(e)),
329            Literal::I16(e) => Self::I16(arrayvec_of(e)),
330            Literal::U32(e) => Self::U32(arrayvec_of(e)),
331            Literal::I32(e) => Self::I32(arrayvec_of(e)),
332            Literal::U64(e) => Self::U64(arrayvec_of(e)),
333            Literal::I64(e) => Self::I64(arrayvec_of(e)),
334            Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
335            Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
336            Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
337            Literal::F16(e) => Self::F16(arrayvec_of(e)),
338        }
339    }
340
341    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
342    /// Returns error if components types do not match.
343    /// # Panics
344    /// Panics if vector is empty
345    fn from_literal_vec(
346        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
347    ) -> Result<Self, ConstantEvaluatorError> {
348        assert!(!components.is_empty());
349        // TODO: should a vector of i32 be constructible from abstract int?
350        macro_rules! compose_literals {
351            ($components:expr, $variant:ident, $self_variant:ident) => {{
352                let mut out = ArrayVec::new();
353                for l in &$components {
354                    match l {
355                        &Literal::$variant(v) => out.push(v),
356                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
357                    }
358                }
359                Self::$self_variant(out)
360            }};
361        }
362        Ok(match components[0] {
363            Literal::I16(_) => compose_literals!(components, I16, I16),
364            Literal::U16(_) => compose_literals!(components, U16, U16),
365            Literal::I32(_) => compose_literals!(components, I32, I32),
366            Literal::U32(_) => compose_literals!(components, U32, U32),
367            Literal::I64(_) => compose_literals!(components, I64, I64),
368            Literal::U64(_) => compose_literals!(components, U64, U64),
369            Literal::F32(_) => compose_literals!(components, F32, F32),
370            Literal::F64(_) => compose_literals!(components, F64, F64),
371            Literal::Bool(_) => compose_literals!(components, Bool, Bool),
372            Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
373            Literal::AbstractFloat(_) => {
374                compose_literals!(components, AbstractFloat, AbstractFloat)
375            }
376            Literal::F16(_) => compose_literals!(components, F16, F16),
377        })
378    }
379
380    #[allow(dead_code)]
381    /// Returns [`ArrayVec`] of [`Literal`]s
382    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
383        macro_rules! decompose_literals {
384            ($v:expr, $variant:ident) => {{
385                let mut out = ArrayVec::new();
386                for e in $v {
387                    out.push(Literal::$variant(*e));
388                }
389                out
390            }};
391        }
392        match *self {
393            LiteralVector::F64(ref v) => decompose_literals!(v, F64),
394            LiteralVector::F32(ref v) => decompose_literals!(v, F32),
395            LiteralVector::F16(ref v) => decompose_literals!(v, F16),
396            LiteralVector::U16(ref v) => decompose_literals!(v, U16),
397            LiteralVector::I16(ref v) => decompose_literals!(v, I16),
398            LiteralVector::U32(ref v) => decompose_literals!(v, U32),
399            LiteralVector::I32(ref v) => decompose_literals!(v, I32),
400            LiteralVector::U64(ref v) => decompose_literals!(v, U64),
401            LiteralVector::I64(ref v) => decompose_literals!(v, I64),
402            LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
403            LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
404            LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
405        }
406    }
407
408    #[allow(dead_code)]
409    /// Puts self into eval's expressions arena and returns handle to it
410    fn register_as_evaluated_expr(
411        &self,
412        eval: &mut ConstantEvaluator<'_>,
413        span: Span,
414    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
415        let lit_vec = self.to_literal_vec();
416        assert!(!lit_vec.is_empty());
417        let expr = if lit_vec.len() == 1 {
418            Expression::Literal(lit_vec[0])
419        } else {
420            Expression::Compose {
421                ty: eval.types.insert(
422                    Type {
423                        name: None,
424                        inner: TypeInner::Vector {
425                            size: match lit_vec.len() {
426                                2 => crate::VectorSize::Bi,
427                                3 => crate::VectorSize::Tri,
428                                4 => crate::VectorSize::Quad,
429                                _ => unreachable!(),
430                            },
431                            scalar: lit_vec[0].scalar(),
432                        },
433                    },
434                    Span::UNDEFINED,
435                ),
436                components: lit_vec
437                    .iter()
438                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
439                    .collect::<Result<_, _>>()?,
440            }
441        };
442        eval.register_evaluated_expr(expr, span)
443    }
444}
445
446/// A macro for matching on [`LiteralVector`] variants.
447///
448/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
449/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
450///
451/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
452///
453/// Example usage:
454///
455/// ```rust,ignore
456/// match_literal_vector!(match v => Literal {
457///     F16 => |v| {v.sum()},
458///     Integer => |v| {v.sum()},
459///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
460/// })
461/// ```
462///
463/// ```rust,ignore
464/// match_literal_vector!(match (e1, e2) => LiteralVector {
465///     F16 => |e1, e2| {e1+e2},
466///     Integer => |e1, e2| {e1+e2},
467///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
468/// })
469/// ```
470macro_rules! match_literal_vector {
471    (match $lit_vec:expr => $out:ident {
472        $(
473            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
474        ),+
475        $(,)?
476    }) => {
477        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
478    };
479
480    (@inner_start
481        $lit_vec:expr;
482        $out:ident;
483        [$($ty:ident),+];
484        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
485    ) => {
486        match_literal_vector!(@inner
487            $lit_vec;
488            $out;
489            [$($ty),+];
490            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
491        )
492    };
493
494    (@inner
495        $lit_vec:expr;
496        $out:ident;
497        [$ty:ident $(, $ty1:ident)*];
498        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
499        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
500    ) => {
501        match_literal_vector!(@inner
502            $ty;
503            $lit_vec;
504            $out;
505            [$($ty1),*];
506            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
507            [$({ $($var),+ ; $($ret)? ; $body }),+]
508        )
509    };
510    (@inner
511        Integer;
512        $lit_vec:expr;
513        $out:ident;
514        [$($ty:ident),*];
515        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
516        [
517            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
518            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
519        ]
520    ) => {
521        match_literal_vector!(@inner
522            $lit_vec;
523            $out;
524            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
525            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
526            [
527                { $($var),+ ; $($ret)? ; $body }, // U32
528                { $($var),+ ; $($ret)? ; $body }, // I32
529                { $($var),+ ; $($ret)? ; $body }, // U64
530                { $($var),+ ; $($ret)? ; $body }, // I64
531                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
532                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
533            ]
534        )
535    };
536    (@inner
537        Float;
538        $lit_vec:expr;
539        $out:ident;
540        [$($ty:ident),*];
541        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
542        [
543            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
544            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
545        ]
546    ) => {
547        match_literal_vector!(@inner
548            $lit_vec;
549            $out;
550            [F16, F32, F64, AbstractFloat $(, $ty)*];
551            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
552            [
553                { $($var),+ ; $($ret)? ; $body }, // F16
554                { $($var),+ ; $($ret)? ; $body }, // F32
555                { $($var),+ ; $($ret)? ; $body }, // F64
556                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
557                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
558            ]
559        )
560    };
561    (@inner
562        $ty:ident;
563        $lit_vec:expr;
564        $out:ident;
565        [$ty1:ident $(,$ty2:ident)*];
566        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
567            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
568            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
569        ]
570    ) => {
571        match_literal_vector!(@inner
572            $ty1;
573            $lit_vec;
574            $out;
575            [$($ty2),*];
576            [
577                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
578                { $ty; $($var),+ ; $($ret)? ; $body }
579            ] <>
580            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
581
582        )
583    };
584    (@inner
585        $ty:ident;
586        $lit_vec:expr;
587        $out:ident;
588        [];
589        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
590        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
591    ) => {
592        match_literal_vector!(@inner_finish
593            $lit_vec;
594            $out;
595            [
596                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
597                { $ty; $($var),+ ; $($ret)? ; $body }
598            ]
599        )
600    };
601    (@inner_finish
602        $lit_vec:expr;
603        $out:ident;
604        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
605    ) => {
606        match $lit_vec {
607            $(
608                #[allow(unused_parens)]
609                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
610            )+
611            _ => Err(ConstantEvaluatorError::InvalidMathArg),
612        }
613    };
614    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
615        $out::$ty($body)
616    };
617    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
618        $out::$ret($body)
619    };
620}
621
622fn float_length<F>(e: &[F]) -> Option<F>
623where
624    F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
625{
626    if e.len() == 1 {
627        // Avoids possible overflow in squaring
628        Some(e[0].abs())
629    } else {
630        let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
631        result.is_finite().then_some(result)
632    }
633}
634
635#[derive(Debug)]
636enum Behavior<'a> {
637    Wgsl(WgslRestrictions<'a>),
638    Glsl(GlslRestrictions<'a>),
639}
640
641impl Behavior<'_> {
642    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
643    const fn has_runtime_restrictions(&self) -> bool {
644        matches!(
645            self,
646            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
647                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
648        )
649    }
650}
651
652/// A context for evaluating constant expressions.
653///
654/// A `ConstantEvaluator` points at an expression arena to which it can append
655/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
656/// of Naga [`Expression`] you like, and if its value can be computed at compile
657/// time, `try_eval_and_append` appends an expression representing the computed
658/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
659/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
660///
661/// A `ConstantEvaluator` also holds whatever information we need to carry out
662/// that evaluation: types, other constants, and so on.
663///
664/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
665/// [`Compose`]: Expression::Compose
666/// [`ZeroValue`]: Expression::ZeroValue
667/// [`Literal`]: Expression::Literal
668/// [`Swizzle`]: Expression::Swizzle
669#[derive(Debug)]
670pub struct ConstantEvaluator<'a> {
671    /// Which language's evaluation rules we should follow.
672    behavior: Behavior<'a>,
673
674    /// The module's type arena.
675    ///
676    /// Because expressions like [`Splat`] contain type handles, we need to be
677    /// able to add new types to produce those expressions.
678    ///
679    /// [`Splat`]: Expression::Splat
680    types: &'a mut UniqueArena<Type>,
681
682    /// The module's constant arena.
683    constants: &'a Arena<Constant>,
684
685    /// The module's override arena.
686    overrides: &'a Arena<Override>,
687
688    /// The arena to which we are contributing expressions.
689    expressions: &'a mut Arena<Expression>,
690
691    /// Tracks the constness of expressions residing in [`Self::expressions`]
692    expression_kind_tracker: &'a mut ExpressionKindTracker,
693
694    layouter: &'a mut crate::proc::Layouter,
695}
696
697#[derive(Debug)]
698enum WgslRestrictions<'a> {
699    /// - const-expressions will be evaluated and inserted in the arena
700    Const(Option<FunctionLocalData<'a>>),
701    /// - const-expressions will be evaluated and inserted in the arena
702    /// - override-expressions will be inserted in the arena
703    Override,
704    /// - const-expressions will be evaluated and inserted in the arena
705    /// - override-expressions will be inserted in the arena
706    /// - runtime-expressions will be inserted in the arena
707    Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711enum GlslRestrictions<'a> {
712    /// - const-expressions will be evaluated and inserted in the arena
713    Const,
714    /// - const-expressions will be evaluated and inserted in the arena
715    /// - override-expressions will be inserted in the arena
716    /// - runtime-expressions will be inserted in the arena
717    Runtime(FunctionLocalData<'a>),
718}
719
720#[derive(Debug)]
721struct FunctionLocalData<'a> {
722    /// Global constant expressions
723    global_expressions: &'a Arena<Expression>,
724    emitter: &'a mut super::Emitter,
725    block: &'a mut crate::Block,
726}
727
728#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
729pub enum ExpressionKind {
730    Const,
731    Override,
732    Runtime,
733}
734
735#[derive(Debug)]
736pub struct ExpressionKindTracker {
737    inner: HandleVec<Expression, ExpressionKind>,
738}
739
740impl ExpressionKindTracker {
741    pub const fn new() -> Self {
742        Self {
743            inner: HandleVec::new(),
744        }
745    }
746
747    /// Forces the the expression to not be const
748    pub fn force_non_const(&mut self, value: Handle<Expression>) {
749        self.inner[value] = ExpressionKind::Runtime;
750    }
751
752    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
753        self.inner.insert(value, expr_type);
754    }
755
756    pub fn is_const(&self, h: Handle<Expression>) -> bool {
757        matches!(self.type_of(h), ExpressionKind::Const)
758    }
759
760    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
761        matches!(
762            self.type_of(h),
763            ExpressionKind::Const | ExpressionKind::Override
764        )
765    }
766
767    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
768        self.inner[value]
769    }
770
771    pub fn from_arena(arena: &Arena<Expression>) -> Self {
772        let mut tracker = Self {
773            inner: HandleVec::with_capacity(arena.len()),
774        };
775        for (handle, expr) in arena.iter() {
776            tracker
777                .inner
778                .insert(handle, tracker.type_of_with_expr(expr));
779        }
780        tracker
781    }
782
783    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
784        match *expr {
785            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
786                ExpressionKind::Const
787            }
788            Expression::Override(_) => ExpressionKind::Override,
789            Expression::Compose { ref components, .. } => {
790                let mut expr_type = ExpressionKind::Const;
791                for component in components {
792                    expr_type = expr_type.max(self.type_of(*component))
793                }
794                expr_type
795            }
796            Expression::Splat { value, .. } => self.type_of(value),
797            Expression::AccessIndex { base, .. } => self.type_of(base),
798            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
799            Expression::Swizzle { vector, .. } => self.type_of(vector),
800            Expression::Unary { expr, .. } => self.type_of(expr),
801            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
802            Expression::Math {
803                arg,
804                arg1,
805                arg2,
806                arg3,
807                ..
808            } => self
809                .type_of(arg)
810                .max(
811                    arg1.map(|arg| self.type_of(arg))
812                        .unwrap_or(ExpressionKind::Const),
813                )
814                .max(
815                    arg2.map(|arg| self.type_of(arg))
816                        .unwrap_or(ExpressionKind::Const),
817                )
818                .max(
819                    arg3.map(|arg| self.type_of(arg))
820                        .unwrap_or(ExpressionKind::Const),
821                ),
822            Expression::As { expr, .. } => self.type_of(expr),
823            Expression::Select {
824                condition,
825                accept,
826                reject,
827            } => self
828                .type_of(condition)
829                .max(self.type_of(accept))
830                .max(self.type_of(reject)),
831            Expression::Relational { argument, .. } => self.type_of(argument),
832            Expression::ArrayLength(expr) => self.type_of(expr),
833            _ => ExpressionKind::Runtime,
834        }
835    }
836}
837
838#[derive(Clone, Debug, thiserror::Error)]
839#[cfg_attr(test, derive(PartialEq))]
840pub enum ConstantEvaluatorError {
841    #[error("Constants cannot access function arguments")]
842    FunctionArg,
843    #[error("Constants cannot access global variables")]
844    GlobalVariable,
845    #[error("Constants cannot access local variables")]
846    LocalVariable,
847    #[error("Cannot get the array length of a non array type")]
848    InvalidArrayLengthArg,
849    #[error("Constants cannot get the array length of a dynamically sized array")]
850    ArrayLengthDynamic,
851    #[error("Cannot call arrayLength on array sized by override-expression")]
852    ArrayLengthOverridden,
853    #[error("Constants cannot call functions")]
854    Call,
855    #[error("Constants don't support workGroupUniformLoad")]
856    WorkGroupUniformLoadResult,
857    #[error("Constants don't support atomic functions")]
858    Atomic,
859    #[error("Constants don't support derivative functions")]
860    Derivative,
861    #[error("Constants don't support load expressions")]
862    Load,
863    #[error("Constants don't support image expressions")]
864    ImageExpression,
865    #[error("Constants don't support ray query expressions")]
866    RayQueryExpression,
867    #[error("Constants don't support subgroup expressions")]
868    SubgroupExpression,
869    #[error("Cannot access the type")]
870    InvalidAccessBase,
871    #[error("Cannot access at the index")]
872    InvalidAccessIndex,
873    #[error("Cannot access with index of type")]
874    InvalidAccessIndexTy,
875    #[error("Constants don't support array length expressions")]
876    ArrayLength,
877    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
878    InvalidCastArg { from: String, to: String },
879    #[error("Cannot apply the unary op to the argument")]
880    InvalidUnaryOpArg,
881    #[error("Cannot apply the binary op to the arguments")]
882    InvalidBinaryOpArgs,
883    #[error("Cannot apply math function to type")]
884    InvalidMathArg,
885    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
886    InvalidMathArgCount(crate::MathFunction, usize, usize),
887    #[error("{0} built-in function argument is out of valid range")]
888    InvalidMathArgValue(String),
889    #[error("Cannot apply relational function to type")]
890    InvalidRelationalArg(RelationalFunction),
891    #[error("value of `low` is greater than `high` for clamp built-in function")]
892    InvalidClamp,
893    #[error("Constructor expects {expected} components, found {actual}")]
894    InvalidVectorComposeLength { expected: usize, actual: usize },
895    #[error("Constructor must only contain vector or scalar arguments")]
896    InvalidVectorComposeComponent,
897    #[error("Splat is defined only on scalar values")]
898    SplatScalarOnly,
899    #[error("Can only swizzle vector constants")]
900    SwizzleVectorOnly,
901    #[error("swizzle component not present in source expression")]
902    SwizzleOutOfBounds,
903    #[error("Type is not constructible")]
904    TypeNotConstructible,
905    #[error("Subexpression(s) are not constant")]
906    SubexpressionsAreNotConstant,
907    #[error("Not implemented as constant expression: {0}")]
908    NotImplemented(String),
909    #[error("{0} operation overflowed")]
910    Overflow(String),
911    #[error(
912        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
913    )]
914    AutomaticConversionLossy {
915        value: String,
916        to_type: &'static str,
917    },
918    #[error("Division by zero")]
919    DivisionByZero,
920    #[error("Remainder by zero")]
921    RemainderByZero,
922    #[error("RHS of shift operation is greater than or equal to 32")]
923    ShiftedMoreThan32Bits,
924    #[error(transparent)]
925    Literal(#[from] crate::valid::LiteralError),
926    #[error("Can't use pipeline-overridable constants in const-expressions")]
927    Override,
928    #[error("Unexpected runtime-expression")]
929    RuntimeExpr,
930    #[error("Unexpected override-expression")]
931    OverrideExpr,
932    #[error("Expected boolean expression for condition argument of `select`, got something else")]
933    SelectScalarConditionNotABool,
934    #[error(
935        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
936        reject,
937        accept
938    )]
939    SelectVecRejectAcceptSizeMismatch {
940        reject: crate::VectorSize,
941        accept: crate::VectorSize,
942    },
943    #[error("Expected boolean vector for condition arg., got something else")]
944    SelectConditionNotAVecBool,
945    #[error(
946        "Expected same number of vector components between condition, accept, and reject args., got something else",
947    )]
948    SelectConditionVecSizeMismatch,
949    #[error(
950        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
951    )]
952    SelectAcceptRejectTypeMismatch,
953    #[error("Cooperative operations can't be constant")]
954    CooperativeOperation,
955    #[error("Type is too large")]
956    TypeTooLarge(Handle<Type>),
957}
958
959impl<'a> ConstantEvaluator<'a> {
960    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
961    /// constant expression arena.
962    ///
963    /// Report errors according to WGSL's rules for constant evaluation.
964    pub const fn for_wgsl_module(
965        module: &'a mut crate::Module,
966        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
967        layouter: &'a mut crate::proc::Layouter,
968        in_override_ctx: bool,
969    ) -> Self {
970        Self::for_module(
971            Behavior::Wgsl(if in_override_ctx {
972                WgslRestrictions::Override
973            } else {
974                WgslRestrictions::Const(None)
975            }),
976            module,
977            global_expression_kind_tracker,
978            layouter,
979        )
980    }
981
982    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
983    /// constant expression arena.
984    ///
985    /// Report errors according to GLSL's rules for constant evaluation.
986    pub const fn for_glsl_module(
987        module: &'a mut crate::Module,
988        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
989        layouter: &'a mut crate::proc::Layouter,
990    ) -> Self {
991        Self::for_module(
992            Behavior::Glsl(GlslRestrictions::Const),
993            module,
994            global_expression_kind_tracker,
995            layouter,
996        )
997    }
998
999    const fn for_module(
1000        behavior: Behavior<'a>,
1001        module: &'a mut crate::Module,
1002        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1003        layouter: &'a mut crate::proc::Layouter,
1004    ) -> Self {
1005        Self {
1006            behavior,
1007            types: &mut module.types,
1008            constants: &module.constants,
1009            overrides: &module.overrides,
1010            expressions: &mut module.global_expressions,
1011            expression_kind_tracker: global_expression_kind_tracker,
1012            layouter,
1013        }
1014    }
1015
1016    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1017    /// expression arena.
1018    ///
1019    /// Report errors according to WGSL's rules for constant evaluation.
1020    pub const fn for_wgsl_function(
1021        module: &'a mut crate::Module,
1022        expressions: &'a mut Arena<Expression>,
1023        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1024        layouter: &'a mut crate::proc::Layouter,
1025        emitter: &'a mut super::Emitter,
1026        block: &'a mut crate::Block,
1027        is_const: bool,
1028    ) -> Self {
1029        let local_data = FunctionLocalData {
1030            global_expressions: &module.global_expressions,
1031            emitter,
1032            block,
1033        };
1034        Self {
1035            behavior: Behavior::Wgsl(if is_const {
1036                WgslRestrictions::Const(Some(local_data))
1037            } else {
1038                WgslRestrictions::Runtime(local_data)
1039            }),
1040            types: &mut module.types,
1041            constants: &module.constants,
1042            overrides: &module.overrides,
1043            expressions,
1044            expression_kind_tracker: local_expression_kind_tracker,
1045            layouter,
1046        }
1047    }
1048
1049    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1050    /// expression arena.
1051    ///
1052    /// Report errors according to GLSL's rules for constant evaluation.
1053    pub const fn for_glsl_function(
1054        module: &'a mut crate::Module,
1055        expressions: &'a mut Arena<Expression>,
1056        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1057        layouter: &'a mut crate::proc::Layouter,
1058        emitter: &'a mut super::Emitter,
1059        block: &'a mut crate::Block,
1060    ) -> Self {
1061        Self {
1062            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1063                global_expressions: &module.global_expressions,
1064                emitter,
1065                block,
1066            })),
1067            types: &mut module.types,
1068            constants: &module.constants,
1069            overrides: &module.overrides,
1070            expressions,
1071            expression_kind_tracker: local_expression_kind_tracker,
1072            layouter,
1073        }
1074    }
1075
1076    pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1077        crate::proc::GlobalCtx {
1078            types: self.types,
1079            constants: self.constants,
1080            overrides: self.overrides,
1081            global_expressions: match self.function_local_data() {
1082                Some(data) => data.global_expressions,
1083                None => self.expressions,
1084            },
1085        }
1086    }
1087
1088    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1089        if !self.expression_kind_tracker.is_const(expr) {
1090            log::debug!("check: SubexpressionsAreNotConstant");
1091            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1092        }
1093        Ok(())
1094    }
1095
1096    fn check_and_get(
1097        &mut self,
1098        expr: Handle<Expression>,
1099    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1100        match self.expressions[expr] {
1101            Expression::Constant(c) => {
1102                // Are we working in a function's expression arena, or the
1103                // module's constant expression arena?
1104                if let Some(function_local_data) = self.function_local_data() {
1105                    // Deep-copy the constant's value into our arena.
1106                    self.copy_from(
1107                        self.constants[c].init,
1108                        function_local_data.global_expressions,
1109                    )
1110                } else {
1111                    // "See through" the constant and use its initializer.
1112                    Ok(self.constants[c].init)
1113                }
1114            }
1115            _ => {
1116                self.check(expr)?;
1117                Ok(expr)
1118            }
1119        }
1120    }
1121
1122    /// Try to evaluate `expr` at compile time.
1123    ///
1124    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1125    /// we can determine its value at compile time, we append an expression
1126    /// representing its value - a tree of [`Literal`], [`Compose`],
1127    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1128    /// `self` contributes to.
1129    ///
1130    /// If `expr`'s value cannot be determined at compile time, and `self` is
1131    /// contributing to some function's expression arena, then append `expr` to
1132    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1133    /// contributing to the module's constant expression arena; since `expr`'s
1134    /// value is not a constant, return an error.
1135    ///
1136    /// We only consider `expr` itself, without recursing into its operands. Its
1137    /// operands must all have been produced by prior calls to
1138    /// `try_eval_and_append`, to ensure that they have already been reduced to
1139    /// an evaluated form if possible.
1140    ///
1141    /// [`Literal`]: Expression::Literal
1142    /// [`Compose`]: Expression::Compose
1143    /// [`ZeroValue`]: Expression::ZeroValue
1144    /// [`Swizzle`]: Expression::Swizzle
1145    pub fn try_eval_and_append(
1146        &mut self,
1147        expr: Expression,
1148        span: Span,
1149    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1150        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1151            ExpressionKind::Const => {
1152                let eval_result = self.try_eval_and_append_impl(&expr, span);
1153                // We should be able to evaluate `Const` expressions at this
1154                // point. If we failed to, then that probably means we just
1155                // haven't implemented that part of constant evaluation. Work
1156                // around this by simply emitting it as a run-time expression.
1157                if self.behavior.has_runtime_restrictions()
1158                    && matches!(
1159                        eval_result,
1160                        Err(ConstantEvaluatorError::NotImplemented(_)
1161                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1162                    )
1163                {
1164                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1165                } else {
1166                    eval_result
1167                }
1168            }
1169            ExpressionKind::Override => match self.behavior {
1170                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1171                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1172                }
1173                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1174                    Err(ConstantEvaluatorError::OverrideExpr)
1175                }
1176
1177                // GLSL specialization constants (constant_id) become Override expressions
1178                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1179                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1180                }
1181                Behavior::Glsl(GlslRestrictions::Const) => {
1182                    Err(ConstantEvaluatorError::OverrideExpr)
1183                }
1184            },
1185            ExpressionKind::Runtime => {
1186                if self.behavior.has_runtime_restrictions() {
1187                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1188                } else {
1189                    Err(ConstantEvaluatorError::RuntimeExpr)
1190                }
1191            }
1192        }
1193    }
1194
1195    /// Is the [`Self::expressions`] arena the global module expression arena?
1196    const fn is_global_arena(&self) -> bool {
1197        matches!(
1198            self.behavior,
1199            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1200                | Behavior::Glsl(GlslRestrictions::Const)
1201        )
1202    }
1203
1204    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1205        match self.behavior {
1206            Behavior::Wgsl(
1207                WgslRestrictions::Runtime(ref function_local_data)
1208                | WgslRestrictions::Const(Some(ref function_local_data)),
1209            )
1210            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1211                Some(function_local_data)
1212            }
1213            _ => None,
1214        }
1215    }
1216
1217    fn try_eval_and_append_impl(
1218        &mut self,
1219        expr: &Expression,
1220        span: Span,
1221    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1222        log::trace!("try_eval_and_append: {expr:?}");
1223        match *expr {
1224            Expression::Constant(c) if self.is_global_arena() => {
1225                // "See through" the constant and use its initializer.
1226                // This is mainly done to avoid having constants pointing to other constants.
1227                Ok(self.constants[c].init)
1228            }
1229            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1230            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1231                self.register_evaluated_expr(expr.clone(), span)
1232            }
1233            Expression::Compose { ty, ref components } => {
1234                let components = components
1235                    .iter()
1236                    .map(|component| self.check_and_get(*component))
1237                    .collect::<Result<Vec<_>, _>>()?;
1238                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1239            }
1240            Expression::Splat { size, value } => {
1241                let value = self.check_and_get(value)?;
1242                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1243            }
1244            Expression::AccessIndex { base, index } => {
1245                let base = self.check_and_get(base)?;
1246
1247                self.access(base, index as usize, span)
1248            }
1249            Expression::Access { base, index } => {
1250                let base = self.check_and_get(base)?;
1251                let index = self.check_and_get(index)?;
1252
1253                let index_val: u32 = self
1254                    .to_ctx()
1255                    .get_const_val_from(index, self.expressions)
1256                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1257                self.access(base, index_val as usize, span)
1258            }
1259            Expression::Swizzle {
1260                size,
1261                vector,
1262                pattern,
1263            } => {
1264                let vector = self.check_and_get(vector)?;
1265
1266                self.swizzle(size, span, vector, pattern)
1267            }
1268            Expression::Unary { expr, op } => {
1269                let expr = self.check_and_get(expr)?;
1270
1271                self.unary_op(op, expr, span)
1272            }
1273            Expression::Binary { left, right, op } => {
1274                let left = self.check_and_get(left)?;
1275                let right = self.check_and_get(right)?;
1276
1277                self.binary_op(op, left, right, span)
1278            }
1279            Expression::Math {
1280                fun,
1281                arg,
1282                arg1,
1283                arg2,
1284                arg3,
1285            } => {
1286                let arg = self.check_and_get(arg)?;
1287                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1288                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1289                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1290
1291                self.math(arg, arg1, arg2, arg3, fun, span)
1292            }
1293            Expression::As {
1294                convert,
1295                expr,
1296                kind,
1297            } => {
1298                let expr = self.check_and_get(expr)?;
1299
1300                match convert {
1301                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1302                    None => Err(ConstantEvaluatorError::NotImplemented(
1303                        "bitcast built-in function".into(),
1304                    )),
1305                }
1306            }
1307            Expression::Select {
1308                reject,
1309                accept,
1310                condition,
1311            } => {
1312                let mut arg = |expr| self.check_and_get(expr);
1313
1314                let reject = arg(reject)?;
1315                let accept = arg(accept)?;
1316                let condition = arg(condition)?;
1317
1318                self.select(reject, accept, condition, span)
1319            }
1320            Expression::Relational { fun, argument } => {
1321                let argument = self.check_and_get(argument)?;
1322                self.relational(fun, argument, span)
1323            }
1324            Expression::ArrayLength(expr) => match self.behavior {
1325                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1326                Behavior::Glsl(_) => {
1327                    let expr = self.check_and_get(expr)?;
1328                    self.array_length(expr, span)
1329                }
1330            },
1331            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1332            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1333            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1334            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1335            Expression::WorkGroupUniformLoadResult { .. } => {
1336                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1337            }
1338            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1339            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1340            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1341            Expression::ImageSample { .. }
1342            | Expression::ImageLoad { .. }
1343            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1344            Expression::RayQueryProceedResult
1345            | Expression::RayQueryGetIntersection { .. }
1346            | Expression::RayQueryVertexPositions { .. } => {
1347                Err(ConstantEvaluatorError::RayQueryExpression)
1348            }
1349            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1350            Expression::SubgroupOperationResult { .. } => {
1351                Err(ConstantEvaluatorError::SubgroupExpression)
1352            }
1353            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1354                Err(ConstantEvaluatorError::CooperativeOperation)
1355            }
1356        }
1357    }
1358
1359    /// Splat `value` to `size`, without using [`Splat`] expressions.
1360    ///
1361    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1362    /// build a vector with the given `size` whose components are all
1363    /// `value`.
1364    ///
1365    /// Use `span` as the span of the inserted expressions and
1366    /// resulting types.
1367    ///
1368    /// [`Splat`]: Expression::Splat
1369    /// [`Compose`]: Expression::Compose
1370    /// [`ZeroValue`]: Expression::ZeroValue
1371    fn splat(
1372        &mut self,
1373        value: Handle<Expression>,
1374        size: crate::VectorSize,
1375        span: Span,
1376    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1377        match self.expressions[value] {
1378            Expression::Literal(literal) => {
1379                let scalar = literal.scalar();
1380                let ty = self.types.insert(
1381                    Type {
1382                        name: None,
1383                        inner: TypeInner::Vector { size, scalar },
1384                    },
1385                    span,
1386                );
1387                let expr = Expression::Compose {
1388                    ty,
1389                    components: vec![value; size as usize],
1390                };
1391                self.register_evaluated_expr(expr, span)
1392            }
1393            Expression::ZeroValue(ty) => {
1394                let inner = match self.types[ty].inner {
1395                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1396                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1397                };
1398                let res_ty = self.types.insert(Type { name: None, inner }, span);
1399                let expr = Expression::ZeroValue(res_ty);
1400                self.register_evaluated_expr(expr, span)
1401            }
1402            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1403        }
1404    }
1405
1406    fn swizzle(
1407        &mut self,
1408        size: crate::VectorSize,
1409        span: Span,
1410        src_constant: Handle<Expression>,
1411        pattern: [crate::SwizzleComponent; 4],
1412    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1413        let mut get_dst_ty = |ty| match self.types[ty].inner {
1414            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1415                Type {
1416                    name: None,
1417                    inner: TypeInner::Vector { size, scalar },
1418                },
1419                span,
1420            )),
1421            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1422        };
1423
1424        match self.expressions[src_constant] {
1425            Expression::ZeroValue(ty) => {
1426                let dst_ty = get_dst_ty(ty)?;
1427                let expr = Expression::ZeroValue(dst_ty);
1428                self.register_evaluated_expr(expr, span)
1429            }
1430            Expression::Splat { value, .. } => {
1431                let expr = Expression::Splat { size, value };
1432                self.register_evaluated_expr(expr, span)
1433            }
1434            Expression::Compose { ty, ref components } => {
1435                let dst_ty = get_dst_ty(ty)?;
1436
1437                let mut flattened = [src_constant; 4]; // dummy value
1438                let len =
1439                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1440                        .zip(flattened.iter_mut())
1441                        .map(|(component, elt)| *elt = component)
1442                        .count();
1443                let flattened = &flattened[..len];
1444
1445                let swizzled_components = pattern[..size as usize]
1446                    .iter()
1447                    .map(|&sc| {
1448                        let sc = sc as usize;
1449                        if let Some(elt) = flattened.get(sc) {
1450                            Ok(*elt)
1451                        } else {
1452                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1453                        }
1454                    })
1455                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1456                let expr = Expression::Compose {
1457                    ty: dst_ty,
1458                    components: swizzled_components,
1459                };
1460                self.register_evaluated_expr(expr, span)
1461            }
1462            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1463        }
1464    }
1465
1466    fn math(
1467        &mut self,
1468        arg: Handle<Expression>,
1469        arg1: Option<Handle<Expression>>,
1470        arg2: Option<Handle<Expression>>,
1471        arg3: Option<Handle<Expression>>,
1472        fun: crate::MathFunction,
1473        span: Span,
1474    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1475        let expected = fun.argument_count();
1476        let given = Some(arg)
1477            .into_iter()
1478            .chain(arg1)
1479            .chain(arg2)
1480            .chain(arg3)
1481            .count();
1482        if expected != given {
1483            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1484                fun, expected, given,
1485            ));
1486        }
1487
1488        // NOTE: We try to match the declaration order of `MathFunction` here.
1489        match fun {
1490            // comparison
1491            crate::MathFunction::Abs => {
1492                component_wise_scalar(self, span, [arg], |args| match args {
1493                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1494                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1495                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1496                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1497                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1498                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1499                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1500                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1501                })
1502            }
1503            crate::MathFunction::Min => {
1504                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1505                    Ok([e1.min(e2)])
1506                })
1507            }
1508            crate::MathFunction::Max => {
1509                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1510                    Ok([e1.max(e2)])
1511                })
1512            }
1513            crate::MathFunction::Clamp => {
1514                component_wise_scalar!(
1515                    self,
1516                    span,
1517                    [arg, arg1.unwrap(), arg2.unwrap()],
1518                    |e, low, high| {
1519                        if low > high {
1520                            Err(ConstantEvaluatorError::InvalidClamp)
1521                        } else {
1522                            Ok([e.clamp(low, high)])
1523                        }
1524                    }
1525                )
1526            }
1527            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1528                Float::F16([e]) => Ok(Float::F16(
1529                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1530                )),
1531                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1532                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1533            }),
1534
1535            // trigonometry
1536            crate::MathFunction::Cos => {
1537                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1538            }
1539            crate::MathFunction::Cosh => {
1540                component_wise_float!(self, span, [arg], |e| {
1541                    let result = e.cosh();
1542                    if result.is_finite() {
1543                        Ok([result])
1544                    } else {
1545                        Err(ConstantEvaluatorError::Overflow("cosh".into()))
1546                    }
1547                })
1548            }
1549            crate::MathFunction::Sin => {
1550                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1551            }
1552            crate::MathFunction::Sinh => {
1553                component_wise_float!(self, span, [arg], |e| {
1554                    let result = e.sinh();
1555                    if result.is_finite() {
1556                        Ok([result])
1557                    } else {
1558                        Err(ConstantEvaluatorError::Overflow("sinh".into()))
1559                    }
1560                })
1561            }
1562            crate::MathFunction::Tan => {
1563                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1564            }
1565            crate::MathFunction::Tanh => {
1566                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1567            }
1568            crate::MathFunction::Acos => {
1569                component_wise_float!(self, span, [arg], |e| {
1570                    if e.abs() <= One::one() {
1571                        Ok([e.acos()])
1572                    } else {
1573                        Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1574                    }
1575                })
1576            }
1577            crate::MathFunction::Asin => {
1578                component_wise_float!(self, span, [arg], |e| {
1579                    if e.abs() <= One::one() {
1580                        Ok([e.asin()])
1581                    } else {
1582                        Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1583                    }
1584                })
1585            }
1586            crate::MathFunction::Atan => {
1587                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1588            }
1589            crate::MathFunction::Atan2 => {
1590                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1591                    Ok([y.atan2(x)])
1592                })
1593            }
1594            crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1595                Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1596                Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1597                Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1598            }),
1599            crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1600                Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1601                Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1602                Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1603                _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1604            }),
1605            crate::MathFunction::Atanh => {
1606                component_wise_float!(self, span, [arg], |e| {
1607                    if e.abs() < One::one() {
1608                        Ok([e.atanh()])
1609                    } else {
1610                        Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1611                    }
1612                })
1613            }
1614            crate::MathFunction::Radians => {
1615                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1616            }
1617            crate::MathFunction::Degrees => {
1618                component_wise_float!(self, span, [arg], |e| {
1619                    let result = e.to_degrees();
1620                    if result.is_finite() {
1621                        Ok([result])
1622                    } else {
1623                        Err(ConstantEvaluatorError::Overflow("degrees".into()))
1624                    }
1625                })
1626            }
1627
1628            // decomposition
1629            crate::MathFunction::Ceil => {
1630                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1631            }
1632            crate::MathFunction::Floor => {
1633                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1634            }
1635            crate::MathFunction::Round => {
1636                component_wise_float(self, span, [arg], |e| match e {
1637                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1638                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1639                    Float::F16([e]) => {
1640                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1641                        //
1642                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1643                        // which has licensing compatible with ours. See also
1644                        // <https://github.com/rust-lang/rust/issues/96710>.
1645                        //
1646                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1647                        fn round_ties_even(x: f64) -> f64 {
1648                            let i = x as i64;
1649                            let f = (x - i as f64).abs();
1650                            if f == 0.5 {
1651                                if i & 1 == 1 {
1652                                    // -1.5, 1.5, 3.5, ...
1653                                    (x.abs() + 0.5).copysign(x)
1654                                } else {
1655                                    (x.abs() - 0.5).copysign(x)
1656                                }
1657                            } else {
1658                                x.round()
1659                            }
1660                        }
1661
1662                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1663                    }
1664                })
1665            }
1666            crate::MathFunction::Fract => {
1667                component_wise_float!(self, span, [arg], |e| {
1668                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1669                    // here.
1670                    Ok([e - e.floor()])
1671                })
1672            }
1673            crate::MathFunction::Trunc => {
1674                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1675            }
1676
1677            // exponent
1678            crate::MathFunction::Exp => {
1679                component_wise_float!(self, span, [arg], |e| {
1680                    let result = e.exp();
1681                    if result.is_finite() {
1682                        Ok([result])
1683                    } else {
1684                        Err(ConstantEvaluatorError::Overflow("exp".into()))
1685                    }
1686                })
1687            }
1688            crate::MathFunction::Exp2 => {
1689                component_wise_float!(self, span, [arg], |e| {
1690                    let result = e.exp2();
1691                    if result.is_finite() {
1692                        Ok([result])
1693                    } else {
1694                        Err(ConstantEvaluatorError::Overflow("exp2".into()))
1695                    }
1696                })
1697            }
1698            crate::MathFunction::Log => {
1699                component_wise_float!(self, span, [arg], |e| {
1700                    if e > Zero::zero() {
1701                        Ok([e.ln()])
1702                    } else {
1703                        Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1704                    }
1705                })
1706            }
1707            crate::MathFunction::Log2 => {
1708                component_wise_float!(self, span, [arg], |e| {
1709                    if e > Zero::zero() {
1710                        Ok([e.log2()])
1711                    } else {
1712                        Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1713                    }
1714                })
1715            }
1716            crate::MathFunction::Pow => {
1717                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1718                    // 0.pow(0) is an error since exp2(0 * log2(0)) is NaN.
1719                    // https://www.w3.org/TR/WGSL/#pow-builtin
1720                    if e1 < Zero::zero()
1721                        || e1.is_one() && e2.is_infinite()
1722                        || e1.is_infinite() && e2.is_zero()
1723                        || e1.is_zero() && e2.is_zero()
1724                    {
1725                        Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1726                    } else {
1727                        let result = e1.powf(e2);
1728                        if result.is_finite() {
1729                            Ok([result])
1730                        } else {
1731                            Err(ConstantEvaluatorError::Overflow("pow".into()))
1732                        }
1733                    }
1734                })
1735            }
1736
1737            // computational
1738            crate::MathFunction::Sign => {
1739                component_wise_signed!(self, span, [arg], |e| {
1740                    Ok([if e.is_zero() {
1741                        Zero::zero()
1742                    } else {
1743                        e.signum()
1744                    }])
1745                })
1746            }
1747            crate::MathFunction::Fma => {
1748                component_wise_float!(
1749                    self,
1750                    span,
1751                    [arg, arg1.unwrap(), arg2.unwrap()],
1752                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1753                )
1754            }
1755            crate::MathFunction::Step => {
1756                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1757                    Float::Abstract([edge, x]) => {
1758                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1759                    }
1760                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1761                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1762                        f16::one()
1763                    } else {
1764                        f16::zero()
1765                    }])),
1766                })
1767            }
1768            crate::MathFunction::Sqrt => {
1769                component_wise_float!(self, span, [arg], |e| {
1770                    if e >= Zero::zero() {
1771                        Ok([e.sqrt()])
1772                    } else {
1773                        Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1774                    }
1775                })
1776            }
1777            crate::MathFunction::InverseSqrt => {
1778                component_wise_float(self, span, [arg], |e| match e {
1779                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1780                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1781                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1782                })
1783            }
1784
1785            // bits
1786            crate::MathFunction::CountTrailingZeros => {
1787                component_wise_concrete_int!(self, span, [arg], |e| {
1788                    #[allow(clippy::useless_conversion)]
1789                    Ok([e
1790                        .trailing_zeros()
1791                        .try_into()
1792                        .expect("bit count overflowed 32 bits, somehow!?")])
1793                })
1794            }
1795            crate::MathFunction::CountLeadingZeros => {
1796                component_wise_concrete_int!(self, span, [arg], |e| {
1797                    #[allow(clippy::useless_conversion)]
1798                    Ok([e
1799                        .leading_zeros()
1800                        .try_into()
1801                        .expect("bit count overflowed 32 bits, somehow!?")])
1802                })
1803            }
1804            crate::MathFunction::CountOneBits => {
1805                component_wise_concrete_int!(self, span, [arg], |e| {
1806                    #[allow(clippy::useless_conversion)]
1807                    Ok([e
1808                        .count_ones()
1809                        .try_into()
1810                        .expect("bit count overflowed 32 bits, somehow!?")])
1811                })
1812            }
1813            crate::MathFunction::ReverseBits => {
1814                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1815            }
1816            crate::MathFunction::FirstTrailingBit => {
1817                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1818            }
1819            crate::MathFunction::FirstLeadingBit => {
1820                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1821            }
1822
1823            // vector
1824            crate::MathFunction::Dot4I8Packed => {
1825                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1826            }
1827            crate::MathFunction::Dot4U8Packed => {
1828                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1829            }
1830            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1831            crate::MathFunction::Dot => {
1832                // https://www.w3.org/TR/WGSL/#dot-builtin
1833                let e1 = self.extract_vec(arg, false)?;
1834                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1835                if e1.len() != e2.len() {
1836                    return Err(ConstantEvaluatorError::InvalidMathArg);
1837                }
1838
1839                fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1840                where
1841                    P: num_traits::Float,
1842                {
1843                    let result = a
1844                        .iter()
1845                        .zip(b.iter())
1846                        .map(|(&aa, &bb)| aa * bb)
1847                        .fold(P::zero(), |acc, x| acc + x);
1848                    if result.is_finite() {
1849                        Ok(result)
1850                    } else {
1851                        Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1852                    }
1853                }
1854
1855                fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1856                where
1857                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1858                {
1859                    a.iter()
1860                        .zip(b.iter())
1861                        .map(|(&aa, bb)| aa.checked_mul(bb))
1862                        .try_fold(P::zero(), |acc, x| {
1863                            if let Some(x) = x {
1864                                acc.checked_add(&x)
1865                            } else {
1866                                None
1867                            }
1868                        })
1869                        .ok_or(ConstantEvaluatorError::Overflow(
1870                            "in dot built-in".to_string(),
1871                        ))
1872                }
1873
1874                fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1875                where
1876                    P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1877                {
1878                    a.iter()
1879                        .zip(b.iter())
1880                        .map(|(&aa, bb)| aa.wrapping_mul(bb))
1881                        .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1882                }
1883
1884                let result = match_literal_vector!(match (e1, e2) => Literal {
1885                    Float => |e1, e2| { float_dot_checked(e1, e2)? },
1886                    AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1887                    I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1888                    U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1889                })?;
1890                self.register_evaluated_expr(Expression::Literal(result), span)
1891            }
1892            crate::MathFunction::Length => {
1893                // https://www.w3.org/TR/WGSL/#length-builtin
1894                let e1 = self.extract_vec(arg, true)?;
1895
1896                let result = match_literal_vector!(match e1 => Literal {
1897                    Float => |e1| {
1898                        float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1899                    },
1900                })?;
1901                self.register_evaluated_expr(Expression::Literal(result), span)
1902            }
1903            crate::MathFunction::Distance => {
1904                // https://www.w3.org/TR/WGSL/#distance-builtin
1905                let e1 = self.extract_vec(arg, true)?;
1906                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1907                if e1.len() != e2.len() {
1908                    return Err(ConstantEvaluatorError::InvalidMathArg);
1909                }
1910
1911                fn float_distance<F>(a: &[F], b: &[F]) -> F
1912                where
1913                    F: core::ops::Mul<F>,
1914                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1915                {
1916                    if a.len() == 1 {
1917                        // Avoids possible overflow in squaring
1918                        (a[0] - b[0]).abs()
1919                    } else {
1920                        a.iter()
1921                            .zip(b.iter())
1922                            .map(|(&aa, &bb)| aa - bb)
1923                            .map(|ei| ei * ei)
1924                            .sum::<F>()
1925                            .sqrt()
1926                    }
1927                }
1928                let result = match_literal_vector!(match (e1, e2) => Literal {
1929                    Float => |e1, e2| { float_distance(e1, e2) },
1930                })?;
1931                self.register_evaluated_expr(Expression::Literal(result), span)
1932            }
1933            crate::MathFunction::Normalize => {
1934                // https://www.w3.org/TR/WGSL/#normalize-builtin
1935                let e1 = self.extract_vec(arg, true)?;
1936
1937                fn float_normalize<F>(
1938                    e: &[F],
1939                ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1940                where
1941                    F: core::ops::Mul<F>,
1942                    F: num_traits::Float + iter::Sum,
1943                {
1944                    let len = match float_length(e) {
1945                        Some(len) if !len.is_zero() => Ok(len),
1946                        Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1947                            "normalize".into(),
1948                        )),
1949                        None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1950                    }?;
1951
1952                    let mut out = ArrayVec::new();
1953                    for &ei in e {
1954                        out.push(ei / len);
1955                    }
1956                    Ok(out)
1957                }
1958
1959                let result = match_literal_vector!(match e1 => LiteralVector {
1960                    Float => |e1| { float_normalize(e1)? },
1961                })?;
1962                result.register_as_evaluated_expr(self, span)
1963            }
1964
1965            // unimplemented
1966            crate::MathFunction::Modf
1967            | crate::MathFunction::Frexp
1968            | crate::MathFunction::Ldexp
1969            | crate::MathFunction::Outer
1970            | crate::MathFunction::FaceForward
1971            | crate::MathFunction::Reflect
1972            | crate::MathFunction::Refract
1973            | crate::MathFunction::Mix
1974            | crate::MathFunction::SmoothStep
1975            | crate::MathFunction::Inverse
1976            | crate::MathFunction::Transpose
1977            | crate::MathFunction::Determinant
1978            | crate::MathFunction::QuantizeToF16
1979            | crate::MathFunction::ExtractBits
1980            | crate::MathFunction::InsertBits
1981            | crate::MathFunction::Pack4x8snorm
1982            | crate::MathFunction::Pack4x8unorm
1983            | crate::MathFunction::Pack2x16snorm
1984            | crate::MathFunction::Pack2x16unorm
1985            | crate::MathFunction::Pack2x16float
1986            | crate::MathFunction::Pack4xI8
1987            | crate::MathFunction::Pack4xU8
1988            | crate::MathFunction::Pack4xI8Clamp
1989            | crate::MathFunction::Pack4xU8Clamp
1990            | crate::MathFunction::Unpack4x8snorm
1991            | crate::MathFunction::Unpack4x8unorm
1992            | crate::MathFunction::Unpack2x16snorm
1993            | crate::MathFunction::Unpack2x16unorm
1994            | crate::MathFunction::Unpack2x16float
1995            | crate::MathFunction::Unpack4xI8
1996            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1997                format!("{fun:?} built-in function"),
1998            )),
1999        }
2000    }
2001
2002    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
2003    fn packed_dot_product(
2004        &mut self,
2005        a: Handle<Expression>,
2006        b: Handle<Expression>,
2007        span: Span,
2008        signed: bool,
2009    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2010        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
2011            return Err(ConstantEvaluatorError::InvalidMathArg);
2012        };
2013        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2014            return Err(ConstantEvaluatorError::InvalidMathArg);
2015        };
2016
2017        let result = if signed {
2018            Literal::I32(
2019                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2020                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2021                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2022                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2023            )
2024        } else {
2025            Literal::U32(
2026                (a & 0xFF) * (b & 0xFF)
2027                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2028                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2029                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2030            )
2031        };
2032
2033        self.register_evaluated_expr(Expression::Literal(result), span)
2034    }
2035
2036    /// Vector cross product.
2037    fn cross_product(
2038        &mut self,
2039        a: Handle<Expression>,
2040        b: Handle<Expression>,
2041        span: Span,
2042    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2043        use Literal as Li;
2044
2045        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2046        let (b, _) = self.extract_vec_with_size::<3>(b)?;
2047
2048        let product = match (a, b) {
2049            (
2050                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2051                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2052            ) => {
2053                // `cross` has no overload for AbstractInt, so AbstractInt
2054                // arguments are automatically converted to AbstractFloat. Since
2055                // `f64` has a much wider range than `i64`, there's no danger of
2056                // overflow here.
2057                let p = cross_product(
2058                    [a0 as f64, a1 as f64, a2 as f64],
2059                    [b0 as f64, b1 as f64, b2 as f64],
2060                );
2061                [
2062                    Li::AbstractFloat(p[0]),
2063                    Li::AbstractFloat(p[1]),
2064                    Li::AbstractFloat(p[2]),
2065                ]
2066            }
2067            (
2068                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2069                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2070            ) => {
2071                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2072                [
2073                    Li::AbstractFloat(p[0]),
2074                    Li::AbstractFloat(p[1]),
2075                    Li::AbstractFloat(p[2]),
2076                ]
2077            }
2078            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2079                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2080                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2081            }
2082            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2083                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2084                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2085            }
2086            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2087                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2088                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2089            }
2090            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2091        };
2092
2093        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2094        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2095        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2096
2097        self.register_evaluated_expr(
2098            Expression::Compose {
2099                ty,
2100                components: vec![p0, p1, p2],
2101            },
2102            span,
2103        )
2104    }
2105
2106    /// Extract the values of a `vecN` from `expr`.
2107    ///
2108    /// Return the value of `expr`, whose type is `vecN<S>` for some
2109    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2110    /// values.
2111    ///
2112    /// Also return the type handle from the `Compose` expression.
2113    fn extract_vec_with_size<const N: usize>(
2114        &mut self,
2115        expr: Handle<Expression>,
2116    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2117        let span = self.expressions.get_span(expr);
2118        let expr = self.eval_zero_value_and_splat(expr, span)?;
2119        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2120            return Err(ConstantEvaluatorError::InvalidMathArg);
2121        };
2122
2123        let mut value = [Literal::Bool(false); N];
2124        for (component, elt) in
2125            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2126                .zip(value.iter_mut())
2127        {
2128            let Expression::Literal(literal) = self.expressions[component] else {
2129                return Err(ConstantEvaluatorError::InvalidMathArg);
2130            };
2131            *elt = literal;
2132        }
2133
2134        Ok((value, ty))
2135    }
2136
2137    /// Extract the values of a `vecN` from `expr`.
2138    ///
2139    /// Return the value of `expr`, whose type is `vecN<S>` for some
2140    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2141    /// values.
2142    ///
2143    /// Also return the type handle from the `Compose` expression.
2144    fn extract_vec(
2145        &mut self,
2146        expr: Handle<Expression>,
2147        allow_single: bool,
2148    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2149        let span = self.expressions.get_span(expr);
2150        let expr = self.eval_zero_value_and_splat(expr, span)?;
2151
2152        match self.expressions[expr] {
2153            Expression::Literal(literal) if allow_single => {
2154                Ok(LiteralVector::from_literal(literal))
2155            }
2156            Expression::Compose { ty, ref components } => {
2157                let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2158                for expr in
2159                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2160                {
2161                    match self.expressions[expr] {
2162                        Expression::Literal(l) => components_out.push(l),
2163                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2164                    }
2165                }
2166                LiteralVector::from_literal_vec(components_out)
2167            }
2168            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2169        }
2170    }
2171
2172    fn array_length(
2173        &mut self,
2174        array: Handle<Expression>,
2175        span: Span,
2176    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2177        match self.expressions[array] {
2178            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2179                match self.types[ty].inner {
2180                    TypeInner::Array { size, .. } => match size {
2181                        ArraySize::Constant(len) => {
2182                            let expr = Expression::Literal(Literal::U32(len.get()));
2183                            self.register_evaluated_expr(expr, span)
2184                        }
2185                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2186                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2187                    },
2188                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2189                }
2190            }
2191            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2192        }
2193    }
2194
2195    fn access(
2196        &mut self,
2197        base: Handle<Expression>,
2198        index: usize,
2199        span: Span,
2200    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2201        match self.expressions[base] {
2202            Expression::ZeroValue(ty) => {
2203                let ty_inner = &self.types[ty].inner;
2204                let components = ty_inner
2205                    .components()
2206                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2207
2208                if index >= components as usize {
2209                    Err(ConstantEvaluatorError::InvalidAccessBase)
2210                } else {
2211                    let ty_res = ty_inner
2212                        .component_type(index)
2213                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2214                    let ty = match ty_res {
2215                        crate::proc::TypeResolution::Handle(ty) => ty,
2216                        crate::proc::TypeResolution::Value(inner) => {
2217                            self.types.insert(Type { name: None, inner }, span)
2218                        }
2219                    };
2220                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2221                }
2222            }
2223            Expression::Splat { size, value } => {
2224                if index >= size as usize {
2225                    Err(ConstantEvaluatorError::InvalidAccessBase)
2226                } else {
2227                    Ok(value)
2228                }
2229            }
2230            Expression::Compose { ty, ref components } => {
2231                let _ = self.types[ty]
2232                    .inner
2233                    .components()
2234                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2235
2236                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2237                    .nth(index)
2238                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2239            }
2240            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2241        }
2242    }
2243
2244    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2245    ///
2246    /// [`ZeroValue`]: Expression::ZeroValue
2247    /// [`Splat`]: Expression::Splat
2248    /// [`Literal`]: Expression::Literal
2249    /// [`Compose`]: Expression::Compose
2250    fn eval_zero_value_and_splat(
2251        &mut self,
2252        mut expr: Handle<Expression>,
2253        span: Span,
2254    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2255        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2256        // each of its components.
2257        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2258            let components = components
2259                .clone()
2260                .iter()
2261                .map(|component| self.eval_zero_value_and_splat(*component, span))
2262                .collect::<Result<_, _>>()?;
2263            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2264        }
2265
2266        // The result of the splat() for a Splat of a scalar ZeroValue is a
2267        // vector ZeroValue, so we must call eval_zero_value_impl() after
2268        // splat() in order to ensure we have no ZeroValues remaining.
2269        if let Expression::Splat { size, value } = self.expressions[expr] {
2270            expr = self.splat(value, size, span)?;
2271        }
2272        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2273            expr = self.eval_zero_value_impl(ty, span)?;
2274        }
2275        Ok(expr)
2276    }
2277
2278    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2279    ///
2280    /// [`ZeroValue`]: Expression::ZeroValue
2281    /// [`Literal`]: Expression::Literal
2282    /// [`Compose`]: Expression::Compose
2283    fn eval_zero_value(
2284        &mut self,
2285        expr: Handle<Expression>,
2286        span: Span,
2287    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2288        match self.expressions[expr] {
2289            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2290            _ => Ok(expr),
2291        }
2292    }
2293
2294    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2295    ///
2296    /// [`ZeroValue`]: Expression::ZeroValue
2297    /// [`Literal`]: Expression::Literal
2298    /// [`Compose`]: Expression::Compose
2299    fn eval_zero_value_impl(
2300        &mut self,
2301        ty: Handle<Type>,
2302        span: Span,
2303    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2304        match self.types[ty].inner {
2305            TypeInner::Scalar(scalar) => {
2306                let expr = Expression::Literal(
2307                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2308                );
2309                self.register_evaluated_expr(expr, span)
2310            }
2311            TypeInner::Vector { size, scalar } => {
2312                let scalar_ty = self.types.insert(
2313                    Type {
2314                        name: None,
2315                        inner: TypeInner::Scalar(scalar),
2316                    },
2317                    span,
2318                );
2319                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2320                let expr = Expression::Compose {
2321                    ty,
2322                    components: vec![el; size as usize],
2323                };
2324                self.register_evaluated_expr(expr, span)
2325            }
2326            TypeInner::Matrix {
2327                columns,
2328                rows,
2329                scalar,
2330            } => {
2331                let vec_ty = self.types.insert(
2332                    Type {
2333                        name: None,
2334                        inner: TypeInner::Vector { size: rows, scalar },
2335                    },
2336                    span,
2337                );
2338                let el = self.eval_zero_value_impl(vec_ty, span)?;
2339                let expr = Expression::Compose {
2340                    ty,
2341                    components: vec![el; columns as usize],
2342                };
2343                self.register_evaluated_expr(expr, span)
2344            }
2345            TypeInner::Array {
2346                base,
2347                size: ArraySize::Constant(size),
2348                ..
2349            } => {
2350                let el = self.eval_zero_value_impl(base, span)?;
2351                let expr = Expression::Compose {
2352                    ty,
2353                    components: vec![el; size.get() as usize],
2354                };
2355                self.register_evaluated_expr(expr, span)
2356            }
2357            TypeInner::Struct { ref members, .. } => {
2358                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2359                let mut components = Vec::with_capacity(members.len());
2360                for ty in types {
2361                    components.push(self.eval_zero_value_impl(ty, span)?);
2362                }
2363                let expr = Expression::Compose { ty, components };
2364                self.register_evaluated_expr(expr, span)
2365            }
2366            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2367        }
2368    }
2369
2370    /// Convert the scalar components of `expr` to `target`.
2371    ///
2372    /// Treat `span` as the location of the resulting expression.
2373    pub fn cast(
2374        &mut self,
2375        expr: Handle<Expression>,
2376        target: crate::Scalar,
2377        span: Span,
2378    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2379        use crate::Scalar as Sc;
2380
2381        let expr = self.eval_zero_value(expr, span)?;
2382
2383        let make_error = || -> Result<_, ConstantEvaluatorError> {
2384            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2385
2386            #[cfg(feature = "wgsl-in")]
2387            let to = target.to_wgsl_for_diagnostics();
2388
2389            #[cfg(not(feature = "wgsl-in"))]
2390            let to = format!("{target:?}");
2391
2392            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2393        };
2394
2395        use crate::proc::type_methods::IntFloatLimits;
2396
2397        let expr = match self.expressions[expr] {
2398            Expression::Literal(literal) => {
2399                let literal = match target {
2400                    Sc::I16 => Literal::I16(match literal {
2401                        Literal::I16(v) => v,
2402                        Literal::U16(v) => v as i16,
2403                        Literal::I32(v) => v as i16,
2404                        Literal::U32(v) => v as i16,
2405                        Literal::F32(v) => v.clamp(i16::MIN as f32, i16::MAX as f32) as i16,
2406                        Literal::F16(v) => {
2407                            f16::to_f32(v).clamp(i16::MIN as f32, i16::MAX as f32) as i16
2408                        }
2409                        Literal::Bool(v) => v as i16,
2410                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2411                            return make_error();
2412                        }
2413                        Literal::AbstractInt(v) => v as i16,
2414                        Literal::AbstractFloat(v) => v as i16,
2415                    }),
2416                    Sc::U16 => Literal::U16(match literal {
2417                        Literal::I16(v) => v as u16,
2418                        Literal::U16(v) => v,
2419                        Literal::I32(v) => v as u16,
2420                        Literal::U32(v) => v as u16,
2421                        Literal::F32(v) => v.clamp(u16::MIN as f32, u16::MAX as f32) as u16,
2422                        Literal::F16(v) => f16::to_u16(&v.max(f16::ZERO)).unwrap(),
2423                        Literal::Bool(v) => v as u16,
2424                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2425                            return make_error();
2426                        }
2427                        Literal::AbstractInt(v) => v as u16,
2428                        Literal::AbstractFloat(v) => v as u16,
2429                    }),
2430                    Sc::I32 => Literal::I32(match literal {
2431                        Literal::I16(v) => v as i32,
2432                        Literal::U16(v) => v as i32,
2433                        Literal::I32(v) => v,
2434                        Literal::U32(v) => v as i32,
2435                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2436                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2437                        Literal::Bool(v) => v as i32,
2438                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2439                            return make_error();
2440                        }
2441                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2442                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2443                    }),
2444                    Sc::U32 => Literal::U32(match literal {
2445                        Literal::I16(v) => v as u32,
2446                        Literal::U16(v) => v as u32,
2447                        Literal::I32(v) => v as u32,
2448                        Literal::U32(v) => v,
2449                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2450                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2451                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2452                        Literal::Bool(v) => v as u32,
2453                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2454                            return make_error();
2455                        }
2456                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2457                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2458                    }),
2459                    Sc::I64 => Literal::I64(match literal {
2460                        Literal::I16(v) => v as i64,
2461                        Literal::U16(v) => v as i64,
2462                        Literal::I32(v) => v as i64,
2463                        Literal::U32(v) => v as i64,
2464                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2465                        Literal::Bool(v) => v as i64,
2466                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2467                        Literal::I64(v) => v,
2468                        Literal::U64(v) => v as i64,
2469                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2470                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2471                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2472                    }),
2473                    Sc::U64 => Literal::U64(match literal {
2474                        Literal::I16(v) => v as u64,
2475                        Literal::U16(v) => v as u64,
2476                        Literal::I32(v) => v as u64,
2477                        Literal::U32(v) => v as u64,
2478                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2479                        Literal::Bool(v) => v as u64,
2480                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2481                        Literal::I64(v) => v as u64,
2482                        Literal::U64(v) => v,
2483                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2484                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2485                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2486                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2487                    }),
2488                    Sc::F16 => Literal::F16(match literal {
2489                        Literal::F16(v) => v,
2490                        Literal::F32(v) => f16::from_f32(v),
2491                        Literal::F64(v) => f16::from_f64(v),
2492                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2493                        Literal::I16(v) => f16::from_f32(v as f32),
2494                        Literal::U16(v) => f16::from_f32(v as f32),
2495                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2496                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2497                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2498                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2499                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2500                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2501                    }),
2502                    Sc::F32 => Literal::F32(match literal {
2503                        Literal::I16(v) => v as f32,
2504                        Literal::U16(v) => v as f32,
2505                        Literal::I32(v) => v as f32,
2506                        Literal::U32(v) => v as f32,
2507                        Literal::F32(v) => v,
2508                        Literal::Bool(v) => v as u32 as f32,
2509                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2510                            return make_error();
2511                        }
2512                        Literal::F16(v) => f16::to_f32(v),
2513                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2514                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2515                    }),
2516                    Sc::F64 => Literal::F64(match literal {
2517                        Literal::I16(v) => v as f64,
2518                        Literal::U16(v) => v as f64,
2519                        Literal::I32(v) => v as f64,
2520                        Literal::U32(v) => v as f64,
2521                        Literal::F16(v) => f16::to_f64(v),
2522                        Literal::F32(v) => v as f64,
2523                        Literal::F64(v) => v,
2524                        Literal::Bool(v) => v as u32 as f64,
2525                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2526                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2527                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2528                    }),
2529                    Sc::BOOL => Literal::Bool(match literal {
2530                        Literal::I16(v) => v != 0,
2531                        Literal::U16(v) => v != 0,
2532                        Literal::I32(v) => v != 0,
2533                        Literal::U32(v) => v != 0,
2534                        Literal::F32(v) => v != 0.0,
2535                        Literal::F16(v) => v != f16::zero(),
2536                        Literal::Bool(v) => v,
2537                        Literal::AbstractInt(v) => v != 0,
2538                        Literal::AbstractFloat(v) => v != 0.0,
2539                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2540                            return make_error();
2541                        }
2542                    }),
2543                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2544                        Literal::AbstractInt(v) => {
2545                            // Overflow is forbidden, but inexact conversions
2546                            // are fine. The range of f64 is far larger than
2547                            // that of i64, so we don't have to check anything
2548                            // here.
2549                            v as f64
2550                        }
2551                        Literal::AbstractFloat(v) => v,
2552                        _ => return make_error(),
2553                    }),
2554                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2555                        Literal::AbstractInt(v) => v,
2556                        _ => return make_error(),
2557                    }),
2558                    _ => {
2559                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2560                        return make_error();
2561                    }
2562                };
2563                Expression::Literal(literal)
2564            }
2565            Expression::Compose {
2566                ty,
2567                components: ref src_components,
2568            } => {
2569                let ty_inner = match self.types[ty].inner {
2570                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2571                        size,
2572                        scalar: target,
2573                    },
2574                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2575                        columns,
2576                        rows,
2577                        scalar: target,
2578                    },
2579                    _ => return make_error(),
2580                };
2581
2582                let mut components = src_components.clone();
2583                for component in &mut components {
2584                    *component = self.cast(*component, target, span)?;
2585                }
2586
2587                let ty = self.types.insert(
2588                    Type {
2589                        name: None,
2590                        inner: ty_inner,
2591                    },
2592                    span,
2593                );
2594
2595                Expression::Compose { ty, components }
2596            }
2597            Expression::Splat { size, value } => {
2598                let value_span = self.expressions.get_span(value);
2599                let cast_value = self.cast(value, target, value_span)?;
2600                Expression::Splat {
2601                    size,
2602                    value: cast_value,
2603                }
2604            }
2605            _ => return make_error(),
2606        };
2607
2608        self.register_evaluated_expr(expr, span)
2609    }
2610
2611    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2612    ///
2613    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2614    /// matrix, or nested arrays of such.
2615    ///
2616    /// This is basically the same as the [`cast`] method, except that that
2617    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2618    ///
2619    /// Treat `span` as the location of the resulting expression.
2620    ///
2621    /// [`cast`]: ConstantEvaluator::cast
2622    /// [`As`]: crate::Expression::As
2623    pub fn cast_array(
2624        &mut self,
2625        expr: Handle<Expression>,
2626        target: crate::Scalar,
2627        span: Span,
2628    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2629        let expr = self.check_and_get(expr)?;
2630
2631        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2632            return self.cast(expr, target, span);
2633        };
2634
2635        let TypeInner::Array {
2636            base: _,
2637            size,
2638            stride: _,
2639        } = self.types[ty].inner
2640        else {
2641            return self.cast(expr, target, span);
2642        };
2643
2644        let mut components = components.clone();
2645        for component in &mut components {
2646            *component = self.cast_array(*component, target, span)?;
2647        }
2648
2649        let first = components.first().unwrap();
2650        let new_base = match self.resolve_type(*first)? {
2651            crate::proc::TypeResolution::Handle(ty) => ty,
2652            crate::proc::TypeResolution::Value(inner) => {
2653                self.types.insert(Type { name: None, inner }, span)
2654            }
2655        };
2656        let mut layouter = core::mem::take(self.layouter);
2657        layouter.update(self.to_ctx()).map_err(|err| {
2658            // The layouter operates lazily, so the error
2659            // could be for any pending type.
2660            let crate::proc::LayoutErrorInner::TooLarge = err.inner else {
2661                unreachable!("unexpected layout error: {err:?}");
2662            };
2663            ConstantEvaluatorError::TypeTooLarge(err.ty)
2664        })?;
2665        *self.layouter = layouter;
2666
2667        let new_base_stride = self.layouter[new_base].to_stride();
2668        let new_array_ty = self.types.insert(
2669            Type {
2670                name: None,
2671                inner: TypeInner::Array {
2672                    base: new_base,
2673                    size,
2674                    stride: new_base_stride,
2675                },
2676            },
2677            span,
2678        );
2679
2680        let compose = Expression::Compose {
2681            ty: new_array_ty,
2682            components,
2683        };
2684        self.register_evaluated_expr(compose, span)
2685    }
2686
2687    fn unary_op(
2688        &mut self,
2689        op: UnaryOperator,
2690        expr: Handle<Expression>,
2691        span: Span,
2692    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2693        let expr = self.eval_zero_value_and_splat(expr, span)?;
2694
2695        let expr = match self.expressions[expr] {
2696            Expression::Literal(value) => Expression::Literal(match op {
2697                UnaryOperator::Negate => match value {
2698                    Literal::I16(v) => Literal::I16(v.wrapping_neg()),
2699                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2700                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2701                    Literal::F32(v) => Literal::F32(-v),
2702                    Literal::F16(v) => Literal::F16(-v),
2703                    Literal::F64(v) => Literal::F64(-v),
2704                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2705                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2706                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2707                },
2708                UnaryOperator::LogicalNot => match value {
2709                    Literal::Bool(v) => Literal::Bool(!v),
2710                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2711                },
2712                UnaryOperator::BitwiseNot => match value {
2713                    Literal::I16(v) => Literal::I16(!v),
2714                    Literal::I32(v) => Literal::I32(!v),
2715                    Literal::I64(v) => Literal::I64(!v),
2716                    Literal::U16(v) => Literal::U16(!v),
2717                    Literal::U32(v) => Literal::U32(!v),
2718                    Literal::U64(v) => Literal::U64(!v),
2719                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2720                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2721                },
2722            }),
2723            Expression::Compose {
2724                ty,
2725                components: ref src_components,
2726            } => {
2727                match self.types[ty].inner {
2728                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2729                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2730                }
2731
2732                let mut components = src_components.clone();
2733                for component in &mut components {
2734                    *component = self.unary_op(op, *component, span)?;
2735                }
2736
2737                Expression::Compose { ty, components }
2738            }
2739            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2740        };
2741
2742        self.register_evaluated_expr(expr, span)
2743    }
2744
2745    fn binary_op(
2746        &mut self,
2747        op: BinaryOperator,
2748        left: Handle<Expression>,
2749        right: Handle<Expression>,
2750        span: Span,
2751    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2752        let left = self.eval_zero_value_and_splat(left, span)?;
2753        let right = self.eval_zero_value_and_splat(right, span)?;
2754
2755        // Note: in most cases constant evaluation checks for overflow, but for
2756        // i32/u32, it uses wrapping arithmetic. See
2757        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2758
2759        let expr = match (&self.expressions[left], &self.expressions[right]) {
2760            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2761                if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2762                    && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2763                {
2764                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2765                }
2766
2767                if matches!(
2768                    (left_value, op),
2769                    (
2770                        Literal::Bool(_),
2771                        BinaryOperator::Less
2772                            | BinaryOperator::LessEqual
2773                            | BinaryOperator::Greater
2774                            | BinaryOperator::GreaterEqual
2775                    )
2776                ) {
2777                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2778                }
2779
2780                let literal = match op {
2781                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2782                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2783                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2784                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2785                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2786                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2787
2788                    _ => match (left_value, right_value) {
2789                        (Literal::I16(a), Literal::I16(b)) => Literal::I16(match op {
2790                            BinaryOperator::Add => a.wrapping_add(b),
2791                            BinaryOperator::Subtract => a.wrapping_sub(b),
2792                            BinaryOperator::Multiply => a.wrapping_mul(b),
2793                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2794                                if b == 0 {
2795                                    ConstantEvaluatorError::DivisionByZero
2796                                } else {
2797                                    ConstantEvaluatorError::Overflow("division".into())
2798                                }
2799                            })?,
2800                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2801                                if b == 0 {
2802                                    ConstantEvaluatorError::RemainderByZero
2803                                } else {
2804                                    ConstantEvaluatorError::Overflow("remainder".into())
2805                                }
2806                            })?,
2807                            BinaryOperator::And => a & b,
2808                            BinaryOperator::ExclusiveOr => a ^ b,
2809                            BinaryOperator::InclusiveOr => a | b,
2810                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2811                        }),
2812                        (Literal::I16(a), Literal::U32(b)) => Literal::I16(match op {
2813                            BinaryOperator::ShiftLeft => {
2814                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2815                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2816                                }
2817                                a.checked_shl(b)
2818                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2819                            }
2820                            BinaryOperator::ShiftRight => a
2821                                .checked_shr(b)
2822                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2823                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2824                        }),
2825                        (Literal::U16(a), Literal::U16(b)) => Literal::U16(match op {
2826                            BinaryOperator::Add => a.wrapping_add(b),
2827                            BinaryOperator::Subtract => a.wrapping_sub(b),
2828                            BinaryOperator::Multiply => a.wrapping_mul(b),
2829                            BinaryOperator::Divide => a
2830                                .checked_div(b)
2831                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2832                            BinaryOperator::Modulo => a
2833                                .checked_rem(b)
2834                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2835                            BinaryOperator::And => a & b,
2836                            BinaryOperator::ExclusiveOr => a ^ b,
2837                            BinaryOperator::InclusiveOr => a | b,
2838                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2839                        }),
2840                        (Literal::U16(a), Literal::U32(b)) => Literal::U16(match op {
2841                            BinaryOperator::ShiftLeft => a
2842                                .checked_mul(
2843                                    1u16.checked_shl(b)
2844                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2845                                )
2846                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2847                            BinaryOperator::ShiftRight => a
2848                                .checked_shr(b)
2849                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2850                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2851                        }),
2852                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2853                            BinaryOperator::Add => a.wrapping_add(b),
2854                            BinaryOperator::Subtract => a.wrapping_sub(b),
2855                            BinaryOperator::Multiply => a.wrapping_mul(b),
2856                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2857                                if b == 0 {
2858                                    ConstantEvaluatorError::DivisionByZero
2859                                } else {
2860                                    ConstantEvaluatorError::Overflow("division".into())
2861                                }
2862                            })?,
2863                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2864                                if b == 0 {
2865                                    ConstantEvaluatorError::RemainderByZero
2866                                } else {
2867                                    ConstantEvaluatorError::Overflow("remainder".into())
2868                                }
2869                            })?,
2870                            BinaryOperator::And => a & b,
2871                            BinaryOperator::ExclusiveOr => a ^ b,
2872                            BinaryOperator::InclusiveOr => a | b,
2873                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2874                        }),
2875                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2876                            BinaryOperator::ShiftLeft => {
2877                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2878                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2879                                }
2880                                a.checked_shl(b)
2881                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2882                            }
2883                            BinaryOperator::ShiftRight => a
2884                                .checked_shr(b)
2885                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2886                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2887                        }),
2888                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2889                            BinaryOperator::Add => a.wrapping_add(b),
2890                            BinaryOperator::Subtract => a.wrapping_sub(b),
2891                            BinaryOperator::Multiply => a.wrapping_mul(b),
2892                            BinaryOperator::Divide => a
2893                                .checked_div(b)
2894                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2895                            BinaryOperator::Modulo => a
2896                                .checked_rem(b)
2897                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2898                            BinaryOperator::And => a & b,
2899                            BinaryOperator::ExclusiveOr => a ^ b,
2900                            BinaryOperator::InclusiveOr => a | b,
2901                            BinaryOperator::ShiftLeft => a
2902                                .checked_mul(
2903                                    1u32.checked_shl(b)
2904                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2905                                )
2906                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2907                            BinaryOperator::ShiftRight => a
2908                                .checked_shr(b)
2909                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2910                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2911                        }),
2912                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2913                            BinaryOperator::Add => a + b,
2914                            BinaryOperator::Subtract => a - b,
2915                            BinaryOperator::Multiply => a * b,
2916                            BinaryOperator::Divide => a / b,
2917                            BinaryOperator::Modulo => a % b,
2918                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2919                        }),
2920                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2921                            Literal::AbstractInt(match op {
2922                                BinaryOperator::ShiftLeft => {
2923                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2924                                        return Err(ConstantEvaluatorError::Overflow(
2925                                            "<<".to_string(),
2926                                        ));
2927                                    }
2928                                    a.checked_shl(b).unwrap_or(0)
2929                                }
2930                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2931                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2932                            })
2933                        }
2934                        (Literal::F16(a), Literal::F16(b)) => {
2935                            let result = match op {
2936                                BinaryOperator::Add => a + b,
2937                                BinaryOperator::Subtract => a - b,
2938                                BinaryOperator::Multiply => a * b,
2939                                BinaryOperator::Divide => a / b,
2940                                BinaryOperator::Modulo => a % b,
2941                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2942                            };
2943                            if !result.is_finite() {
2944                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2945                            }
2946                            Literal::F16(result)
2947                        }
2948                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2949                            Literal::AbstractInt(match op {
2950                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2951                                    ConstantEvaluatorError::Overflow("addition".into())
2952                                })?,
2953                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2954                                    ConstantEvaluatorError::Overflow("subtraction".into())
2955                                })?,
2956                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2957                                    ConstantEvaluatorError::Overflow("multiplication".into())
2958                                })?,
2959                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2960                                    if b == 0 {
2961                                        ConstantEvaluatorError::DivisionByZero
2962                                    } else {
2963                                        ConstantEvaluatorError::Overflow("division".into())
2964                                    }
2965                                })?,
2966                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2967                                    if b == 0 {
2968                                        ConstantEvaluatorError::RemainderByZero
2969                                    } else {
2970                                        ConstantEvaluatorError::Overflow("remainder".into())
2971                                    }
2972                                })?,
2973                                BinaryOperator::And => a & b,
2974                                BinaryOperator::ExclusiveOr => a ^ b,
2975                                BinaryOperator::InclusiveOr => a | b,
2976                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2977                            })
2978                        }
2979                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2980                            let result = match op {
2981                                BinaryOperator::Add => a + b,
2982                                BinaryOperator::Subtract => a - b,
2983                                BinaryOperator::Multiply => a * b,
2984                                BinaryOperator::Divide => a / b,
2985                                BinaryOperator::Modulo => a % b,
2986                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2987                            };
2988                            if !result.is_finite() {
2989                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2990                            }
2991                            Literal::AbstractFloat(result)
2992                        }
2993                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2994                            BinaryOperator::LogicalAnd => a && b,
2995                            BinaryOperator::LogicalOr => a || b,
2996                            BinaryOperator::And => a & b,
2997                            BinaryOperator::InclusiveOr => a | b,
2998                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2999                        }),
3000                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3001                    },
3002                };
3003                Expression::Literal(literal)
3004            }
3005            (
3006                &Expression::Compose {
3007                    components: ref src_components,
3008                    ty,
3009                },
3010                &Expression::Literal(_),
3011            ) => {
3012                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3013                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3014                }
3015                let mut components = src_components.clone();
3016                for component in &mut components {
3017                    *component = self.binary_op(op, *component, right, span)?;
3018                }
3019                Expression::Compose { ty, components }
3020            }
3021            (
3022                &Expression::Literal(_),
3023                &Expression::Compose {
3024                    components: ref src_components,
3025                    ty,
3026                },
3027            ) => {
3028                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3029                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3030                }
3031                let mut components = src_components.clone();
3032                for component in &mut components {
3033                    *component = self.binary_op(op, left, *component, span)?;
3034                }
3035                Expression::Compose { ty, components }
3036            }
3037            (
3038                &Expression::Compose {
3039                    components: ref left_components,
3040                    ty: left_ty,
3041                },
3042                &Expression::Compose {
3043                    components: ref right_components,
3044                    ty: right_ty,
3045                },
3046            ) => {
3047                // We have to make a copy of the component lists, because the
3048                // call to `binary_op_vector` needs `&mut self`, but `self` owns
3049                // the component lists.
3050                let left_flattened = crate::proc::flatten_compose(
3051                    left_ty,
3052                    left_components,
3053                    self.expressions,
3054                    self.types,
3055                )
3056                .collect::<Vec<_>>();
3057                let right_flattened = crate::proc::flatten_compose(
3058                    right_ty,
3059                    right_components,
3060                    self.expressions,
3061                    self.types,
3062                )
3063                .collect::<Vec<_>>();
3064
3065                self.binary_op_compose(
3066                    op,
3067                    &left_flattened,
3068                    &right_flattened,
3069                    left_ty,
3070                    right_ty,
3071                    span,
3072                )?
3073            }
3074            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3075        };
3076
3077        return self.register_evaluated_expr(expr, span);
3078
3079        fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
3080            let is_numeric_vec = matches!(
3081                compose_ty, TypeInner::Vector { scalar, .. }
3082                if scalar.kind != ScalarKind::Bool
3083            );
3084            let is_allowed_vec_scalar_op = matches!(
3085                op,
3086                BinaryOperator::Add
3087                    | BinaryOperator::Subtract
3088                    | BinaryOperator::Multiply
3089                    | BinaryOperator::Divide
3090                    | BinaryOperator::Modulo
3091            );
3092            let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
3093            let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
3094            is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
3095        }
3096    }
3097
3098    fn binary_op_compose(
3099        &mut self,
3100        op: BinaryOperator,
3101        left_components: &[Handle<Expression>],
3102        right_components: &[Handle<Expression>],
3103        left_ty: Handle<Type>,
3104        right_ty: Handle<Type>,
3105        span: Span,
3106    ) -> Result<Expression, ConstantEvaluatorError> {
3107        match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
3108            // Binary operation on vector-vector
3109            (
3110                &TypeInner::Vector {
3111                    size: left_size, ..
3112                },
3113                &TypeInner::Vector {
3114                    size: right_size, ..
3115                },
3116            ) if left_size == right_size => self.binary_op_vector(
3117                op,
3118                left_size,
3119                left_components,
3120                right_components,
3121                left_ty,
3122                span,
3123            ),
3124            // Binary operation on vector-matrix
3125            (
3126                &TypeInner::Vector { size, .. },
3127                &TypeInner::Matrix {
3128                    columns,
3129                    rows,
3130                    scalar,
3131                },
3132            ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3133                left_components,
3134                right_components,
3135                columns,
3136                scalar,
3137                span,
3138            ),
3139            // Binary operation on matrix-vector
3140            (
3141                &TypeInner::Matrix {
3142                    columns,
3143                    rows,
3144                    scalar,
3145                },
3146                &TypeInner::Vector { size, .. },
3147            ) if op == BinaryOperator::Multiply && size == columns => {
3148                self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3149            }
3150            // Binary operation on matrix-matrix
3151            (
3152                &TypeInner::Matrix {
3153                    columns: left_columns,
3154                    rows: left_rows,
3155                    scalar,
3156                },
3157                &TypeInner::Matrix {
3158                    columns: right_columns,
3159                    rows: right_rows,
3160                    ..
3161                },
3162            ) => match op {
3163                BinaryOperator::Add | BinaryOperator::Subtract
3164                    if left_columns == right_columns && left_rows == right_rows =>
3165                {
3166                    let components = left_components
3167                        .iter()
3168                        .zip(right_components)
3169                        .map(|(&left, &right)| self.binary_op(op, left, right, span))
3170                        .collect::<Result<Vec<_>, _>>()?;
3171                    Ok(Expression::Compose {
3172                        ty: left_ty,
3173                        components,
3174                    })
3175                }
3176                BinaryOperator::Multiply if left_columns == right_rows => self
3177                    .multiply_matrix_matrix(
3178                        left_components,
3179                        right_components,
3180                        left_rows,
3181                        right_columns,
3182                        scalar,
3183                        span,
3184                    ),
3185                _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3186            },
3187            _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3188        }
3189    }
3190
3191    fn binary_op_vector(
3192        &mut self,
3193        op: BinaryOperator,
3194        size: crate::VectorSize,
3195        left_components: &[Handle<Expression>],
3196        right_components: &[Handle<Expression>],
3197        left_ty: Handle<Type>,
3198        span: Span,
3199    ) -> Result<Expression, ConstantEvaluatorError> {
3200        let ty = match op {
3201            // Relational operators produce vectors of booleans.
3202            BinaryOperator::Equal
3203            | BinaryOperator::NotEqual
3204            | BinaryOperator::Less
3205            | BinaryOperator::LessEqual
3206            | BinaryOperator::Greater
3207            | BinaryOperator::GreaterEqual => self.types.insert(
3208                Type {
3209                    name: None,
3210                    inner: TypeInner::Vector {
3211                        size,
3212                        scalar: crate::Scalar::BOOL,
3213                    },
3214                },
3215                span,
3216            ),
3217
3218            // Other operators produce the same type as their left
3219            // operand.
3220            BinaryOperator::Add
3221            | BinaryOperator::Subtract
3222            | BinaryOperator::Multiply
3223            | BinaryOperator::Divide
3224            | BinaryOperator::Modulo
3225            | BinaryOperator::And
3226            | BinaryOperator::ExclusiveOr
3227            | BinaryOperator::InclusiveOr
3228            | BinaryOperator::ShiftLeft
3229            | BinaryOperator::ShiftRight => left_ty,
3230
3231            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3232                // Not supported on vectors
3233                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3234            }
3235        };
3236
3237        let components = left_components
3238            .iter()
3239            .zip(right_components)
3240            .map(|(&left, &right)| self.binary_op(op, left, right, span))
3241            .collect::<Result<Vec<_>, _>>()?;
3242
3243        Ok(Expression::Compose { ty, components })
3244    }
3245
3246    fn multiply_vector_matrix(
3247        &mut self,
3248        vec_components: &[Handle<Expression>],
3249        mat_components: &[Handle<Expression>],
3250        mat_columns: crate::VectorSize,
3251        scalar: crate::Scalar,
3252        span: Span,
3253    ) -> Result<Expression, ConstantEvaluatorError> {
3254        let ty = self.types.insert(
3255            Type {
3256                name: None,
3257                inner: TypeInner::Vector {
3258                    size: mat_columns,
3259                    scalar,
3260                },
3261            },
3262            span,
3263        );
3264        let components = mat_components
3265            .iter()
3266            .map(|&column| {
3267                let Expression::Compose { ref components, .. } = self.expressions[column] else {
3268                    unreachable!()
3269                };
3270                self.dot_exprs(
3271                    vec_components.iter().cloned(),
3272                    components.clone().into_iter(),
3273                    span,
3274                )
3275            })
3276            .collect::<Result<Vec<_>, _>>()?;
3277        Ok(Expression::Compose { ty, components })
3278    }
3279
3280    fn multiply_matrix_vector(
3281        &mut self,
3282        mat_components: &[Handle<Expression>],
3283        vec_components: &[Handle<Expression>],
3284        mat_rows: crate::VectorSize,
3285        scalar: crate::Scalar,
3286        span: Span,
3287    ) -> Result<Expression, ConstantEvaluatorError> {
3288        let ty = self.types.insert(
3289            Type {
3290                name: None,
3291                inner: TypeInner::Vector {
3292                    size: mat_rows,
3293                    scalar,
3294                },
3295            },
3296            span,
3297        );
3298
3299        let flatten = self.flatten_matrix(mat_components);
3300        let nr = mat_rows as usize;
3301        let components = (0..nr)
3302            .map(|r| {
3303                let row = flatten.iter().skip(r).step_by(nr).cloned();
3304                self.dot_exprs(row, vec_components.iter().cloned(), span)
3305            })
3306            .collect::<Result<Vec<_>, _>>()?;
3307        Ok(Expression::Compose { ty, components })
3308    }
3309
3310    fn multiply_matrix_matrix(
3311        &mut self,
3312        left_components: &[Handle<Expression>],
3313        right_components: &[Handle<Expression>],
3314        left_rows: crate::VectorSize,
3315        right_columns: crate::VectorSize,
3316        scalar: crate::Scalar,
3317        span: Span,
3318    ) -> Result<Expression, ConstantEvaluatorError> {
3319        let left_nc = left_components.len();
3320        let left_nr = left_rows as usize;
3321        let right_nc = right_columns as usize;
3322        let right_nr = left_nc;
3323
3324        let mut result = Vec::with_capacity(right_nc);
3325        let result_ty = self.types.insert(
3326            Type {
3327                name: None,
3328                inner: TypeInner::Matrix {
3329                    columns: right_columns,
3330                    rows: left_rows,
3331                    scalar,
3332                },
3333            },
3334            span,
3335        );
3336        let result_column_ty = self.types.insert(
3337            Type {
3338                name: None,
3339                inner: TypeInner::Vector {
3340                    size: left_rows,
3341                    scalar,
3342                },
3343            },
3344            span,
3345        );
3346
3347        let left_flattened = self.flatten_matrix(left_components);
3348        let right_flattened = self.flatten_matrix(right_components);
3349        for c in 0..right_nc {
3350            let result_column = (0..left_nr)
3351                .map(|r| {
3352                    let row = left_flattened.iter().skip(r).step_by(left_nr);
3353                    let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3354                    self.dot_exprs(row.cloned(), column.cloned(), span)
3355                })
3356                .collect::<Result<Vec<_>, _>>()?;
3357            let expr = Expression::Compose {
3358                ty: result_column_ty,
3359                components: result_column,
3360            };
3361            let handle = self.register_evaluated_expr(expr, span)?;
3362            result.push(handle);
3363        }
3364        Ok(Expression::Compose {
3365            ty: result_ty,
3366            components: result,
3367        })
3368    }
3369
3370    fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3371        let mut flattened = ArrayVec::<_, 16>::new();
3372        for &column in columns {
3373            let Expression::Compose { ref components, .. } = self.expressions[column] else {
3374                unreachable!()
3375            };
3376            flattened.extend(components.iter().cloned());
3377        }
3378        flattened
3379    }
3380
3381    fn dot_exprs(
3382        &mut self,
3383        left: impl Iterator<Item = Handle<Expression>>,
3384        right: impl Iterator<Item = Handle<Expression>>,
3385        span: Span,
3386    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3387        let mut acc = None;
3388        for (l, r) in left.zip(right) {
3389            let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3390            match acc.as_mut() {
3391                Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3392                None => acc = Some(result),
3393            }
3394        }
3395        Ok(acc.unwrap())
3396    }
3397
3398    fn relational(
3399        &mut self,
3400        fun: RelationalFunction,
3401        arg: Handle<Expression>,
3402        span: Span,
3403    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3404        let arg = self.eval_zero_value_and_splat(arg, span)?;
3405        match fun {
3406            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3407                Expression::Literal(Literal::Bool(_)) => Ok(arg),
3408                Expression::Compose { ty, ref components }
3409                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3410                {
3411                    let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3412                    for component in
3413                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3414                    {
3415                        match self.expressions[component] {
3416                            Expression::Literal(Literal::Bool(val)) => {
3417                                bool_components.push(val);
3418                            }
3419                            _ => {
3420                                return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3421                            }
3422                        }
3423                    }
3424                    let components = bool_components;
3425                    let result = match fun {
3426                        RelationalFunction::All => components.iter().all(|c| *c),
3427                        RelationalFunction::Any => components.iter().any(|c| *c),
3428                        _ => unreachable!(),
3429                    };
3430                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3431                }
3432                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3433            },
3434            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3435                "{fun:?} built-in function"
3436            ))),
3437        }
3438    }
3439
3440    /// Deep copy `expr` from `expressions` into `self.expressions`.
3441    ///
3442    /// Return the root of the new copy.
3443    ///
3444    /// This is used when we're evaluating expressions in a function's
3445    /// expression arena that refer to a constant: we need to copy the
3446    /// constant's value into the function's arena so we can operate on it.
3447    fn copy_from(
3448        &mut self,
3449        expr: Handle<Expression>,
3450        expressions: &Arena<Expression>,
3451    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3452        let span = expressions.get_span(expr);
3453        match expressions[expr] {
3454            ref expr @ (Expression::Literal(_)
3455            | Expression::Constant(_)
3456            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3457            Expression::Compose { ty, ref components } => {
3458                let mut components = components.clone();
3459                for component in &mut components {
3460                    *component = self.copy_from(*component, expressions)?;
3461                }
3462                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3463            }
3464            Expression::Splat { size, value } => {
3465                let value = self.copy_from(value, expressions)?;
3466                self.register_evaluated_expr(Expression::Splat { size, value }, span)
3467            }
3468            _ => {
3469                log::debug!("copy_from: SubexpressionsAreNotConstant");
3470                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3471            }
3472        }
3473    }
3474
3475    /// Returns the total number of components, after flattening, of a vector compose expression.
3476    fn vector_compose_flattened_size(
3477        &self,
3478        components: &[Handle<Expression>],
3479    ) -> Result<usize, ConstantEvaluatorError> {
3480        components
3481            .iter()
3482            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3483                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3484                    TypeInner::Scalar(_) => 1,
3485                    // We trust that the vector size of `component` is correct,
3486                    // as it will have already been validated when `component`
3487                    // was registered.
3488                    TypeInner::Vector { size, .. } => size as usize,
3489                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3490                };
3491                Ok(acc + size)
3492            })
3493    }
3494
3495    fn register_evaluated_expr(
3496        &mut self,
3497        expr: Expression,
3498        span: Span,
3499    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3500        // It suffices to only check_literal_value() for `Literal` expressions,
3501        // since we only register one expression at a time, `Compose`
3502        // expressions can only refer to other expressions, and `ZeroValue`
3503        // expressions are always okay.
3504        if let Expression::Literal(literal) = expr {
3505            crate::valid::check_literal_value(literal)?;
3506        }
3507
3508        // Ensure vector composes contain the correct number of components. We
3509        // do so here when each compose is registered to avoid having to deal
3510        // with the mess each time the compose is used in another expression.
3511        if let Expression::Compose { ty, ref components } = expr {
3512            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3513                let expected = size as usize;
3514                let actual = self.vector_compose_flattened_size(components)?;
3515                if expected != actual {
3516                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3517                        expected,
3518                        actual,
3519                    });
3520                }
3521            }
3522        }
3523
3524        Ok(self.append_expr(expr, span, ExpressionKind::Const))
3525    }
3526
3527    fn append_expr(
3528        &mut self,
3529        expr: Expression,
3530        span: Span,
3531        expr_type: ExpressionKind,
3532    ) -> Handle<Expression> {
3533        let h = match self.behavior {
3534            Behavior::Wgsl(
3535                WgslRestrictions::Runtime(ref mut function_local_data)
3536                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3537            )
3538            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3539                let is_running = function_local_data.emitter.is_running();
3540                let needs_pre_emit = expr.needs_pre_emit();
3541                if is_running && needs_pre_emit {
3542                    function_local_data
3543                        .block
3544                        .extend(function_local_data.emitter.finish(self.expressions));
3545                    let h = self.expressions.append(expr, span);
3546                    function_local_data.emitter.start(self.expressions);
3547                    h
3548                } else {
3549                    self.expressions.append(expr, span)
3550                }
3551            }
3552            _ => self.expressions.append(expr, span),
3553        };
3554        self.expression_kind_tracker.insert(h, expr_type);
3555        h
3556    }
3557
3558    /// Resolve the type of `expr` if it is a constant expression.
3559    ///
3560    /// If `expr` was evaluated to a constant, returns its type.
3561    /// Otherwise, returns an error.
3562    fn resolve_type(
3563        &self,
3564        expr: Handle<Expression>,
3565    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3566        use crate::proc::TypeResolution as Tr;
3567        use crate::Expression as Ex;
3568        let resolution = match self.expressions[expr] {
3569            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3570            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3571            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3572            Ex::Splat { size, value } => {
3573                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3574                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3575                };
3576                Tr::Value(TypeInner::Vector { scalar, size })
3577            }
3578            _ => {
3579                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3580                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3581            }
3582        };
3583
3584        Ok(resolution)
3585    }
3586
3587    fn select(
3588        &mut self,
3589        reject: Handle<Expression>,
3590        accept: Handle<Expression>,
3591        condition: Handle<Expression>,
3592        span: Span,
3593    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3594        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3595
3596        let reject = arg(reject)?;
3597        let accept = arg(accept)?;
3598        let condition = arg(condition)?;
3599
3600        let select_single_component =
3601            |this: &mut Self, reject_scalar, reject, accept, condition| {
3602                let accept = this.cast(accept, reject_scalar, span)?;
3603                if condition {
3604                    Ok(accept)
3605                } else {
3606                    Ok(reject)
3607                }
3608            };
3609
3610        match (&self.expressions[reject], &self.expressions[accept]) {
3611            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3612                let reject_scalar = reject_lit.scalar();
3613                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3614                else {
3615                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3616                };
3617                select_single_component(self, reject_scalar, reject, accept, condition)
3618            }
3619            (
3620                &Expression::Compose {
3621                    ty: reject_ty,
3622                    components: ref reject_components,
3623                },
3624                &Expression::Compose {
3625                    ty: accept_ty,
3626                    components: ref accept_components,
3627                },
3628            ) => {
3629                let ty_deets = |ty| {
3630                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3631                    (size.unwrap(), scalar)
3632                };
3633
3634                let expected_vec_size = {
3635                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3636                        [reject_ty, accept_ty].map(ty_deets);
3637
3638                    if reject_vec_size != accept_vec_size {
3639                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3640                            reject: reject_vec_size,
3641                            accept: accept_vec_size,
3642                        });
3643                    }
3644                    reject_vec_size
3645                };
3646
3647                let condition_components = match self.expressions[condition] {
3648                    Expression::Literal(Literal::Bool(condition)) => {
3649                        vec![condition; (expected_vec_size as u8).into()]
3650                    }
3651                    Expression::Compose {
3652                        ty: condition_ty,
3653                        components: ref condition_components,
3654                    } => {
3655                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3656                        if condition_scalar.kind != ScalarKind::Bool {
3657                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3658                        }
3659                        if condition_vec_size != expected_vec_size {
3660                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3661                        }
3662                        condition_components
3663                            .iter()
3664                            .copied()
3665                            .map(|component| match &self.expressions[component] {
3666                                &Expression::Literal(Literal::Bool(condition)) => condition,
3667                                _ => unreachable!(),
3668                            })
3669                            .collect()
3670                    }
3671
3672                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3673                };
3674
3675                let evaluated = Expression::Compose {
3676                    ty: reject_ty,
3677                    components: reject_components
3678                        .clone()
3679                        .into_iter()
3680                        .zip(accept_components.clone().into_iter())
3681                        .zip(condition_components.into_iter())
3682                        .map(|((reject, accept), condition)| {
3683                            let reject_scalar = match &self.expressions[reject] {
3684                                &Expression::Literal(lit) => lit.scalar(),
3685                                _ => unreachable!(),
3686                            };
3687                            select_single_component(self, reject_scalar, reject, accept, condition)
3688                        })
3689                        .collect::<Result<_, _>>()?,
3690                };
3691                self.register_evaluated_expr(evaluated, span)
3692            }
3693            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3694        }
3695    }
3696}
3697
3698fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3699    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3700    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3701    // return a right-to-left bit index of 0.
3702    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3703        match e {
3704            idx @ 0..=31 => idx,
3705            32 => u32::MAX,
3706            _ => unreachable!(),
3707        }
3708    };
3709    match concrete_int {
3710        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3711        ConcreteInt::I32([e]) => {
3712            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3713        }
3714    }
3715}
3716
3717#[test]
3718fn first_trailing_bit_smoke() {
3719    assert_eq!(
3720        first_trailing_bit(ConcreteInt::I32([0])),
3721        ConcreteInt::I32([-1])
3722    );
3723    assert_eq!(
3724        first_trailing_bit(ConcreteInt::I32([1])),
3725        ConcreteInt::I32([0])
3726    );
3727    assert_eq!(
3728        first_trailing_bit(ConcreteInt::I32([2])),
3729        ConcreteInt::I32([1])
3730    );
3731    assert_eq!(
3732        first_trailing_bit(ConcreteInt::I32([-1])),
3733        ConcreteInt::I32([0]),
3734    );
3735    assert_eq!(
3736        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3737        ConcreteInt::I32([31]),
3738    );
3739    assert_eq!(
3740        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3741        ConcreteInt::I32([0]),
3742    );
3743    for idx in 0..32 {
3744        assert_eq!(
3745            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3746            ConcreteInt::I32([idx])
3747        )
3748    }
3749
3750    assert_eq!(
3751        first_trailing_bit(ConcreteInt::U32([0])),
3752        ConcreteInt::U32([u32::MAX])
3753    );
3754    assert_eq!(
3755        first_trailing_bit(ConcreteInt::U32([1])),
3756        ConcreteInt::U32([0])
3757    );
3758    assert_eq!(
3759        first_trailing_bit(ConcreteInt::U32([2])),
3760        ConcreteInt::U32([1])
3761    );
3762    assert_eq!(
3763        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3764        ConcreteInt::U32([31]),
3765    );
3766    assert_eq!(
3767        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3768        ConcreteInt::U32([0]),
3769    );
3770    for idx in 0..32 {
3771        assert_eq!(
3772            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3773            ConcreteInt::U32([idx])
3774        )
3775    }
3776}
3777
3778fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3779    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3780    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3781    // index of 0.
3782    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3783        match e {
3784            idx @ 0..=31 => 31 - idx,
3785            32 => u32::MAX,
3786            _ => unreachable!(),
3787        }
3788    };
3789    match concrete_int {
3790        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3791            let rtl_bit_index = if e.is_negative() {
3792                e.leading_ones()
3793            } else {
3794                e.leading_zeros()
3795            };
3796            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3797        }]),
3798        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3799    }
3800}
3801
3802#[test]
3803fn first_leading_bit_smoke() {
3804    assert_eq!(
3805        first_leading_bit(ConcreteInt::I32([-1])),
3806        ConcreteInt::I32([-1])
3807    );
3808    assert_eq!(
3809        first_leading_bit(ConcreteInt::I32([0])),
3810        ConcreteInt::I32([-1])
3811    );
3812    assert_eq!(
3813        first_leading_bit(ConcreteInt::I32([1])),
3814        ConcreteInt::I32([0])
3815    );
3816    assert_eq!(
3817        first_leading_bit(ConcreteInt::I32([-2])),
3818        ConcreteInt::I32([0])
3819    );
3820    assert_eq!(
3821        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3822        ConcreteInt::I32([12])
3823    );
3824    assert_eq!(
3825        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3826        ConcreteInt::I32([30])
3827    );
3828    assert_eq!(
3829        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3830        ConcreteInt::I32([30])
3831    );
3832    // NOTE: Ignore the sign bit, which is a separate (above) case.
3833    for idx in 0..(32 - 1) {
3834        assert_eq!(
3835            first_leading_bit(ConcreteInt::I32([1 << idx])),
3836            ConcreteInt::I32([idx])
3837        );
3838    }
3839    for idx in 1..(32 - 1) {
3840        assert_eq!(
3841            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3842            ConcreteInt::I32([idx - 1])
3843        );
3844    }
3845
3846    assert_eq!(
3847        first_leading_bit(ConcreteInt::U32([0])),
3848        ConcreteInt::U32([u32::MAX])
3849    );
3850    assert_eq!(
3851        first_leading_bit(ConcreteInt::U32([1])),
3852        ConcreteInt::U32([0])
3853    );
3854    assert_eq!(
3855        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3856        ConcreteInt::U32([31])
3857    );
3858    for idx in 0..32 {
3859        assert_eq!(
3860            first_leading_bit(ConcreteInt::U32([1 << idx])),
3861            ConcreteInt::U32([idx])
3862        )
3863    }
3864}
3865
3866/// Trait for conversions of abstract values to concrete types.
3867trait TryFromAbstract<T>: Sized {
3868    /// Convert an abstract literal `value` to `Self`.
3869    ///
3870    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3871    /// WGSL, we follow WGSL's conversion rules here:
3872    ///
3873    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3874    ///   from [`AbstractInt`] to an integer type are either lossless or an
3875    ///   error.
3876    ///
3877    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3878    ///   to floating point in constant expressions and override
3879    ///   expressions are errors if the value is out of range for the
3880    ///   destination type, but rounding is okay.
3881    ///
3882    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3883    ///   other floating point type, following the scalar floating point to
3884    ///   integral conversion algorithm (§15.7.6). There is no automatic
3885    ///   conversion from AbstractFloat to integer types.
3886    ///
3887    /// [`AbstractInt`]: crate::Literal::AbstractInt
3888    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3889    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3890}
3891
3892impl TryFromAbstract<i64> for i32 {
3893    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3894        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3895            value: format!("{value:?}"),
3896            to_type: "i32",
3897        })
3898    }
3899}
3900
3901impl TryFromAbstract<i64> for u32 {
3902    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3903        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3904            value: format!("{value:?}"),
3905            to_type: "u32",
3906        })
3907    }
3908}
3909
3910impl TryFromAbstract<i64> for u64 {
3911    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3912        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3913            value: format!("{value:?}"),
3914            to_type: "u64",
3915        })
3916    }
3917}
3918
3919impl TryFromAbstract<i64> for i64 {
3920    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3921        Ok(value)
3922    }
3923}
3924
3925impl TryFromAbstract<i64> for f32 {
3926    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3927        let f = value as f32;
3928        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3929        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3930        // overflow here.
3931        Ok(f)
3932    }
3933}
3934
3935impl TryFromAbstract<f64> for f32 {
3936    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3937        let f = value as f32;
3938        if f.is_infinite() {
3939            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3940                value: format!("{value:?}"),
3941                to_type: "f32",
3942            });
3943        }
3944        Ok(f)
3945    }
3946}
3947
3948impl TryFromAbstract<i64> for f64 {
3949    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3950        let f = value as f64;
3951        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3952        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3953        // overflow here.
3954        Ok(f)
3955    }
3956}
3957
3958impl TryFromAbstract<f64> for f64 {
3959    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3960        Ok(value)
3961    }
3962}
3963
3964impl TryFromAbstract<f64> for i32 {
3965    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3966        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3967        // To convert a floating point scalar value X to an integer scalar type T:
3968        // * If X is a NaN, the result is an indeterminate value in T.
3969        // * If X is exactly representable in the target type T, then the
3970        //   result is that value.
3971        // * Otherwise, the result is the value in T closest to truncate(X) and
3972        //   also exactly representable in the original floating point type.
3973        //
3974        // A rust cast satisfies these requirements apart from "the result
3975        // is... exactly representable in the original floating point type".
3976        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3977        // we're all good.
3978        Ok(value as i32)
3979    }
3980}
3981
3982impl TryFromAbstract<f64> for u32 {
3983    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3984        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3985        // so a simple rust cast is sufficient.
3986        Ok(value as u32)
3987    }
3988}
3989
3990impl TryFromAbstract<f64> for i64 {
3991    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3992        // As above, except we clamp to the minimum and maximum values
3993        // representable by both f64 and i64.
3994        use crate::proc::type_methods::IntFloatLimits;
3995        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3996    }
3997}
3998
3999impl TryFromAbstract<f64> for u64 {
4000    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
4001        // As above, this time clamping to the minimum and maximum values
4002        // representable by both f64 and u64.
4003        use crate::proc::type_methods::IntFloatLimits;
4004        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
4005    }
4006}
4007
4008impl TryFromAbstract<f64> for f16 {
4009    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
4010        let f = f16::from_f64(value);
4011        if f.is_infinite() {
4012            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4013                value: format!("{value:?}"),
4014                to_type: "f16",
4015            });
4016        }
4017        Ok(f)
4018    }
4019}
4020
4021impl TryFromAbstract<i64> for f16 {
4022    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
4023        let f = f16::from_i64(value);
4024        if f.is_none() {
4025            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4026                value: format!("{value:?}"),
4027                to_type: "f16",
4028            });
4029        }
4030        Ok(f.unwrap())
4031    }
4032}
4033
4034fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
4035where
4036    T: Copy,
4037    T: core::ops::Mul<T, Output = T>,
4038    T: core::ops::Sub<T, Output = T>,
4039{
4040    [
4041        a[1] * b[2] - a[2] * b[1],
4042        a[2] * b[0] - a[0] * b[2],
4043        a[0] * b[1] - a[1] * b[0],
4044    ]
4045}
4046
4047#[cfg(test)]
4048mod tests {
4049    use alloc::{vec, vec::Vec};
4050
4051    use crate::{
4052        Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
4053        Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
4054    };
4055
4056    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
4057
4058    #[test]
4059    fn unary_op() {
4060        let mut types = UniqueArena::new();
4061        let mut constants = Arena::new();
4062        let overrides = Arena::new();
4063        let mut global_expressions = Arena::new();
4064
4065        let scalar_ty = types.insert(
4066            Type {
4067                name: None,
4068                inner: TypeInner::Scalar(crate::Scalar::I32),
4069            },
4070            Default::default(),
4071        );
4072
4073        let vec_ty = types.insert(
4074            Type {
4075                name: None,
4076                inner: TypeInner::Vector {
4077                    size: VectorSize::Bi,
4078                    scalar: crate::Scalar::I32,
4079                },
4080            },
4081            Default::default(),
4082        );
4083
4084        let h = constants.append(
4085            Constant {
4086                name: None,
4087                ty: scalar_ty,
4088                init: global_expressions
4089                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4090            },
4091            Default::default(),
4092        );
4093
4094        let h1 = constants.append(
4095            Constant {
4096                name: None,
4097                ty: scalar_ty,
4098                init: global_expressions
4099                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
4100            },
4101            Default::default(),
4102        );
4103
4104        let vec_h = constants.append(
4105            Constant {
4106                name: None,
4107                ty: vec_ty,
4108                init: global_expressions.append(
4109                    Expression::Compose {
4110                        ty: vec_ty,
4111                        components: vec![constants[h].init, constants[h1].init],
4112                    },
4113                    Default::default(),
4114                ),
4115            },
4116            Default::default(),
4117        );
4118
4119        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4120        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
4121
4122        let expr2 = Expression::Unary {
4123            op: UnaryOperator::Negate,
4124            expr,
4125        };
4126
4127        let expr3 = Expression::Unary {
4128            op: UnaryOperator::BitwiseNot,
4129            expr,
4130        };
4131
4132        let expr4 = Expression::Unary {
4133            op: UnaryOperator::BitwiseNot,
4134            expr: expr1,
4135        };
4136
4137        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4138        let mut solver = ConstantEvaluator {
4139            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4140            types: &mut types,
4141            constants: &constants,
4142            overrides: &overrides,
4143            expressions: &mut global_expressions,
4144            expression_kind_tracker,
4145            layouter: &mut crate::proc::Layouter::default(),
4146        };
4147
4148        let res1 = solver
4149            .try_eval_and_append(expr2, Default::default())
4150            .unwrap();
4151        let res2 = solver
4152            .try_eval_and_append(expr3, Default::default())
4153            .unwrap();
4154        let res3 = solver
4155            .try_eval_and_append(expr4, Default::default())
4156            .unwrap();
4157
4158        assert_eq!(
4159            global_expressions[res1],
4160            Expression::Literal(Literal::I32(-4))
4161        );
4162
4163        assert_eq!(
4164            global_expressions[res2],
4165            Expression::Literal(Literal::I32(!4))
4166        );
4167
4168        let res3_inner = &global_expressions[res3];
4169
4170        match *res3_inner {
4171            Expression::Compose {
4172                ref ty,
4173                ref components,
4174            } => {
4175                assert_eq!(*ty, vec_ty);
4176                let mut components_iter = components.iter().copied();
4177                assert_eq!(
4178                    global_expressions[components_iter.next().unwrap()],
4179                    Expression::Literal(Literal::I32(!4))
4180                );
4181                assert_eq!(
4182                    global_expressions[components_iter.next().unwrap()],
4183                    Expression::Literal(Literal::I32(!8))
4184                );
4185                assert!(components_iter.next().is_none());
4186            }
4187            _ => panic!("Expected vector"),
4188        }
4189    }
4190
4191    #[test]
4192    fn matrix_op() {
4193        let mut helper = MatrixTestHelper::new();
4194
4195        for nc in 2..=4 {
4196            for nr in 2..=4 {
4197                // Validates multiplication on vector-matrix.
4198                // vecR(0, 1, .., r) * matCxR(0, 1, .., nc * nr)
4199                let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4200                let expected = (0..nc)
4201                    .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4202                    .collect::<Vec<f32>>();
4203                assert_eq!(evaluated, expected);
4204
4205                // Validates multiplication on matrix-vector.
4206                // matCxR(0, 1, .., nc * nr) * vecC(0, 1, .., nc)
4207                let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4208                let expected = (0..nr)
4209                    .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4210                    .collect::<Vec<f32>>();
4211                assert_eq!(evaluated, expected);
4212
4213                for k in 2..=4 {
4214                    // Validates multiplication on matrix-matrix.
4215                    // matKxR(0, 1, .., k * nr) * matCxK(0, 1, .., nc * k)
4216                    let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4217                    let expected = (0..nc)
4218                        .flat_map(|c| {
4219                            (0..nr).map(move |r| {
4220                                (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4221                            })
4222                        })
4223                        .collect::<Vec<f32>>();
4224                    assert_eq!(evaluated, expected);
4225                }
4226            }
4227        }
4228    }
4229
4230    /// Test fixture providing pre-built f32 vector and matrix constant
4231    /// expressions with sequential element values, used to evaluate and verify
4232    /// matrix operations.
4233    struct MatrixTestHelper {
4234        types: UniqueArena<Type>,
4235        expressions: Arena<Expression>,
4236        /// Vector expressions from [0, 1] to [0, 1, 2, 3].
4237        vec_exprs: FastHashMap<usize, Handle<Expression>>,
4238        /// Matrix expressions from [0, .., 3] to [0, .., 15].
4239        mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4240    }
4241
4242    impl MatrixTestHelper {
4243        fn new() -> Self {
4244            let mut types = UniqueArena::new();
4245            let mut expressions = Arena::new();
4246            let span = crate::Span::default();
4247
4248            let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4249            for c in 2..=4 {
4250                let vec_ty = types.insert(
4251                    Type {
4252                        name: None,
4253                        inner: TypeInner::Vector {
4254                            size: Self::int_to_vector_size(c),
4255                            scalar: crate::Scalar::F32,
4256                        },
4257                    },
4258                    span,
4259                );
4260                vec_tys.insert(c, vec_ty);
4261                for r in 2..=4 {
4262                    let mat_ty = types.insert(
4263                        Type {
4264                            name: None,
4265                            inner: TypeInner::Matrix {
4266                                columns: Self::int_to_vector_size(c),
4267                                rows: Self::int_to_vector_size(r),
4268                                scalar: crate::Scalar::F32,
4269                            },
4270                        },
4271                        span,
4272                    );
4273                    mat_tys.insert((c, r), mat_ty);
4274                }
4275            }
4276
4277            let mut lit_exprs = FastHashMap::default();
4278            for i in 0..16 {
4279                let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4280                lit_exprs.insert(i, expr);
4281            }
4282
4283            let mut vec_exprs = FastHashMap::default();
4284            for c in 2..=4 {
4285                let expr = expressions.append(
4286                    Expression::Compose {
4287                        ty: *vec_tys.get(&c).unwrap(),
4288                        components: (0..c)
4289                            .map(|i| *lit_exprs.get(&i).unwrap())
4290                            .collect::<Vec<_>>(),
4291                    },
4292                    span,
4293                );
4294                vec_exprs.insert(c, expr);
4295            }
4296
4297            let mut mat_exprs = FastHashMap::default();
4298            for c in 2..=4 {
4299                for r in 2..=4 {
4300                    let mut columns = Vec::with_capacity(c);
4301                    for cc in 0..c {
4302                        let start = cc * r;
4303                        let expr = expressions.append(
4304                            Expression::Compose {
4305                                ty: *vec_tys.get(&r).unwrap(),
4306                                components: (start..start + r)
4307                                    .map(|i| *lit_exprs.get(&i).unwrap())
4308                                    .collect::<Vec<_>>(),
4309                            },
4310                            span,
4311                        );
4312                        columns.push(expr);
4313                    }
4314
4315                    let expr = expressions.append(
4316                        Expression::Compose {
4317                            ty: *mat_tys.get(&(c, r)).unwrap(),
4318                            components: columns,
4319                        },
4320                        span,
4321                    );
4322                    mat_exprs.insert((c, r), expr);
4323                }
4324            }
4325
4326            Self {
4327                types,
4328                expressions,
4329                vec_exprs,
4330                mat_exprs,
4331            }
4332        }
4333
4334        /// Evaluates vec[0..nr] * mat[0..nc*nr] and returns the result as f32s.
4335        fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4336            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4337            let mut solver = ConstantEvaluator {
4338                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4339                types: &mut self.types,
4340                constants: &Arena::new(),
4341                overrides: &Arena::new(),
4342                expressions: &mut self.expressions,
4343                expression_kind_tracker,
4344                layouter: &mut crate::proc::Layouter::default(),
4345            };
4346
4347            let result = solver
4348                .try_eval_and_append(
4349                    Expression::Binary {
4350                        op: BinaryOperator::Multiply,
4351                        left: *self.vec_exprs.get(&nr).unwrap(),
4352                        right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4353                    },
4354                    Default::default(),
4355                )
4356                .unwrap();
4357            self.flatten(result)
4358        }
4359
4360        /// Evaluates mat[0..nc*nr] * vec[0..nc] and returns the result as f32s.
4361        fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4362            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4363            let mut solver = ConstantEvaluator {
4364                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4365                types: &mut self.types,
4366                constants: &Arena::new(),
4367                overrides: &Arena::new(),
4368                expressions: &mut self.expressions,
4369                expression_kind_tracker,
4370                layouter: &mut crate::proc::Layouter::default(),
4371            };
4372
4373            let result = solver
4374                .try_eval_and_append(
4375                    Expression::Binary {
4376                        op: BinaryOperator::Multiply,
4377                        left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4378                        right: *self.vec_exprs.get(&nc).unwrap(),
4379                    },
4380                    Default::default(),
4381                )
4382                .unwrap();
4383            self.flatten(result)
4384        }
4385
4386        /// Evaluates mat[0..k*l_nr] * mat[0..r_nc*k] and returns the result as
4387        /// f32s.
4388        fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4389            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4390            let mut solver = ConstantEvaluator {
4391                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4392                types: &mut self.types,
4393                constants: &Arena::new(),
4394                overrides: &Arena::new(),
4395                expressions: &mut self.expressions,
4396                expression_kind_tracker,
4397                layouter: &mut crate::proc::Layouter::default(),
4398            };
4399
4400            let result = solver
4401                .try_eval_and_append(
4402                    Expression::Binary {
4403                        op: BinaryOperator::Multiply,
4404                        left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4405                        right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4406                    },
4407                    Default::default(),
4408                )
4409                .unwrap();
4410            self.flatten(result)
4411        }
4412
4413        fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4414            let Expression::Compose {
4415                ref components,
4416                ref ty,
4417            } = self.expressions[expr]
4418            else {
4419                unreachable!()
4420            };
4421
4422            match self.types[*ty].inner {
4423                TypeInner::Vector { .. } => components
4424                    .iter()
4425                    .map(|&comp| {
4426                        let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4427                            unreachable!()
4428                        };
4429                        v
4430                    })
4431                    .collect(),
4432                TypeInner::Matrix { .. } => components
4433                    .iter()
4434                    .flat_map(|&comp| self.flatten(comp))
4435                    .collect(),
4436                _ => unreachable!(),
4437            }
4438        }
4439
4440        fn int_to_vector_size(int: usize) -> VectorSize {
4441            match int {
4442                2 => VectorSize::Bi,
4443                3 => VectorSize::Tri,
4444                4 => VectorSize::Quad,
4445                _ => unreachable!(),
4446            }
4447        }
4448    }
4449
4450    #[test]
4451    fn cast() {
4452        let mut types = UniqueArena::new();
4453        let mut constants = Arena::new();
4454        let overrides = Arena::new();
4455        let mut global_expressions = Arena::new();
4456
4457        let scalar_ty = types.insert(
4458            Type {
4459                name: None,
4460                inner: TypeInner::Scalar(crate::Scalar::I32),
4461            },
4462            Default::default(),
4463        );
4464
4465        let h = constants.append(
4466            Constant {
4467                name: None,
4468                ty: scalar_ty,
4469                init: global_expressions
4470                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4471            },
4472            Default::default(),
4473        );
4474
4475        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4476
4477        let root = Expression::As {
4478            expr,
4479            kind: ScalarKind::Bool,
4480            convert: Some(crate::BOOL_WIDTH),
4481        };
4482
4483        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4484        let mut solver = ConstantEvaluator {
4485            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4486            types: &mut types,
4487            constants: &constants,
4488            overrides: &overrides,
4489            expressions: &mut global_expressions,
4490            expression_kind_tracker,
4491            layouter: &mut crate::proc::Layouter::default(),
4492        };
4493
4494        let res = solver
4495            .try_eval_and_append(root, Default::default())
4496            .unwrap();
4497
4498        assert_eq!(
4499            global_expressions[res],
4500            Expression::Literal(Literal::Bool(true))
4501        );
4502    }
4503
4504    #[test]
4505    fn access() {
4506        let mut types = UniqueArena::new();
4507        let mut constants = Arena::new();
4508        let overrides = Arena::new();
4509        let mut global_expressions = Arena::new();
4510
4511        let matrix_ty = types.insert(
4512            Type {
4513                name: None,
4514                inner: TypeInner::Matrix {
4515                    columns: VectorSize::Bi,
4516                    rows: VectorSize::Tri,
4517                    scalar: crate::Scalar::F32,
4518                },
4519            },
4520            Default::default(),
4521        );
4522
4523        let vec_ty = types.insert(
4524            Type {
4525                name: None,
4526                inner: TypeInner::Vector {
4527                    size: VectorSize::Tri,
4528                    scalar: crate::Scalar::F32,
4529                },
4530            },
4531            Default::default(),
4532        );
4533
4534        let mut vec1_components = Vec::with_capacity(3);
4535        let mut vec2_components = Vec::with_capacity(3);
4536
4537        for i in 0..3 {
4538            let h = global_expressions.append(
4539                Expression::Literal(Literal::F32(i as f32)),
4540                Default::default(),
4541            );
4542
4543            vec1_components.push(h)
4544        }
4545
4546        for i in 3..6 {
4547            let h = global_expressions.append(
4548                Expression::Literal(Literal::F32(i as f32)),
4549                Default::default(),
4550            );
4551
4552            vec2_components.push(h)
4553        }
4554
4555        let vec1 = constants.append(
4556            Constant {
4557                name: None,
4558                ty: vec_ty,
4559                init: global_expressions.append(
4560                    Expression::Compose {
4561                        ty: vec_ty,
4562                        components: vec1_components,
4563                    },
4564                    Default::default(),
4565                ),
4566            },
4567            Default::default(),
4568        );
4569
4570        let vec2 = constants.append(
4571            Constant {
4572                name: None,
4573                ty: vec_ty,
4574                init: global_expressions.append(
4575                    Expression::Compose {
4576                        ty: vec_ty,
4577                        components: vec2_components,
4578                    },
4579                    Default::default(),
4580                ),
4581            },
4582            Default::default(),
4583        );
4584
4585        let h = constants.append(
4586            Constant {
4587                name: None,
4588                ty: matrix_ty,
4589                init: global_expressions.append(
4590                    Expression::Compose {
4591                        ty: matrix_ty,
4592                        components: vec![constants[vec1].init, constants[vec2].init],
4593                    },
4594                    Default::default(),
4595                ),
4596            },
4597            Default::default(),
4598        );
4599
4600        let base = global_expressions.append(Expression::Constant(h), Default::default());
4601
4602        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4603        let mut solver = ConstantEvaluator {
4604            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4605            types: &mut types,
4606            constants: &constants,
4607            overrides: &overrides,
4608            expressions: &mut global_expressions,
4609            expression_kind_tracker,
4610            layouter: &mut crate::proc::Layouter::default(),
4611        };
4612
4613        let root1 = Expression::AccessIndex { base, index: 1 };
4614
4615        let res1 = solver
4616            .try_eval_and_append(root1, Default::default())
4617            .unwrap();
4618
4619        let root2 = Expression::AccessIndex {
4620            base: res1,
4621            index: 2,
4622        };
4623
4624        let res2 = solver
4625            .try_eval_and_append(root2, Default::default())
4626            .unwrap();
4627
4628        match global_expressions[res1] {
4629            Expression::Compose {
4630                ref ty,
4631                ref components,
4632            } => {
4633                assert_eq!(*ty, vec_ty);
4634                let mut components_iter = components.iter().copied();
4635                assert_eq!(
4636                    global_expressions[components_iter.next().unwrap()],
4637                    Expression::Literal(Literal::F32(3.))
4638                );
4639                assert_eq!(
4640                    global_expressions[components_iter.next().unwrap()],
4641                    Expression::Literal(Literal::F32(4.))
4642                );
4643                assert_eq!(
4644                    global_expressions[components_iter.next().unwrap()],
4645                    Expression::Literal(Literal::F32(5.))
4646                );
4647                assert!(components_iter.next().is_none());
4648            }
4649            _ => panic!("Expected vector"),
4650        }
4651
4652        assert_eq!(
4653            global_expressions[res2],
4654            Expression::Literal(Literal::F32(5.))
4655        );
4656    }
4657
4658    #[test]
4659    fn compose_of_constants() {
4660        let mut types = UniqueArena::new();
4661        let mut constants = Arena::new();
4662        let overrides = Arena::new();
4663        let mut global_expressions = Arena::new();
4664
4665        let i32_ty = types.insert(
4666            Type {
4667                name: None,
4668                inner: TypeInner::Scalar(crate::Scalar::I32),
4669            },
4670            Default::default(),
4671        );
4672
4673        let vec2_i32_ty = types.insert(
4674            Type {
4675                name: None,
4676                inner: TypeInner::Vector {
4677                    size: VectorSize::Bi,
4678                    scalar: crate::Scalar::I32,
4679                },
4680            },
4681            Default::default(),
4682        );
4683
4684        let h = constants.append(
4685            Constant {
4686                name: None,
4687                ty: i32_ty,
4688                init: global_expressions
4689                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4690            },
4691            Default::default(),
4692        );
4693
4694        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4695
4696        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4697        let mut solver = ConstantEvaluator {
4698            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4699            types: &mut types,
4700            constants: &constants,
4701            overrides: &overrides,
4702            expressions: &mut global_expressions,
4703            expression_kind_tracker,
4704            layouter: &mut crate::proc::Layouter::default(),
4705        };
4706
4707        let solved_compose = solver
4708            .try_eval_and_append(
4709                Expression::Compose {
4710                    ty: vec2_i32_ty,
4711                    components: vec![h_expr, h_expr],
4712                },
4713                Default::default(),
4714            )
4715            .unwrap();
4716        let solved_negate = solver
4717            .try_eval_and_append(
4718                Expression::Unary {
4719                    op: UnaryOperator::Negate,
4720                    expr: solved_compose,
4721                },
4722                Default::default(),
4723            )
4724            .unwrap();
4725
4726        let pass = match global_expressions[solved_negate] {
4727            Expression::Compose { ty, ref components } => {
4728                ty == vec2_i32_ty
4729                    && components.iter().all(|&component| {
4730                        let component = &global_expressions[component];
4731                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4732                    })
4733            }
4734            _ => false,
4735        };
4736        if !pass {
4737            panic!("unexpected evaluation result")
4738        }
4739    }
4740
4741    #[test]
4742    fn splat_of_constant() {
4743        let mut types = UniqueArena::new();
4744        let mut constants = Arena::new();
4745        let overrides = Arena::new();
4746        let mut global_expressions = Arena::new();
4747
4748        let i32_ty = types.insert(
4749            Type {
4750                name: None,
4751                inner: TypeInner::Scalar(crate::Scalar::I32),
4752            },
4753            Default::default(),
4754        );
4755
4756        let vec2_i32_ty = types.insert(
4757            Type {
4758                name: None,
4759                inner: TypeInner::Vector {
4760                    size: VectorSize::Bi,
4761                    scalar: crate::Scalar::I32,
4762                },
4763            },
4764            Default::default(),
4765        );
4766
4767        let h = constants.append(
4768            Constant {
4769                name: None,
4770                ty: i32_ty,
4771                init: global_expressions
4772                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4773            },
4774            Default::default(),
4775        );
4776
4777        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4778
4779        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4780        let mut solver = ConstantEvaluator {
4781            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4782            types: &mut types,
4783            constants: &constants,
4784            overrides: &overrides,
4785            expressions: &mut global_expressions,
4786            expression_kind_tracker,
4787            layouter: &mut crate::proc::Layouter::default(),
4788        };
4789
4790        let solved_compose = solver
4791            .try_eval_and_append(
4792                Expression::Splat {
4793                    size: VectorSize::Bi,
4794                    value: h_expr,
4795                },
4796                Default::default(),
4797            )
4798            .unwrap();
4799        let solved_negate = solver
4800            .try_eval_and_append(
4801                Expression::Unary {
4802                    op: UnaryOperator::Negate,
4803                    expr: solved_compose,
4804                },
4805                Default::default(),
4806            )
4807            .unwrap();
4808
4809        let pass = match global_expressions[solved_negate] {
4810            Expression::Compose { ty, ref components } => {
4811                ty == vec2_i32_ty
4812                    && components.iter().all(|&component| {
4813                        let component = &global_expressions[component];
4814                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4815                    })
4816            }
4817            _ => false,
4818        };
4819        if !pass {
4820            panic!("unexpected evaluation result")
4821        }
4822    }
4823
4824    #[test]
4825    fn splat_of_zero_value() {
4826        let mut types = UniqueArena::new();
4827        let constants = Arena::new();
4828        let overrides = Arena::new();
4829        let mut global_expressions = Arena::new();
4830
4831        let f32_ty = types.insert(
4832            Type {
4833                name: None,
4834                inner: TypeInner::Scalar(crate::Scalar::F32),
4835            },
4836            Default::default(),
4837        );
4838
4839        let vec2_f32_ty = types.insert(
4840            Type {
4841                name: None,
4842                inner: TypeInner::Vector {
4843                    size: VectorSize::Bi,
4844                    scalar: crate::Scalar::F32,
4845                },
4846            },
4847            Default::default(),
4848        );
4849
4850        let five =
4851            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4852        let five_splat = global_expressions.append(
4853            Expression::Splat {
4854                size: VectorSize::Bi,
4855                value: five,
4856            },
4857            Default::default(),
4858        );
4859        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4860        let zero_splat = global_expressions.append(
4861            Expression::Splat {
4862                size: VectorSize::Bi,
4863                value: zero,
4864            },
4865            Default::default(),
4866        );
4867
4868        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4869        let mut solver = ConstantEvaluator {
4870            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4871            types: &mut types,
4872            constants: &constants,
4873            overrides: &overrides,
4874            expressions: &mut global_expressions,
4875            expression_kind_tracker,
4876            layouter: &mut crate::proc::Layouter::default(),
4877        };
4878
4879        let solved_add = solver
4880            .try_eval_and_append(
4881                Expression::Binary {
4882                    op: BinaryOperator::Add,
4883                    left: zero_splat,
4884                    right: five_splat,
4885                },
4886                Default::default(),
4887            )
4888            .unwrap();
4889
4890        let pass = match global_expressions[solved_add] {
4891            Expression::Compose { ty, ref components } => {
4892                ty == vec2_f32_ty
4893                    && components.iter().all(|&component| {
4894                        let component = &global_expressions[component];
4895                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4896                    })
4897            }
4898            _ => false,
4899        };
4900        if !pass {
4901            panic!("unexpected evaluation result")
4902        }
4903    }
4904}