naga/proc/
constant_evaluator.rs

1// Code in this file intentionally uses `for` loops and `.push()` rather than
2// `ArrayVec::from_iter`, because the latter is monomorphized by all three of
3// the item type, the capacity, and the iterator type, which can easily bloat
4// the compiled executable (by ~260 KiB, when it was removed).
5
6use alloc::{
7    format,
8    string::{String, ToString},
9    vec,
10    vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19    arena::{Arena, Handle, HandleVec, UniqueArena},
20    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21    ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
28/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
29///
30/// Technique stolen directly from
31/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
32macro_rules! with_dollar_sign {
33    ($($body:tt)*) => {
34        macro_rules! __with_dollar_sign { $($body)* }
35        __with_dollar_sign!($);
36    }
37}
38
39macro_rules! gen_component_wise_extractor {
40    (
41        $ident:ident -> $target:ident,
42        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44    ) => {
45        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
46        #[derive(Debug)]
47        #[cfg_attr(test, derive(PartialEq))]
48        enum $target<const N: usize> {
49            $(
50                #[doc = concat!(
51                    "Maps to [`Literal::",
52                    stringify!($literal),
53                    "`]",
54                )]
55                $mapping([$ty; N]),
56            )+
57        }
58
59        impl From<$target<1>> for Expression {
60            fn from(value: $target<1>) -> Self {
61                match value {
62                    $(
63                        $target::$mapping([value]) => {
64                            Expression::Literal(Literal::$literal(value))
65                        }
66                    )+
67                }
68            }
69        }
70
71        #[doc = concat!(
72            "Attempts to evaluate multiple `exprs` as a combined [`",
73            stringify!($target),
74            "`] to pass to `handler`. ",
75        )]
76        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
77        /// component of each vector.
78        ///
79        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
80        /// same length, a new vector expression is registered, composed of each component emitted
81        /// by `handler`.
82        fn $ident<const N: usize, const M: usize>(
83            eval: &mut ConstantEvaluator<'_>,
84            span: Span,
85            exprs: [Handle<Expression>; N],
86            handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88        where
89            $target<M>: Into<Expression>,
90        {
91            assert!(N > 0);
92            let err = ConstantEvaluatorError::InvalidMathArg;
93            let mut exprs = exprs.into_iter();
94
95            macro_rules! sanitize {
96                ($expr:expr) => {
97                    eval.eval_zero_value_and_splat($expr, span)
98                        .map(|expr| &eval.expressions[expr])
99                };
100            }
101
102            let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103                $(
104                    &Expression::Literal(Literal::$literal(x)) => {
105                        let mut arr = ArrayVec::<_, N>::new();
106                        arr.push(x);
107                        for expr in exprs {
108                            match sanitize!(expr)? {
109                                &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110                                _ => return Err(err),
111                            }
112                        }
113                        let comps = $target::$mapping(arr.into_inner().unwrap());
114                        Ok(handler(comps)?.into())
115                    },
116                )+
117                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118                    &TypeInner::Vector { size, scalar } => match scalar.kind {
119                        $(ScalarKind::$scalar_kind)|* => {
120                            let first_ty = ty;
121                            let mut component_groups =
122                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123                            {
124                                let mut inner = ArrayVec::new();
125                                for item in crate::proc::flatten_compose(
126                                    first_ty,
127                                    components,
128                                    eval.expressions,
129                                    eval.types,
130                                ) {
131                                    inner.push(item);
132                                }
133                                component_groups.push(inner);
134                            }
135                            for expr in exprs {
136                                match sanitize!(expr)? {
137                                    &Expression::Compose { ty, ref components }
138                                        if &eval.types[ty].inner
139                                            == &eval.types[first_ty].inner =>
140                                    {
141                                        let mut inner = ArrayVec::new();
142                                        for item in crate::proc::flatten_compose(
143                                            ty,
144                                            components,
145                                            eval.expressions,
146                                            eval.types,
147                                        ) {
148                                            inner.push(item);
149                                        }
150                                        component_groups.push(inner);
151                                    }
152                                    _ => return Err(err),
153                                }
154                            }
155                            let component_groups = component_groups.into_inner().unwrap();
156                            let mut new_components =
157                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158                            for idx in 0..(size as u8).into() {
159                                let mut group_arr = ArrayVec::<_, N>::new();
160                                for cs in component_groups.iter() {
161                                    group_arr.push(
162                                        cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163                                    );
164                                }
165                                let group = group_arr.into_inner().unwrap();
166                                new_components.push($ident(
167                                    eval,
168                                    span,
169                                    group,
170                                    handler,
171                                )?);
172                            }
173                            Ok(Expression::Compose {
174                                ty: first_ty,
175                                components: new_components.into_iter().collect(),
176                            })
177                        }
178                        _ => return Err(err),
179                    },
180                    _ => return Err(err),
181                },
182                _ => return Err(err),
183            };
184            eval.register_evaluated_expr(new_expr?, span)
185        }
186
187        with_dollar_sign! {
188            ($d:tt) => {
189                #[allow(unused)]
190                #[doc = concat!(
191                    "A convenience macro for using the same RHS for each [`",
192                    stringify!($target),
193                    "`] variant in a call to [`",
194                    stringify!($ident),
195                    "`].",
196                )]
197                macro_rules! $ident {
198                    (
199                        $eval:expr,
200                        $span:expr,
201                        [$d ($d expr:expr),+ $d (,)?],
202                        |$d ($d arg:ident),+| $d tt:tt
203                    ) => {
204                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
205                            $(
206                                $target::$mapping([$d ($d arg),+]) => {
207                                    let res = $d tt;
208                                    Result::map(res, $target::$mapping)
209                                },
210                            )+
211                        })
212                    };
213                }
214            };
215        }
216    };
217}
218
219gen_component_wise_extractor! {
220    component_wise_scalar -> Scalar,
221    literals: [
222        AbstractFloat => AbstractFloat: f64,
223        F32 => F32: f32,
224        F16 => F16: f16,
225        AbstractInt => AbstractInt: i64,
226        U32 => U32: u32,
227        I32 => I32: i32,
228        U64 => U64: u64,
229        I64 => I64: i64,
230    ],
231    scalar_kinds: [
232        Float,
233        AbstractFloat,
234        Sint,
235        Uint,
236        AbstractInt,
237    ],
238}
239
240gen_component_wise_extractor! {
241    component_wise_float -> Float,
242    literals: [
243        AbstractFloat => Abstract: f64,
244        F32 => F32: f32,
245        F16 => F16: f16,
246    ],
247    scalar_kinds: [
248        Float,
249        AbstractFloat,
250    ],
251}
252
253gen_component_wise_extractor! {
254    component_wise_concrete_int -> ConcreteInt,
255    literals: [
256        U32 => U32: u32,
257        I32 => I32: i32,
258    ],
259    scalar_kinds: [
260        Sint,
261        Uint,
262    ],
263}
264
265gen_component_wise_extractor! {
266    component_wise_signed -> Signed,
267    literals: [
268        AbstractFloat => AbstractFloat: f64,
269        AbstractInt => AbstractInt: i64,
270        F32 => F32: f32,
271        F16 => F16: f16,
272        I32 => I32: i32,
273    ],
274    scalar_kinds: [
275        Sint,
276        AbstractInt,
277        Float,
278        AbstractFloat,
279    ],
280}
281
282/// Vectors with a concrete element type.
283#[derive(Debug)]
284enum LiteralVector {
285    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288    U16(ArrayVec<u16, { crate::VectorSize::MAX }>),
289    I16(ArrayVec<i16, { crate::VectorSize::MAX }>),
290    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
291    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
292    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
293    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
294    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
295    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
296    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
297}
298
299impl LiteralVector {
300    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
301    fn len(&self) -> usize {
302        match *self {
303            LiteralVector::F64(ref v) => v.len(),
304            LiteralVector::F32(ref v) => v.len(),
305            LiteralVector::F16(ref v) => v.len(),
306            LiteralVector::U16(ref v) => v.len(),
307            LiteralVector::I16(ref v) => v.len(),
308            LiteralVector::U32(ref v) => v.len(),
309            LiteralVector::I32(ref v) => v.len(),
310            LiteralVector::U64(ref v) => v.len(),
311            LiteralVector::I64(ref v) => v.len(),
312            LiteralVector::Bool(ref v) => v.len(),
313            LiteralVector::AbstractInt(ref v) => v.len(),
314            LiteralVector::AbstractFloat(ref v) => v.len(),
315        }
316    }
317
318    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
319    fn from_literal(literal: Literal) -> Self {
320        fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
321            let mut v = ArrayVec::new();
322            v.push(val);
323            v
324        }
325        match literal {
326            Literal::F64(e) => Self::F64(arrayvec_of(e)),
327            Literal::F32(e) => Self::F32(arrayvec_of(e)),
328            Literal::U16(e) => Self::U16(arrayvec_of(e)),
329            Literal::I16(e) => Self::I16(arrayvec_of(e)),
330            Literal::U32(e) => Self::U32(arrayvec_of(e)),
331            Literal::I32(e) => Self::I32(arrayvec_of(e)),
332            Literal::U64(e) => Self::U64(arrayvec_of(e)),
333            Literal::I64(e) => Self::I64(arrayvec_of(e)),
334            Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
335            Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
336            Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
337            Literal::F16(e) => Self::F16(arrayvec_of(e)),
338        }
339    }
340
341    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
342    /// Returns error if components types do not match.
343    /// # Panics
344    /// Panics if vector is empty
345    fn from_literal_vec(
346        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
347    ) -> Result<Self, ConstantEvaluatorError> {
348        assert!(!components.is_empty());
349        // TODO: should a vector of i32 be constructible from abstract int?
350        macro_rules! compose_literals {
351            ($components:expr, $variant:ident, $self_variant:ident) => {{
352                let mut out = ArrayVec::new();
353                for l in &$components {
354                    match l {
355                        &Literal::$variant(v) => out.push(v),
356                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
357                    }
358                }
359                Self::$self_variant(out)
360            }};
361        }
362        Ok(match components[0] {
363            Literal::I16(_) => compose_literals!(components, I16, I16),
364            Literal::U16(_) => compose_literals!(components, U16, U16),
365            Literal::I32(_) => compose_literals!(components, I32, I32),
366            Literal::U32(_) => compose_literals!(components, U32, U32),
367            Literal::I64(_) => compose_literals!(components, I64, I64),
368            Literal::U64(_) => compose_literals!(components, U64, U64),
369            Literal::F32(_) => compose_literals!(components, F32, F32),
370            Literal::F64(_) => compose_literals!(components, F64, F64),
371            Literal::Bool(_) => compose_literals!(components, Bool, Bool),
372            Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
373            Literal::AbstractFloat(_) => {
374                compose_literals!(components, AbstractFloat, AbstractFloat)
375            }
376            Literal::F16(_) => compose_literals!(components, F16, F16),
377        })
378    }
379
380    #[allow(dead_code)]
381    /// Returns [`ArrayVec`] of [`Literal`]s
382    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
383        macro_rules! decompose_literals {
384            ($v:expr, $variant:ident) => {{
385                let mut out = ArrayVec::new();
386                for e in $v {
387                    out.push(Literal::$variant(*e));
388                }
389                out
390            }};
391        }
392        match *self {
393            LiteralVector::F64(ref v) => decompose_literals!(v, F64),
394            LiteralVector::F32(ref v) => decompose_literals!(v, F32),
395            LiteralVector::F16(ref v) => decompose_literals!(v, F16),
396            LiteralVector::U16(ref v) => decompose_literals!(v, U16),
397            LiteralVector::I16(ref v) => decompose_literals!(v, I16),
398            LiteralVector::U32(ref v) => decompose_literals!(v, U32),
399            LiteralVector::I32(ref v) => decompose_literals!(v, I32),
400            LiteralVector::U64(ref v) => decompose_literals!(v, U64),
401            LiteralVector::I64(ref v) => decompose_literals!(v, I64),
402            LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
403            LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
404            LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
405        }
406    }
407
408    #[allow(dead_code)]
409    /// Puts self into eval's expressions arena and returns handle to it
410    fn register_as_evaluated_expr(
411        &self,
412        eval: &mut ConstantEvaluator<'_>,
413        span: Span,
414    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
415        let lit_vec = self.to_literal_vec();
416        assert!(!lit_vec.is_empty());
417        let expr = if lit_vec.len() == 1 {
418            Expression::Literal(lit_vec[0])
419        } else {
420            Expression::Compose {
421                ty: eval.types.insert(
422                    Type {
423                        name: None,
424                        inner: TypeInner::Vector {
425                            size: match lit_vec.len() {
426                                2 => crate::VectorSize::Bi,
427                                3 => crate::VectorSize::Tri,
428                                4 => crate::VectorSize::Quad,
429                                _ => unreachable!(),
430                            },
431                            scalar: lit_vec[0].scalar(),
432                        },
433                    },
434                    Span::UNDEFINED,
435                ),
436                components: lit_vec
437                    .iter()
438                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
439                    .collect::<Result<_, _>>()?,
440            }
441        };
442        eval.register_evaluated_expr(expr, span)
443    }
444}
445
446/// A macro for matching on [`LiteralVector`] variants.
447///
448/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
449/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
450///
451/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
452///
453/// Example usage:
454///
455/// ```rust,ignore
456/// match_literal_vector!(match v => Literal {
457///     F16 => |v| {v.sum()},
458///     Integer => |v| {v.sum()},
459///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
460/// })
461/// ```
462///
463/// ```rust,ignore
464/// match_literal_vector!(match (e1, e2) => LiteralVector {
465///     F16 => |e1, e2| {e1+e2},
466///     Integer => |e1, e2| {e1+e2},
467///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
468/// })
469/// ```
470macro_rules! match_literal_vector {
471    (match $lit_vec:expr => $out:ident {
472        $(
473            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
474        ),+
475        $(,)?
476    }) => {
477        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
478    };
479
480    (@inner_start
481        $lit_vec:expr;
482        $out:ident;
483        [$($ty:ident),+];
484        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
485    ) => {
486        match_literal_vector!(@inner
487            $lit_vec;
488            $out;
489            [$($ty),+];
490            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
491        )
492    };
493
494    (@inner
495        $lit_vec:expr;
496        $out:ident;
497        [$ty:ident $(, $ty1:ident)*];
498        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
499        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
500    ) => {
501        match_literal_vector!(@inner
502            $ty;
503            $lit_vec;
504            $out;
505            [$($ty1),*];
506            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
507            [$({ $($var),+ ; $($ret)? ; $body }),+]
508        )
509    };
510    (@inner
511        Integer;
512        $lit_vec:expr;
513        $out:ident;
514        [$($ty:ident),*];
515        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
516        [
517            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
518            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
519        ]
520    ) => {
521        match_literal_vector!(@inner
522            $lit_vec;
523            $out;
524            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
525            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
526            [
527                { $($var),+ ; $($ret)? ; $body }, // U32
528                { $($var),+ ; $($ret)? ; $body }, // I32
529                { $($var),+ ; $($ret)? ; $body }, // U64
530                { $($var),+ ; $($ret)? ; $body }, // I64
531                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
532                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
533            ]
534        )
535    };
536    (@inner
537        Float;
538        $lit_vec:expr;
539        $out:ident;
540        [$($ty:ident),*];
541        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
542        [
543            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
544            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
545        ]
546    ) => {
547        match_literal_vector!(@inner
548            $lit_vec;
549            $out;
550            [F16, F32, F64, AbstractFloat $(, $ty)*];
551            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
552            [
553                { $($var),+ ; $($ret)? ; $body }, // F16
554                { $($var),+ ; $($ret)? ; $body }, // F32
555                { $($var),+ ; $($ret)? ; $body }, // F64
556                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
557                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
558            ]
559        )
560    };
561    (@inner
562        $ty:ident;
563        $lit_vec:expr;
564        $out:ident;
565        [$ty1:ident $(,$ty2:ident)*];
566        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
567            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
568            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
569        ]
570    ) => {
571        match_literal_vector!(@inner
572            $ty1;
573            $lit_vec;
574            $out;
575            [$($ty2),*];
576            [
577                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
578                { $ty; $($var),+ ; $($ret)? ; $body }
579            ] <>
580            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
581
582        )
583    };
584    (@inner
585        $ty:ident;
586        $lit_vec:expr;
587        $out:ident;
588        [];
589        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
590        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
591    ) => {
592        match_literal_vector!(@inner_finish
593            $lit_vec;
594            $out;
595            [
596                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
597                { $ty; $($var),+ ; $($ret)? ; $body }
598            ]
599        )
600    };
601    (@inner_finish
602        $lit_vec:expr;
603        $out:ident;
604        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
605    ) => {
606        match $lit_vec {
607            $(
608                #[allow(unused_parens)]
609                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
610            )+
611            _ => Err(ConstantEvaluatorError::InvalidMathArg),
612        }
613    };
614    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
615        $out::$ty($body)
616    };
617    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
618        $out::$ret($body)
619    };
620}
621
622fn float_length<F>(e: &[F]) -> Option<F>
623where
624    F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
625{
626    if e.len() == 1 {
627        // Avoids possible overflow in squaring
628        Some(e[0].abs())
629    } else {
630        let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
631        result.is_finite().then_some(result)
632    }
633}
634
635#[derive(Debug)]
636enum Behavior<'a> {
637    Wgsl(WgslRestrictions<'a>),
638    Glsl(GlslRestrictions<'a>),
639}
640
641impl Behavior<'_> {
642    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
643    const fn has_runtime_restrictions(&self) -> bool {
644        matches!(
645            self,
646            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
647                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
648        )
649    }
650}
651
652/// A context for evaluating constant expressions.
653///
654/// A `ConstantEvaluator` points at an expression arena to which it can append
655/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
656/// of Naga [`Expression`] you like, and if its value can be computed at compile
657/// time, `try_eval_and_append` appends an expression representing the computed
658/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
659/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
660///
661/// A `ConstantEvaluator` also holds whatever information we need to carry out
662/// that evaluation: types, other constants, and so on.
663///
664/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
665/// [`Compose`]: Expression::Compose
666/// [`ZeroValue`]: Expression::ZeroValue
667/// [`Literal`]: Expression::Literal
668/// [`Swizzle`]: Expression::Swizzle
669#[derive(Debug)]
670pub struct ConstantEvaluator<'a> {
671    /// Which language's evaluation rules we should follow.
672    behavior: Behavior<'a>,
673
674    /// The module's type arena.
675    ///
676    /// Because expressions like [`Splat`] contain type handles, we need to be
677    /// able to add new types to produce those expressions.
678    ///
679    /// [`Splat`]: Expression::Splat
680    types: &'a mut UniqueArena<Type>,
681
682    /// The module's constant arena.
683    constants: &'a Arena<Constant>,
684
685    /// The module's override arena.
686    overrides: &'a Arena<Override>,
687
688    /// The arena to which we are contributing expressions.
689    expressions: &'a mut Arena<Expression>,
690
691    /// Tracks the constness of expressions residing in [`Self::expressions`]
692    expression_kind_tracker: &'a mut ExpressionKindTracker,
693
694    layouter: &'a mut crate::proc::Layouter,
695}
696
697#[derive(Debug)]
698enum WgslRestrictions<'a> {
699    /// - const-expressions will be evaluated and inserted in the arena
700    Const(Option<FunctionLocalData<'a>>),
701    /// - const-expressions will be evaluated and inserted in the arena
702    /// - override-expressions will be inserted in the arena
703    Override,
704    /// - const-expressions will be evaluated and inserted in the arena
705    /// - override-expressions will be inserted in the arena
706    /// - runtime-expressions will be inserted in the arena
707    Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711enum GlslRestrictions<'a> {
712    /// - const-expressions will be evaluated and inserted in the arena
713    Const,
714    /// - const-expressions will be evaluated and inserted in the arena
715    /// - override-expressions will be inserted in the arena
716    /// - runtime-expressions will be inserted in the arena
717    Runtime(FunctionLocalData<'a>),
718}
719
720#[derive(Debug)]
721struct FunctionLocalData<'a> {
722    /// Global constant expressions
723    global_expressions: &'a Arena<Expression>,
724    emitter: &'a mut super::Emitter,
725    block: &'a mut crate::Block,
726}
727
728#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
729pub enum ExpressionKind {
730    Const,
731    Override,
732    Runtime,
733}
734
735#[derive(Debug)]
736pub struct ExpressionKindTracker {
737    inner: HandleVec<Expression, ExpressionKind>,
738}
739
740impl ExpressionKindTracker {
741    pub const fn new() -> Self {
742        Self {
743            inner: HandleVec::new(),
744        }
745    }
746
747    /// Forces the the expression to not be const
748    pub fn force_non_const(&mut self, value: Handle<Expression>) {
749        self.inner[value] = ExpressionKind::Runtime;
750    }
751
752    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
753        self.inner.insert(value, expr_type);
754    }
755
756    pub fn is_const(&self, h: Handle<Expression>) -> bool {
757        matches!(self.type_of(h), ExpressionKind::Const)
758    }
759
760    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
761        matches!(
762            self.type_of(h),
763            ExpressionKind::Const | ExpressionKind::Override
764        )
765    }
766
767    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
768        self.inner[value]
769    }
770
771    pub fn from_arena(arena: &Arena<Expression>) -> Self {
772        let mut tracker = Self {
773            inner: HandleVec::with_capacity(arena.len()),
774        };
775        for (handle, expr) in arena.iter() {
776            tracker
777                .inner
778                .insert(handle, tracker.type_of_with_expr(expr));
779        }
780        tracker
781    }
782
783    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
784        match *expr {
785            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
786                ExpressionKind::Const
787            }
788            Expression::Override(_) => ExpressionKind::Override,
789            Expression::Compose { ref components, .. } => {
790                let mut expr_type = ExpressionKind::Const;
791                for component in components {
792                    expr_type = expr_type.max(self.type_of(*component))
793                }
794                expr_type
795            }
796            Expression::Splat { value, .. } => self.type_of(value),
797            Expression::AccessIndex { base, .. } => self.type_of(base),
798            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
799            Expression::Swizzle { vector, .. } => self.type_of(vector),
800            Expression::Unary { expr, .. } => self.type_of(expr),
801            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
802            Expression::Math {
803                arg,
804                arg1,
805                arg2,
806                arg3,
807                ..
808            } => self
809                .type_of(arg)
810                .max(
811                    arg1.map(|arg| self.type_of(arg))
812                        .unwrap_or(ExpressionKind::Const),
813                )
814                .max(
815                    arg2.map(|arg| self.type_of(arg))
816                        .unwrap_or(ExpressionKind::Const),
817                )
818                .max(
819                    arg3.map(|arg| self.type_of(arg))
820                        .unwrap_or(ExpressionKind::Const),
821                ),
822            Expression::As { expr, .. } => self.type_of(expr),
823            Expression::Select {
824                condition,
825                accept,
826                reject,
827            } => self
828                .type_of(condition)
829                .max(self.type_of(accept))
830                .max(self.type_of(reject)),
831            Expression::Relational { argument, .. } => self.type_of(argument),
832            Expression::ArrayLength(expr) => self.type_of(expr),
833            _ => ExpressionKind::Runtime,
834        }
835    }
836}
837
838#[derive(Clone, Debug, thiserror::Error)]
839#[cfg_attr(test, derive(PartialEq))]
840pub enum ConstantEvaluatorError {
841    #[error("Constants cannot access function arguments")]
842    FunctionArg,
843    #[error("Constants cannot access global variables")]
844    GlobalVariable,
845    #[error("Constants cannot access local variables")]
846    LocalVariable,
847    #[error("Cannot get the array length of a non array type")]
848    InvalidArrayLengthArg,
849    #[error("Constants cannot get the array length of a dynamically sized array")]
850    ArrayLengthDynamic,
851    #[error("Cannot call arrayLength on array sized by override-expression")]
852    ArrayLengthOverridden,
853    #[error("Constants cannot call functions")]
854    Call,
855    #[error("Constants don't support workGroupUniformLoad")]
856    WorkGroupUniformLoadResult,
857    #[error("Constants don't support atomic functions")]
858    Atomic,
859    #[error("Constants don't support derivative functions")]
860    Derivative,
861    #[error("Constants don't support load expressions")]
862    Load,
863    #[error("Constants don't support image expressions")]
864    ImageExpression,
865    #[error("Constants don't support ray query expressions")]
866    RayQueryExpression,
867    #[error("Constants don't support subgroup expressions")]
868    SubgroupExpression,
869    #[error("Cannot access the type")]
870    InvalidAccessBase,
871    #[error("Cannot access at the index")]
872    InvalidAccessIndex,
873    #[error("Cannot access with index of type")]
874    InvalidAccessIndexTy,
875    #[error("Constants don't support array length expressions")]
876    ArrayLength,
877    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
878    InvalidCastArg { from: String, to: String },
879    #[error("Cannot apply the unary op to the argument")]
880    InvalidUnaryOpArg,
881    #[error("Cannot apply the binary op to the arguments")]
882    InvalidBinaryOpArgs,
883    #[error("Cannot apply math function to type")]
884    InvalidMathArg,
885    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
886    InvalidMathArgCount(crate::MathFunction, usize, usize),
887    #[error("{0} built-in function argument is out of valid range")]
888    InvalidMathArgValue(String),
889    #[error("Cannot apply relational function to type")]
890    InvalidRelationalArg(RelationalFunction),
891    #[error("value of `low` is greater than `high` for clamp built-in function")]
892    InvalidClamp,
893    #[error("Constructor expects {expected} components, found {actual}")]
894    InvalidVectorComposeLength { expected: usize, actual: usize },
895    #[error("Constructor must only contain vector or scalar arguments")]
896    InvalidVectorComposeComponent,
897    #[error("Splat is defined only on scalar values")]
898    SplatScalarOnly,
899    #[error("Can only swizzle vector constants")]
900    SwizzleVectorOnly,
901    #[error("swizzle component not present in source expression")]
902    SwizzleOutOfBounds,
903    #[error("Type is not constructible")]
904    TypeNotConstructible,
905    #[error("Subexpression(s) are not constant")]
906    SubexpressionsAreNotConstant,
907    #[error("Not implemented as constant expression: {0}")]
908    NotImplemented(String),
909    #[error("{0} operation overflowed")]
910    Overflow(String),
911    #[error(
912        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
913    )]
914    AutomaticConversionLossy {
915        value: String,
916        to_type: &'static str,
917    },
918    #[error("Division by zero")]
919    DivisionByZero,
920    #[error("Remainder by zero")]
921    RemainderByZero,
922    #[error("RHS of shift operation is greater than or equal to 32")]
923    ShiftedMoreThan32Bits,
924    #[error(transparent)]
925    Literal(#[from] crate::valid::LiteralError),
926    #[error("Can't use pipeline-overridable constants in const-expressions")]
927    Override,
928    #[error("Unexpected runtime-expression")]
929    RuntimeExpr,
930    #[error("Unexpected override-expression")]
931    OverrideExpr,
932    #[error("Expected boolean expression for condition argument of `select`, got something else")]
933    SelectScalarConditionNotABool,
934    #[error(
935        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
936        reject,
937        accept
938    )]
939    SelectVecRejectAcceptSizeMismatch {
940        reject: crate::VectorSize,
941        accept: crate::VectorSize,
942    },
943    #[error("Expected boolean vector for condition arg., got something else")]
944    SelectConditionNotAVecBool,
945    #[error(
946        "Expected same number of vector components between condition, accept, and reject args., got something else",
947    )]
948    SelectConditionVecSizeMismatch,
949    #[error(
950        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
951    )]
952    SelectAcceptRejectTypeMismatch,
953    #[error("Cooperative operations can't be constant")]
954    CooperativeOperation,
955}
956
957impl<'a> ConstantEvaluator<'a> {
958    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
959    /// constant expression arena.
960    ///
961    /// Report errors according to WGSL's rules for constant evaluation.
962    pub const fn for_wgsl_module(
963        module: &'a mut crate::Module,
964        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
965        layouter: &'a mut crate::proc::Layouter,
966        in_override_ctx: bool,
967    ) -> Self {
968        Self::for_module(
969            Behavior::Wgsl(if in_override_ctx {
970                WgslRestrictions::Override
971            } else {
972                WgslRestrictions::Const(None)
973            }),
974            module,
975            global_expression_kind_tracker,
976            layouter,
977        )
978    }
979
980    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
981    /// constant expression arena.
982    ///
983    /// Report errors according to GLSL's rules for constant evaluation.
984    pub const fn for_glsl_module(
985        module: &'a mut crate::Module,
986        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
987        layouter: &'a mut crate::proc::Layouter,
988    ) -> Self {
989        Self::for_module(
990            Behavior::Glsl(GlslRestrictions::Const),
991            module,
992            global_expression_kind_tracker,
993            layouter,
994        )
995    }
996
997    const fn for_module(
998        behavior: Behavior<'a>,
999        module: &'a mut crate::Module,
1000        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1001        layouter: &'a mut crate::proc::Layouter,
1002    ) -> Self {
1003        Self {
1004            behavior,
1005            types: &mut module.types,
1006            constants: &module.constants,
1007            overrides: &module.overrides,
1008            expressions: &mut module.global_expressions,
1009            expression_kind_tracker: global_expression_kind_tracker,
1010            layouter,
1011        }
1012    }
1013
1014    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1015    /// expression arena.
1016    ///
1017    /// Report errors according to WGSL's rules for constant evaluation.
1018    pub const fn for_wgsl_function(
1019        module: &'a mut crate::Module,
1020        expressions: &'a mut Arena<Expression>,
1021        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1022        layouter: &'a mut crate::proc::Layouter,
1023        emitter: &'a mut super::Emitter,
1024        block: &'a mut crate::Block,
1025        is_const: bool,
1026    ) -> Self {
1027        let local_data = FunctionLocalData {
1028            global_expressions: &module.global_expressions,
1029            emitter,
1030            block,
1031        };
1032        Self {
1033            behavior: Behavior::Wgsl(if is_const {
1034                WgslRestrictions::Const(Some(local_data))
1035            } else {
1036                WgslRestrictions::Runtime(local_data)
1037            }),
1038            types: &mut module.types,
1039            constants: &module.constants,
1040            overrides: &module.overrides,
1041            expressions,
1042            expression_kind_tracker: local_expression_kind_tracker,
1043            layouter,
1044        }
1045    }
1046
1047    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1048    /// expression arena.
1049    ///
1050    /// Report errors according to GLSL's rules for constant evaluation.
1051    pub const fn for_glsl_function(
1052        module: &'a mut crate::Module,
1053        expressions: &'a mut Arena<Expression>,
1054        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1055        layouter: &'a mut crate::proc::Layouter,
1056        emitter: &'a mut super::Emitter,
1057        block: &'a mut crate::Block,
1058    ) -> Self {
1059        Self {
1060            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1061                global_expressions: &module.global_expressions,
1062                emitter,
1063                block,
1064            })),
1065            types: &mut module.types,
1066            constants: &module.constants,
1067            overrides: &module.overrides,
1068            expressions,
1069            expression_kind_tracker: local_expression_kind_tracker,
1070            layouter,
1071        }
1072    }
1073
1074    pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1075        crate::proc::GlobalCtx {
1076            types: self.types,
1077            constants: self.constants,
1078            overrides: self.overrides,
1079            global_expressions: match self.function_local_data() {
1080                Some(data) => data.global_expressions,
1081                None => self.expressions,
1082            },
1083        }
1084    }
1085
1086    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1087        if !self.expression_kind_tracker.is_const(expr) {
1088            log::debug!("check: SubexpressionsAreNotConstant");
1089            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1090        }
1091        Ok(())
1092    }
1093
1094    fn check_and_get(
1095        &mut self,
1096        expr: Handle<Expression>,
1097    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1098        match self.expressions[expr] {
1099            Expression::Constant(c) => {
1100                // Are we working in a function's expression arena, or the
1101                // module's constant expression arena?
1102                if let Some(function_local_data) = self.function_local_data() {
1103                    // Deep-copy the constant's value into our arena.
1104                    self.copy_from(
1105                        self.constants[c].init,
1106                        function_local_data.global_expressions,
1107                    )
1108                } else {
1109                    // "See through" the constant and use its initializer.
1110                    Ok(self.constants[c].init)
1111                }
1112            }
1113            _ => {
1114                self.check(expr)?;
1115                Ok(expr)
1116            }
1117        }
1118    }
1119
1120    /// Try to evaluate `expr` at compile time.
1121    ///
1122    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1123    /// we can determine its value at compile time, we append an expression
1124    /// representing its value - a tree of [`Literal`], [`Compose`],
1125    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1126    /// `self` contributes to.
1127    ///
1128    /// If `expr`'s value cannot be determined at compile time, and `self` is
1129    /// contributing to some function's expression arena, then append `expr` to
1130    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1131    /// contributing to the module's constant expression arena; since `expr`'s
1132    /// value is not a constant, return an error.
1133    ///
1134    /// We only consider `expr` itself, without recursing into its operands. Its
1135    /// operands must all have been produced by prior calls to
1136    /// `try_eval_and_append`, to ensure that they have already been reduced to
1137    /// an evaluated form if possible.
1138    ///
1139    /// [`Literal`]: Expression::Literal
1140    /// [`Compose`]: Expression::Compose
1141    /// [`ZeroValue`]: Expression::ZeroValue
1142    /// [`Swizzle`]: Expression::Swizzle
1143    pub fn try_eval_and_append(
1144        &mut self,
1145        expr: Expression,
1146        span: Span,
1147    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1148        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1149            ExpressionKind::Const => {
1150                let eval_result = self.try_eval_and_append_impl(&expr, span);
1151                // We should be able to evaluate `Const` expressions at this
1152                // point. If we failed to, then that probably means we just
1153                // haven't implemented that part of constant evaluation. Work
1154                // around this by simply emitting it as a run-time expression.
1155                if self.behavior.has_runtime_restrictions()
1156                    && matches!(
1157                        eval_result,
1158                        Err(ConstantEvaluatorError::NotImplemented(_)
1159                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1160                    )
1161                {
1162                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1163                } else {
1164                    eval_result
1165                }
1166            }
1167            ExpressionKind::Override => match self.behavior {
1168                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1169                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1170                }
1171                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1172                    Err(ConstantEvaluatorError::OverrideExpr)
1173                }
1174
1175                // GLSL specialization constants (constant_id) become Override expressions
1176                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1177                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1178                }
1179                Behavior::Glsl(GlslRestrictions::Const) => {
1180                    Err(ConstantEvaluatorError::OverrideExpr)
1181                }
1182            },
1183            ExpressionKind::Runtime => {
1184                if self.behavior.has_runtime_restrictions() {
1185                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1186                } else {
1187                    Err(ConstantEvaluatorError::RuntimeExpr)
1188                }
1189            }
1190        }
1191    }
1192
1193    /// Is the [`Self::expressions`] arena the global module expression arena?
1194    const fn is_global_arena(&self) -> bool {
1195        matches!(
1196            self.behavior,
1197            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1198                | Behavior::Glsl(GlslRestrictions::Const)
1199        )
1200    }
1201
1202    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1203        match self.behavior {
1204            Behavior::Wgsl(
1205                WgslRestrictions::Runtime(ref function_local_data)
1206                | WgslRestrictions::Const(Some(ref function_local_data)),
1207            )
1208            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1209                Some(function_local_data)
1210            }
1211            _ => None,
1212        }
1213    }
1214
1215    fn try_eval_and_append_impl(
1216        &mut self,
1217        expr: &Expression,
1218        span: Span,
1219    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1220        log::trace!("try_eval_and_append: {expr:?}");
1221        match *expr {
1222            Expression::Constant(c) if self.is_global_arena() => {
1223                // "See through" the constant and use its initializer.
1224                // This is mainly done to avoid having constants pointing to other constants.
1225                Ok(self.constants[c].init)
1226            }
1227            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1228            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1229                self.register_evaluated_expr(expr.clone(), span)
1230            }
1231            Expression::Compose { ty, ref components } => {
1232                let components = components
1233                    .iter()
1234                    .map(|component| self.check_and_get(*component))
1235                    .collect::<Result<Vec<_>, _>>()?;
1236                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1237            }
1238            Expression::Splat { size, value } => {
1239                let value = self.check_and_get(value)?;
1240                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1241            }
1242            Expression::AccessIndex { base, index } => {
1243                let base = self.check_and_get(base)?;
1244
1245                self.access(base, index as usize, span)
1246            }
1247            Expression::Access { base, index } => {
1248                let base = self.check_and_get(base)?;
1249                let index = self.check_and_get(index)?;
1250
1251                let index_val: u32 = self
1252                    .to_ctx()
1253                    .get_const_val_from(index, self.expressions)
1254                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1255                self.access(base, index_val as usize, span)
1256            }
1257            Expression::Swizzle {
1258                size,
1259                vector,
1260                pattern,
1261            } => {
1262                let vector = self.check_and_get(vector)?;
1263
1264                self.swizzle(size, span, vector, pattern)
1265            }
1266            Expression::Unary { expr, op } => {
1267                let expr = self.check_and_get(expr)?;
1268
1269                self.unary_op(op, expr, span)
1270            }
1271            Expression::Binary { left, right, op } => {
1272                let left = self.check_and_get(left)?;
1273                let right = self.check_and_get(right)?;
1274
1275                self.binary_op(op, left, right, span)
1276            }
1277            Expression::Math {
1278                fun,
1279                arg,
1280                arg1,
1281                arg2,
1282                arg3,
1283            } => {
1284                let arg = self.check_and_get(arg)?;
1285                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1286                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1287                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1288
1289                self.math(arg, arg1, arg2, arg3, fun, span)
1290            }
1291            Expression::As {
1292                convert,
1293                expr,
1294                kind,
1295            } => {
1296                let expr = self.check_and_get(expr)?;
1297
1298                match convert {
1299                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1300                    None => Err(ConstantEvaluatorError::NotImplemented(
1301                        "bitcast built-in function".into(),
1302                    )),
1303                }
1304            }
1305            Expression::Select {
1306                reject,
1307                accept,
1308                condition,
1309            } => {
1310                let mut arg = |expr| self.check_and_get(expr);
1311
1312                let reject = arg(reject)?;
1313                let accept = arg(accept)?;
1314                let condition = arg(condition)?;
1315
1316                self.select(reject, accept, condition, span)
1317            }
1318            Expression::Relational { fun, argument } => {
1319                let argument = self.check_and_get(argument)?;
1320                self.relational(fun, argument, span)
1321            }
1322            Expression::ArrayLength(expr) => match self.behavior {
1323                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1324                Behavior::Glsl(_) => {
1325                    let expr = self.check_and_get(expr)?;
1326                    self.array_length(expr, span)
1327                }
1328            },
1329            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1330            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1331            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1332            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1333            Expression::WorkGroupUniformLoadResult { .. } => {
1334                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1335            }
1336            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1337            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1338            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1339            Expression::ImageSample { .. }
1340            | Expression::ImageLoad { .. }
1341            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1342            Expression::RayQueryProceedResult
1343            | Expression::RayQueryGetIntersection { .. }
1344            | Expression::RayQueryVertexPositions { .. } => {
1345                Err(ConstantEvaluatorError::RayQueryExpression)
1346            }
1347            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1348            Expression::SubgroupOperationResult { .. } => {
1349                Err(ConstantEvaluatorError::SubgroupExpression)
1350            }
1351            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1352                Err(ConstantEvaluatorError::CooperativeOperation)
1353            }
1354        }
1355    }
1356
1357    /// Splat `value` to `size`, without using [`Splat`] expressions.
1358    ///
1359    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1360    /// build a vector with the given `size` whose components are all
1361    /// `value`.
1362    ///
1363    /// Use `span` as the span of the inserted expressions and
1364    /// resulting types.
1365    ///
1366    /// [`Splat`]: Expression::Splat
1367    /// [`Compose`]: Expression::Compose
1368    /// [`ZeroValue`]: Expression::ZeroValue
1369    fn splat(
1370        &mut self,
1371        value: Handle<Expression>,
1372        size: crate::VectorSize,
1373        span: Span,
1374    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1375        match self.expressions[value] {
1376            Expression::Literal(literal) => {
1377                let scalar = literal.scalar();
1378                let ty = self.types.insert(
1379                    Type {
1380                        name: None,
1381                        inner: TypeInner::Vector { size, scalar },
1382                    },
1383                    span,
1384                );
1385                let expr = Expression::Compose {
1386                    ty,
1387                    components: vec![value; size as usize],
1388                };
1389                self.register_evaluated_expr(expr, span)
1390            }
1391            Expression::ZeroValue(ty) => {
1392                let inner = match self.types[ty].inner {
1393                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1394                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1395                };
1396                let res_ty = self.types.insert(Type { name: None, inner }, span);
1397                let expr = Expression::ZeroValue(res_ty);
1398                self.register_evaluated_expr(expr, span)
1399            }
1400            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1401        }
1402    }
1403
1404    fn swizzle(
1405        &mut self,
1406        size: crate::VectorSize,
1407        span: Span,
1408        src_constant: Handle<Expression>,
1409        pattern: [crate::SwizzleComponent; 4],
1410    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1411        let mut get_dst_ty = |ty| match self.types[ty].inner {
1412            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1413                Type {
1414                    name: None,
1415                    inner: TypeInner::Vector { size, scalar },
1416                },
1417                span,
1418            )),
1419            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1420        };
1421
1422        match self.expressions[src_constant] {
1423            Expression::ZeroValue(ty) => {
1424                let dst_ty = get_dst_ty(ty)?;
1425                let expr = Expression::ZeroValue(dst_ty);
1426                self.register_evaluated_expr(expr, span)
1427            }
1428            Expression::Splat { value, .. } => {
1429                let expr = Expression::Splat { size, value };
1430                self.register_evaluated_expr(expr, span)
1431            }
1432            Expression::Compose { ty, ref components } => {
1433                let dst_ty = get_dst_ty(ty)?;
1434
1435                let mut flattened = [src_constant; 4]; // dummy value
1436                let len =
1437                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1438                        .zip(flattened.iter_mut())
1439                        .map(|(component, elt)| *elt = component)
1440                        .count();
1441                let flattened = &flattened[..len];
1442
1443                let swizzled_components = pattern[..size as usize]
1444                    .iter()
1445                    .map(|&sc| {
1446                        let sc = sc as usize;
1447                        if let Some(elt) = flattened.get(sc) {
1448                            Ok(*elt)
1449                        } else {
1450                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1451                        }
1452                    })
1453                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1454                let expr = Expression::Compose {
1455                    ty: dst_ty,
1456                    components: swizzled_components,
1457                };
1458                self.register_evaluated_expr(expr, span)
1459            }
1460            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1461        }
1462    }
1463
1464    fn math(
1465        &mut self,
1466        arg: Handle<Expression>,
1467        arg1: Option<Handle<Expression>>,
1468        arg2: Option<Handle<Expression>>,
1469        arg3: Option<Handle<Expression>>,
1470        fun: crate::MathFunction,
1471        span: Span,
1472    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1473        let expected = fun.argument_count();
1474        let given = Some(arg)
1475            .into_iter()
1476            .chain(arg1)
1477            .chain(arg2)
1478            .chain(arg3)
1479            .count();
1480        if expected != given {
1481            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1482                fun, expected, given,
1483            ));
1484        }
1485
1486        // NOTE: We try to match the declaration order of `MathFunction` here.
1487        match fun {
1488            // comparison
1489            crate::MathFunction::Abs => {
1490                component_wise_scalar(self, span, [arg], |args| match args {
1491                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1492                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1493                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1494                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1495                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1496                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1497                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1498                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1499                })
1500            }
1501            crate::MathFunction::Min => {
1502                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1503                    Ok([e1.min(e2)])
1504                })
1505            }
1506            crate::MathFunction::Max => {
1507                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1508                    Ok([e1.max(e2)])
1509                })
1510            }
1511            crate::MathFunction::Clamp => {
1512                component_wise_scalar!(
1513                    self,
1514                    span,
1515                    [arg, arg1.unwrap(), arg2.unwrap()],
1516                    |e, low, high| {
1517                        if low > high {
1518                            Err(ConstantEvaluatorError::InvalidClamp)
1519                        } else {
1520                            Ok([e.clamp(low, high)])
1521                        }
1522                    }
1523                )
1524            }
1525            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1526                Float::F16([e]) => Ok(Float::F16(
1527                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1528                )),
1529                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1530                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1531            }),
1532
1533            // trigonometry
1534            crate::MathFunction::Cos => {
1535                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1536            }
1537            crate::MathFunction::Cosh => {
1538                component_wise_float!(self, span, [arg], |e| {
1539                    let result = e.cosh();
1540                    if result.is_finite() {
1541                        Ok([result])
1542                    } else {
1543                        Err(ConstantEvaluatorError::Overflow("cosh".into()))
1544                    }
1545                })
1546            }
1547            crate::MathFunction::Sin => {
1548                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1549            }
1550            crate::MathFunction::Sinh => {
1551                component_wise_float!(self, span, [arg], |e| {
1552                    let result = e.sinh();
1553                    if result.is_finite() {
1554                        Ok([result])
1555                    } else {
1556                        Err(ConstantEvaluatorError::Overflow("sinh".into()))
1557                    }
1558                })
1559            }
1560            crate::MathFunction::Tan => {
1561                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1562            }
1563            crate::MathFunction::Tanh => {
1564                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1565            }
1566            crate::MathFunction::Acos => {
1567                component_wise_float!(self, span, [arg], |e| {
1568                    if e.abs() <= One::one() {
1569                        Ok([e.acos()])
1570                    } else {
1571                        Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1572                    }
1573                })
1574            }
1575            crate::MathFunction::Asin => {
1576                component_wise_float!(self, span, [arg], |e| {
1577                    if e.abs() <= One::one() {
1578                        Ok([e.asin()])
1579                    } else {
1580                        Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1581                    }
1582                })
1583            }
1584            crate::MathFunction::Atan => {
1585                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1586            }
1587            crate::MathFunction::Atan2 => {
1588                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1589                    Ok([y.atan2(x)])
1590                })
1591            }
1592            crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1593                Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1594                Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1595                Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1596            }),
1597            crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1598                Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1599                Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1600                Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1601                _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1602            }),
1603            crate::MathFunction::Atanh => {
1604                component_wise_float!(self, span, [arg], |e| {
1605                    if e.abs() < One::one() {
1606                        Ok([e.atanh()])
1607                    } else {
1608                        Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1609                    }
1610                })
1611            }
1612            crate::MathFunction::Radians => {
1613                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1614            }
1615            crate::MathFunction::Degrees => {
1616                component_wise_float!(self, span, [arg], |e| {
1617                    let result = e.to_degrees();
1618                    if result.is_finite() {
1619                        Ok([result])
1620                    } else {
1621                        Err(ConstantEvaluatorError::Overflow("degrees".into()))
1622                    }
1623                })
1624            }
1625
1626            // decomposition
1627            crate::MathFunction::Ceil => {
1628                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1629            }
1630            crate::MathFunction::Floor => {
1631                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1632            }
1633            crate::MathFunction::Round => {
1634                component_wise_float(self, span, [arg], |e| match e {
1635                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1636                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1637                    Float::F16([e]) => {
1638                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1639                        //
1640                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1641                        // which has licensing compatible with ours. See also
1642                        // <https://github.com/rust-lang/rust/issues/96710>.
1643                        //
1644                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1645                        fn round_ties_even(x: f64) -> f64 {
1646                            let i = x as i64;
1647                            let f = (x - i as f64).abs();
1648                            if f == 0.5 {
1649                                if i & 1 == 1 {
1650                                    // -1.5, 1.5, 3.5, ...
1651                                    (x.abs() + 0.5).copysign(x)
1652                                } else {
1653                                    (x.abs() - 0.5).copysign(x)
1654                                }
1655                            } else {
1656                                x.round()
1657                            }
1658                        }
1659
1660                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1661                    }
1662                })
1663            }
1664            crate::MathFunction::Fract => {
1665                component_wise_float!(self, span, [arg], |e| {
1666                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1667                    // here.
1668                    Ok([e - e.floor()])
1669                })
1670            }
1671            crate::MathFunction::Trunc => {
1672                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1673            }
1674
1675            // exponent
1676            crate::MathFunction::Exp => {
1677                component_wise_float!(self, span, [arg], |e| {
1678                    let result = e.exp();
1679                    if result.is_finite() {
1680                        Ok([result])
1681                    } else {
1682                        Err(ConstantEvaluatorError::Overflow("exp".into()))
1683                    }
1684                })
1685            }
1686            crate::MathFunction::Exp2 => {
1687                component_wise_float!(self, span, [arg], |e| {
1688                    let result = e.exp2();
1689                    if result.is_finite() {
1690                        Ok([result])
1691                    } else {
1692                        Err(ConstantEvaluatorError::Overflow("exp2".into()))
1693                    }
1694                })
1695            }
1696            crate::MathFunction::Log => {
1697                component_wise_float!(self, span, [arg], |e| {
1698                    if e > Zero::zero() {
1699                        Ok([e.ln()])
1700                    } else {
1701                        Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1702                    }
1703                })
1704            }
1705            crate::MathFunction::Log2 => {
1706                component_wise_float!(self, span, [arg], |e| {
1707                    if e > Zero::zero() {
1708                        Ok([e.log2()])
1709                    } else {
1710                        Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1711                    }
1712                })
1713            }
1714            crate::MathFunction::Pow => {
1715                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1716                    // 0.pow(0) is an error since exp2(0 * log2(0)) is NaN.
1717                    // https://www.w3.org/TR/WGSL/#pow-builtin
1718                    if e1 < Zero::zero()
1719                        || e1.is_one() && e2.is_infinite()
1720                        || e1.is_infinite() && e2.is_zero()
1721                        || e1.is_zero() && e2.is_zero()
1722                    {
1723                        Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1724                    } else {
1725                        let result = e1.powf(e2);
1726                        if result.is_finite() {
1727                            Ok([result])
1728                        } else {
1729                            Err(ConstantEvaluatorError::Overflow("pow".into()))
1730                        }
1731                    }
1732                })
1733            }
1734
1735            // computational
1736            crate::MathFunction::Sign => {
1737                component_wise_signed!(self, span, [arg], |e| {
1738                    Ok([if e.is_zero() {
1739                        Zero::zero()
1740                    } else {
1741                        e.signum()
1742                    }])
1743                })
1744            }
1745            crate::MathFunction::Fma => {
1746                component_wise_float!(
1747                    self,
1748                    span,
1749                    [arg, arg1.unwrap(), arg2.unwrap()],
1750                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1751                )
1752            }
1753            crate::MathFunction::Step => {
1754                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1755                    Float::Abstract([edge, x]) => {
1756                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1757                    }
1758                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1759                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1760                        f16::one()
1761                    } else {
1762                        f16::zero()
1763                    }])),
1764                })
1765            }
1766            crate::MathFunction::Sqrt => {
1767                component_wise_float!(self, span, [arg], |e| {
1768                    if e >= Zero::zero() {
1769                        Ok([e.sqrt()])
1770                    } else {
1771                        Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1772                    }
1773                })
1774            }
1775            crate::MathFunction::InverseSqrt => {
1776                component_wise_float(self, span, [arg], |e| match e {
1777                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1778                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1779                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1780                })
1781            }
1782
1783            // bits
1784            crate::MathFunction::CountTrailingZeros => {
1785                component_wise_concrete_int!(self, span, [arg], |e| {
1786                    #[allow(clippy::useless_conversion)]
1787                    Ok([e
1788                        .trailing_zeros()
1789                        .try_into()
1790                        .expect("bit count overflowed 32 bits, somehow!?")])
1791                })
1792            }
1793            crate::MathFunction::CountLeadingZeros => {
1794                component_wise_concrete_int!(self, span, [arg], |e| {
1795                    #[allow(clippy::useless_conversion)]
1796                    Ok([e
1797                        .leading_zeros()
1798                        .try_into()
1799                        .expect("bit count overflowed 32 bits, somehow!?")])
1800                })
1801            }
1802            crate::MathFunction::CountOneBits => {
1803                component_wise_concrete_int!(self, span, [arg], |e| {
1804                    #[allow(clippy::useless_conversion)]
1805                    Ok([e
1806                        .count_ones()
1807                        .try_into()
1808                        .expect("bit count overflowed 32 bits, somehow!?")])
1809                })
1810            }
1811            crate::MathFunction::ReverseBits => {
1812                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1813            }
1814            crate::MathFunction::FirstTrailingBit => {
1815                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1816            }
1817            crate::MathFunction::FirstLeadingBit => {
1818                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1819            }
1820
1821            // vector
1822            crate::MathFunction::Dot4I8Packed => {
1823                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1824            }
1825            crate::MathFunction::Dot4U8Packed => {
1826                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1827            }
1828            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1829            crate::MathFunction::Dot => {
1830                // https://www.w3.org/TR/WGSL/#dot-builtin
1831                let e1 = self.extract_vec(arg, false)?;
1832                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1833                if e1.len() != e2.len() {
1834                    return Err(ConstantEvaluatorError::InvalidMathArg);
1835                }
1836
1837                fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1838                where
1839                    P: num_traits::Float,
1840                {
1841                    let result = a
1842                        .iter()
1843                        .zip(b.iter())
1844                        .map(|(&aa, &bb)| aa * bb)
1845                        .fold(P::zero(), |acc, x| acc + x);
1846                    if result.is_finite() {
1847                        Ok(result)
1848                    } else {
1849                        Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1850                    }
1851                }
1852
1853                fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1854                where
1855                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1856                {
1857                    a.iter()
1858                        .zip(b.iter())
1859                        .map(|(&aa, bb)| aa.checked_mul(bb))
1860                        .try_fold(P::zero(), |acc, x| {
1861                            if let Some(x) = x {
1862                                acc.checked_add(&x)
1863                            } else {
1864                                None
1865                            }
1866                        })
1867                        .ok_or(ConstantEvaluatorError::Overflow(
1868                            "in dot built-in".to_string(),
1869                        ))
1870                }
1871
1872                fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1873                where
1874                    P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1875                {
1876                    a.iter()
1877                        .zip(b.iter())
1878                        .map(|(&aa, bb)| aa.wrapping_mul(bb))
1879                        .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1880                }
1881
1882                let result = match_literal_vector!(match (e1, e2) => Literal {
1883                    Float => |e1, e2| { float_dot_checked(e1, e2)? },
1884                    AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1885                    I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1886                    U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1887                })?;
1888                self.register_evaluated_expr(Expression::Literal(result), span)
1889            }
1890            crate::MathFunction::Length => {
1891                // https://www.w3.org/TR/WGSL/#length-builtin
1892                let e1 = self.extract_vec(arg, true)?;
1893
1894                let result = match_literal_vector!(match e1 => Literal {
1895                    Float => |e1| {
1896                        float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1897                    },
1898                })?;
1899                self.register_evaluated_expr(Expression::Literal(result), span)
1900            }
1901            crate::MathFunction::Distance => {
1902                // https://www.w3.org/TR/WGSL/#distance-builtin
1903                let e1 = self.extract_vec(arg, true)?;
1904                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1905                if e1.len() != e2.len() {
1906                    return Err(ConstantEvaluatorError::InvalidMathArg);
1907                }
1908
1909                fn float_distance<F>(a: &[F], b: &[F]) -> F
1910                where
1911                    F: core::ops::Mul<F>,
1912                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1913                {
1914                    if a.len() == 1 {
1915                        // Avoids possible overflow in squaring
1916                        (a[0] - b[0]).abs()
1917                    } else {
1918                        a.iter()
1919                            .zip(b.iter())
1920                            .map(|(&aa, &bb)| aa - bb)
1921                            .map(|ei| ei * ei)
1922                            .sum::<F>()
1923                            .sqrt()
1924                    }
1925                }
1926                let result = match_literal_vector!(match (e1, e2) => Literal {
1927                    Float => |e1, e2| { float_distance(e1, e2) },
1928                })?;
1929                self.register_evaluated_expr(Expression::Literal(result), span)
1930            }
1931            crate::MathFunction::Normalize => {
1932                // https://www.w3.org/TR/WGSL/#normalize-builtin
1933                let e1 = self.extract_vec(arg, true)?;
1934
1935                fn float_normalize<F>(
1936                    e: &[F],
1937                ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1938                where
1939                    F: core::ops::Mul<F>,
1940                    F: num_traits::Float + iter::Sum,
1941                {
1942                    let len = match float_length(e) {
1943                        Some(len) if !len.is_zero() => Ok(len),
1944                        Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1945                            "normalize".into(),
1946                        )),
1947                        None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1948                    }?;
1949
1950                    let mut out = ArrayVec::new();
1951                    for &ei in e {
1952                        out.push(ei / len);
1953                    }
1954                    Ok(out)
1955                }
1956
1957                let result = match_literal_vector!(match e1 => LiteralVector {
1958                    Float => |e1| { float_normalize(e1)? },
1959                })?;
1960                result.register_as_evaluated_expr(self, span)
1961            }
1962
1963            // unimplemented
1964            crate::MathFunction::Modf
1965            | crate::MathFunction::Frexp
1966            | crate::MathFunction::Ldexp
1967            | crate::MathFunction::Outer
1968            | crate::MathFunction::FaceForward
1969            | crate::MathFunction::Reflect
1970            | crate::MathFunction::Refract
1971            | crate::MathFunction::Mix
1972            | crate::MathFunction::SmoothStep
1973            | crate::MathFunction::Inverse
1974            | crate::MathFunction::Transpose
1975            | crate::MathFunction::Determinant
1976            | crate::MathFunction::QuantizeToF16
1977            | crate::MathFunction::ExtractBits
1978            | crate::MathFunction::InsertBits
1979            | crate::MathFunction::Pack4x8snorm
1980            | crate::MathFunction::Pack4x8unorm
1981            | crate::MathFunction::Pack2x16snorm
1982            | crate::MathFunction::Pack2x16unorm
1983            | crate::MathFunction::Pack2x16float
1984            | crate::MathFunction::Pack4xI8
1985            | crate::MathFunction::Pack4xU8
1986            | crate::MathFunction::Pack4xI8Clamp
1987            | crate::MathFunction::Pack4xU8Clamp
1988            | crate::MathFunction::Unpack4x8snorm
1989            | crate::MathFunction::Unpack4x8unorm
1990            | crate::MathFunction::Unpack2x16snorm
1991            | crate::MathFunction::Unpack2x16unorm
1992            | crate::MathFunction::Unpack2x16float
1993            | crate::MathFunction::Unpack4xI8
1994            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1995                format!("{fun:?} built-in function"),
1996            )),
1997        }
1998    }
1999
2000    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
2001    fn packed_dot_product(
2002        &mut self,
2003        a: Handle<Expression>,
2004        b: Handle<Expression>,
2005        span: Span,
2006        signed: bool,
2007    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2008        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
2009            return Err(ConstantEvaluatorError::InvalidMathArg);
2010        };
2011        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2012            return Err(ConstantEvaluatorError::InvalidMathArg);
2013        };
2014
2015        let result = if signed {
2016            Literal::I32(
2017                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2018                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2019                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2020                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2021            )
2022        } else {
2023            Literal::U32(
2024                (a & 0xFF) * (b & 0xFF)
2025                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2026                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2027                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2028            )
2029        };
2030
2031        self.register_evaluated_expr(Expression::Literal(result), span)
2032    }
2033
2034    /// Vector cross product.
2035    fn cross_product(
2036        &mut self,
2037        a: Handle<Expression>,
2038        b: Handle<Expression>,
2039        span: Span,
2040    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2041        use Literal as Li;
2042
2043        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2044        let (b, _) = self.extract_vec_with_size::<3>(b)?;
2045
2046        let product = match (a, b) {
2047            (
2048                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2049                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2050            ) => {
2051                // `cross` has no overload for AbstractInt, so AbstractInt
2052                // arguments are automatically converted to AbstractFloat. Since
2053                // `f64` has a much wider range than `i64`, there's no danger of
2054                // overflow here.
2055                let p = cross_product(
2056                    [a0 as f64, a1 as f64, a2 as f64],
2057                    [b0 as f64, b1 as f64, b2 as f64],
2058                );
2059                [
2060                    Li::AbstractFloat(p[0]),
2061                    Li::AbstractFloat(p[1]),
2062                    Li::AbstractFloat(p[2]),
2063                ]
2064            }
2065            (
2066                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2067                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2068            ) => {
2069                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2070                [
2071                    Li::AbstractFloat(p[0]),
2072                    Li::AbstractFloat(p[1]),
2073                    Li::AbstractFloat(p[2]),
2074                ]
2075            }
2076            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2077                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2078                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2079            }
2080            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2081                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2082                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2083            }
2084            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2085                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2086                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2087            }
2088            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2089        };
2090
2091        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2092        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2093        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2094
2095        self.register_evaluated_expr(
2096            Expression::Compose {
2097                ty,
2098                components: vec![p0, p1, p2],
2099            },
2100            span,
2101        )
2102    }
2103
2104    /// Extract the values of a `vecN` from `expr`.
2105    ///
2106    /// Return the value of `expr`, whose type is `vecN<S>` for some
2107    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2108    /// values.
2109    ///
2110    /// Also return the type handle from the `Compose` expression.
2111    fn extract_vec_with_size<const N: usize>(
2112        &mut self,
2113        expr: Handle<Expression>,
2114    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2115        let span = self.expressions.get_span(expr);
2116        let expr = self.eval_zero_value_and_splat(expr, span)?;
2117        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2118            return Err(ConstantEvaluatorError::InvalidMathArg);
2119        };
2120
2121        let mut value = [Literal::Bool(false); N];
2122        for (component, elt) in
2123            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2124                .zip(value.iter_mut())
2125        {
2126            let Expression::Literal(literal) = self.expressions[component] else {
2127                return Err(ConstantEvaluatorError::InvalidMathArg);
2128            };
2129            *elt = literal;
2130        }
2131
2132        Ok((value, ty))
2133    }
2134
2135    /// Extract the values of a `vecN` from `expr`.
2136    ///
2137    /// Return the value of `expr`, whose type is `vecN<S>` for some
2138    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2139    /// values.
2140    ///
2141    /// Also return the type handle from the `Compose` expression.
2142    fn extract_vec(
2143        &mut self,
2144        expr: Handle<Expression>,
2145        allow_single: bool,
2146    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2147        let span = self.expressions.get_span(expr);
2148        let expr = self.eval_zero_value_and_splat(expr, span)?;
2149
2150        match self.expressions[expr] {
2151            Expression::Literal(literal) if allow_single => {
2152                Ok(LiteralVector::from_literal(literal))
2153            }
2154            Expression::Compose { ty, ref components } => {
2155                let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2156                for expr in
2157                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2158                {
2159                    match self.expressions[expr] {
2160                        Expression::Literal(l) => components_out.push(l),
2161                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2162                    }
2163                }
2164                LiteralVector::from_literal_vec(components_out)
2165            }
2166            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2167        }
2168    }
2169
2170    fn array_length(
2171        &mut self,
2172        array: Handle<Expression>,
2173        span: Span,
2174    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2175        match self.expressions[array] {
2176            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2177                match self.types[ty].inner {
2178                    TypeInner::Array { size, .. } => match size {
2179                        ArraySize::Constant(len) => {
2180                            let expr = Expression::Literal(Literal::U32(len.get()));
2181                            self.register_evaluated_expr(expr, span)
2182                        }
2183                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2184                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2185                    },
2186                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2187                }
2188            }
2189            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2190        }
2191    }
2192
2193    fn access(
2194        &mut self,
2195        base: Handle<Expression>,
2196        index: usize,
2197        span: Span,
2198    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2199        match self.expressions[base] {
2200            Expression::ZeroValue(ty) => {
2201                let ty_inner = &self.types[ty].inner;
2202                let components = ty_inner
2203                    .components()
2204                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2205
2206                if index >= components as usize {
2207                    Err(ConstantEvaluatorError::InvalidAccessBase)
2208                } else {
2209                    let ty_res = ty_inner
2210                        .component_type(index)
2211                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2212                    let ty = match ty_res {
2213                        crate::proc::TypeResolution::Handle(ty) => ty,
2214                        crate::proc::TypeResolution::Value(inner) => {
2215                            self.types.insert(Type { name: None, inner }, span)
2216                        }
2217                    };
2218                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2219                }
2220            }
2221            Expression::Splat { size, value } => {
2222                if index >= size as usize {
2223                    Err(ConstantEvaluatorError::InvalidAccessBase)
2224                } else {
2225                    Ok(value)
2226                }
2227            }
2228            Expression::Compose { ty, ref components } => {
2229                let _ = self.types[ty]
2230                    .inner
2231                    .components()
2232                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2233
2234                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2235                    .nth(index)
2236                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2237            }
2238            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2239        }
2240    }
2241
2242    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2243    ///
2244    /// [`ZeroValue`]: Expression::ZeroValue
2245    /// [`Splat`]: Expression::Splat
2246    /// [`Literal`]: Expression::Literal
2247    /// [`Compose`]: Expression::Compose
2248    fn eval_zero_value_and_splat(
2249        &mut self,
2250        mut expr: Handle<Expression>,
2251        span: Span,
2252    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2253        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2254        // each of its components.
2255        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2256            let components = components
2257                .clone()
2258                .iter()
2259                .map(|component| self.eval_zero_value_and_splat(*component, span))
2260                .collect::<Result<_, _>>()?;
2261            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2262        }
2263
2264        // The result of the splat() for a Splat of a scalar ZeroValue is a
2265        // vector ZeroValue, so we must call eval_zero_value_impl() after
2266        // splat() in order to ensure we have no ZeroValues remaining.
2267        if let Expression::Splat { size, value } = self.expressions[expr] {
2268            expr = self.splat(value, size, span)?;
2269        }
2270        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2271            expr = self.eval_zero_value_impl(ty, span)?;
2272        }
2273        Ok(expr)
2274    }
2275
2276    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2277    ///
2278    /// [`ZeroValue`]: Expression::ZeroValue
2279    /// [`Literal`]: Expression::Literal
2280    /// [`Compose`]: Expression::Compose
2281    fn eval_zero_value(
2282        &mut self,
2283        expr: Handle<Expression>,
2284        span: Span,
2285    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2286        match self.expressions[expr] {
2287            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2288            _ => Ok(expr),
2289        }
2290    }
2291
2292    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2293    ///
2294    /// [`ZeroValue`]: Expression::ZeroValue
2295    /// [`Literal`]: Expression::Literal
2296    /// [`Compose`]: Expression::Compose
2297    fn eval_zero_value_impl(
2298        &mut self,
2299        ty: Handle<Type>,
2300        span: Span,
2301    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2302        match self.types[ty].inner {
2303            TypeInner::Scalar(scalar) => {
2304                let expr = Expression::Literal(
2305                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2306                );
2307                self.register_evaluated_expr(expr, span)
2308            }
2309            TypeInner::Vector { size, scalar } => {
2310                let scalar_ty = self.types.insert(
2311                    Type {
2312                        name: None,
2313                        inner: TypeInner::Scalar(scalar),
2314                    },
2315                    span,
2316                );
2317                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2318                let expr = Expression::Compose {
2319                    ty,
2320                    components: vec![el; size as usize],
2321                };
2322                self.register_evaluated_expr(expr, span)
2323            }
2324            TypeInner::Matrix {
2325                columns,
2326                rows,
2327                scalar,
2328            } => {
2329                let vec_ty = self.types.insert(
2330                    Type {
2331                        name: None,
2332                        inner: TypeInner::Vector { size: rows, scalar },
2333                    },
2334                    span,
2335                );
2336                let el = self.eval_zero_value_impl(vec_ty, span)?;
2337                let expr = Expression::Compose {
2338                    ty,
2339                    components: vec![el; columns as usize],
2340                };
2341                self.register_evaluated_expr(expr, span)
2342            }
2343            TypeInner::Array {
2344                base,
2345                size: ArraySize::Constant(size),
2346                ..
2347            } => {
2348                let el = self.eval_zero_value_impl(base, span)?;
2349                let expr = Expression::Compose {
2350                    ty,
2351                    components: vec![el; size.get() as usize],
2352                };
2353                self.register_evaluated_expr(expr, span)
2354            }
2355            TypeInner::Struct { ref members, .. } => {
2356                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2357                let mut components = Vec::with_capacity(members.len());
2358                for ty in types {
2359                    components.push(self.eval_zero_value_impl(ty, span)?);
2360                }
2361                let expr = Expression::Compose { ty, components };
2362                self.register_evaluated_expr(expr, span)
2363            }
2364            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2365        }
2366    }
2367
2368    /// Convert the scalar components of `expr` to `target`.
2369    ///
2370    /// Treat `span` as the location of the resulting expression.
2371    pub fn cast(
2372        &mut self,
2373        expr: Handle<Expression>,
2374        target: crate::Scalar,
2375        span: Span,
2376    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2377        use crate::Scalar as Sc;
2378
2379        let expr = self.eval_zero_value(expr, span)?;
2380
2381        let make_error = || -> Result<_, ConstantEvaluatorError> {
2382            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2383
2384            #[cfg(feature = "wgsl-in")]
2385            let to = target.to_wgsl_for_diagnostics();
2386
2387            #[cfg(not(feature = "wgsl-in"))]
2388            let to = format!("{target:?}");
2389
2390            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2391        };
2392
2393        use crate::proc::type_methods::IntFloatLimits;
2394
2395        let expr = match self.expressions[expr] {
2396            Expression::Literal(literal) => {
2397                let literal = match target {
2398                    Sc::I16 => Literal::I16(match literal {
2399                        Literal::I16(v) => v,
2400                        Literal::U16(v) => v as i16,
2401                        Literal::I32(v) => v as i16,
2402                        Literal::U32(v) => v as i16,
2403                        Literal::F32(v) => v.clamp(i16::MIN as f32, i16::MAX as f32) as i16,
2404                        Literal::F16(v) => {
2405                            f16::to_f32(v).clamp(i16::MIN as f32, i16::MAX as f32) as i16
2406                        }
2407                        Literal::Bool(v) => v as i16,
2408                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2409                            return make_error();
2410                        }
2411                        Literal::AbstractInt(v) => v as i16,
2412                        Literal::AbstractFloat(v) => v as i16,
2413                    }),
2414                    Sc::U16 => Literal::U16(match literal {
2415                        Literal::I16(v) => v as u16,
2416                        Literal::U16(v) => v,
2417                        Literal::I32(v) => v as u16,
2418                        Literal::U32(v) => v as u16,
2419                        Literal::F32(v) => v.clamp(u16::MIN as f32, u16::MAX as f32) as u16,
2420                        Literal::F16(v) => f16::to_u16(&v.max(f16::ZERO)).unwrap(),
2421                        Literal::Bool(v) => v as u16,
2422                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2423                            return make_error();
2424                        }
2425                        Literal::AbstractInt(v) => v as u16,
2426                        Literal::AbstractFloat(v) => v as u16,
2427                    }),
2428                    Sc::I32 => Literal::I32(match literal {
2429                        Literal::I16(v) => v as i32,
2430                        Literal::U16(v) => v as i32,
2431                        Literal::I32(v) => v,
2432                        Literal::U32(v) => v as i32,
2433                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2434                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2435                        Literal::Bool(v) => v as i32,
2436                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2437                            return make_error();
2438                        }
2439                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2440                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2441                    }),
2442                    Sc::U32 => Literal::U32(match literal {
2443                        Literal::I16(v) => v as u32,
2444                        Literal::U16(v) => v as u32,
2445                        Literal::I32(v) => v as u32,
2446                        Literal::U32(v) => v,
2447                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2448                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2449                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2450                        Literal::Bool(v) => v as u32,
2451                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2452                            return make_error();
2453                        }
2454                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2455                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2456                    }),
2457                    Sc::I64 => Literal::I64(match literal {
2458                        Literal::I16(v) => v as i64,
2459                        Literal::U16(v) => v as i64,
2460                        Literal::I32(v) => v as i64,
2461                        Literal::U32(v) => v as i64,
2462                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2463                        Literal::Bool(v) => v as i64,
2464                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2465                        Literal::I64(v) => v,
2466                        Literal::U64(v) => v as i64,
2467                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2468                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2469                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2470                    }),
2471                    Sc::U64 => Literal::U64(match literal {
2472                        Literal::I16(v) => v as u64,
2473                        Literal::U16(v) => v as u64,
2474                        Literal::I32(v) => v as u64,
2475                        Literal::U32(v) => v as u64,
2476                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2477                        Literal::Bool(v) => v as u64,
2478                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2479                        Literal::I64(v) => v as u64,
2480                        Literal::U64(v) => v,
2481                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2482                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2483                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2484                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2485                    }),
2486                    Sc::F16 => Literal::F16(match literal {
2487                        Literal::F16(v) => v,
2488                        Literal::F32(v) => f16::from_f32(v),
2489                        Literal::F64(v) => f16::from_f64(v),
2490                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2491                        Literal::I16(v) => f16::from_f32(v as f32),
2492                        Literal::U16(v) => f16::from_f32(v as f32),
2493                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2494                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2495                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2496                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2497                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2498                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2499                    }),
2500                    Sc::F32 => Literal::F32(match literal {
2501                        Literal::I16(v) => v as f32,
2502                        Literal::U16(v) => v as f32,
2503                        Literal::I32(v) => v as f32,
2504                        Literal::U32(v) => v as f32,
2505                        Literal::F32(v) => v,
2506                        Literal::Bool(v) => v as u32 as f32,
2507                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2508                            return make_error();
2509                        }
2510                        Literal::F16(v) => f16::to_f32(v),
2511                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2512                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2513                    }),
2514                    Sc::F64 => Literal::F64(match literal {
2515                        Literal::I16(v) => v as f64,
2516                        Literal::U16(v) => v as f64,
2517                        Literal::I32(v) => v as f64,
2518                        Literal::U32(v) => v as f64,
2519                        Literal::F16(v) => f16::to_f64(v),
2520                        Literal::F32(v) => v as f64,
2521                        Literal::F64(v) => v,
2522                        Literal::Bool(v) => v as u32 as f64,
2523                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2524                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2525                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2526                    }),
2527                    Sc::BOOL => Literal::Bool(match literal {
2528                        Literal::I16(v) => v != 0,
2529                        Literal::U16(v) => v != 0,
2530                        Literal::I32(v) => v != 0,
2531                        Literal::U32(v) => v != 0,
2532                        Literal::F32(v) => v != 0.0,
2533                        Literal::F16(v) => v != f16::zero(),
2534                        Literal::Bool(v) => v,
2535                        Literal::AbstractInt(v) => v != 0,
2536                        Literal::AbstractFloat(v) => v != 0.0,
2537                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2538                            return make_error();
2539                        }
2540                    }),
2541                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2542                        Literal::AbstractInt(v) => {
2543                            // Overflow is forbidden, but inexact conversions
2544                            // are fine. The range of f64 is far larger than
2545                            // that of i64, so we don't have to check anything
2546                            // here.
2547                            v as f64
2548                        }
2549                        Literal::AbstractFloat(v) => v,
2550                        _ => return make_error(),
2551                    }),
2552                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2553                        Literal::AbstractInt(v) => v,
2554                        _ => return make_error(),
2555                    }),
2556                    _ => {
2557                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2558                        return make_error();
2559                    }
2560                };
2561                Expression::Literal(literal)
2562            }
2563            Expression::Compose {
2564                ty,
2565                components: ref src_components,
2566            } => {
2567                let ty_inner = match self.types[ty].inner {
2568                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2569                        size,
2570                        scalar: target,
2571                    },
2572                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2573                        columns,
2574                        rows,
2575                        scalar: target,
2576                    },
2577                    _ => return make_error(),
2578                };
2579
2580                let mut components = src_components.clone();
2581                for component in &mut components {
2582                    *component = self.cast(*component, target, span)?;
2583                }
2584
2585                let ty = self.types.insert(
2586                    Type {
2587                        name: None,
2588                        inner: ty_inner,
2589                    },
2590                    span,
2591                );
2592
2593                Expression::Compose { ty, components }
2594            }
2595            Expression::Splat { size, value } => {
2596                let value_span = self.expressions.get_span(value);
2597                let cast_value = self.cast(value, target, value_span)?;
2598                Expression::Splat {
2599                    size,
2600                    value: cast_value,
2601                }
2602            }
2603            _ => return make_error(),
2604        };
2605
2606        self.register_evaluated_expr(expr, span)
2607    }
2608
2609    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2610    ///
2611    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2612    /// matrix, or nested arrays of such.
2613    ///
2614    /// This is basically the same as the [`cast`] method, except that that
2615    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2616    ///
2617    /// Treat `span` as the location of the resulting expression.
2618    ///
2619    /// [`cast`]: ConstantEvaluator::cast
2620    /// [`As`]: crate::Expression::As
2621    pub fn cast_array(
2622        &mut self,
2623        expr: Handle<Expression>,
2624        target: crate::Scalar,
2625        span: Span,
2626    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2627        let expr = self.check_and_get(expr)?;
2628
2629        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2630            return self.cast(expr, target, span);
2631        };
2632
2633        let TypeInner::Array {
2634            base: _,
2635            size,
2636            stride: _,
2637        } = self.types[ty].inner
2638        else {
2639            return self.cast(expr, target, span);
2640        };
2641
2642        let mut components = components.clone();
2643        for component in &mut components {
2644            *component = self.cast_array(*component, target, span)?;
2645        }
2646
2647        let first = components.first().unwrap();
2648        let new_base = match self.resolve_type(*first)? {
2649            crate::proc::TypeResolution::Handle(ty) => ty,
2650            crate::proc::TypeResolution::Value(inner) => {
2651                self.types.insert(Type { name: None, inner }, span)
2652            }
2653        };
2654        let mut layouter = core::mem::take(self.layouter);
2655        layouter.update(self.to_ctx()).unwrap();
2656        *self.layouter = layouter;
2657
2658        let new_base_stride = self.layouter[new_base].to_stride();
2659        let new_array_ty = self.types.insert(
2660            Type {
2661                name: None,
2662                inner: TypeInner::Array {
2663                    base: new_base,
2664                    size,
2665                    stride: new_base_stride,
2666                },
2667            },
2668            span,
2669        );
2670
2671        let compose = Expression::Compose {
2672            ty: new_array_ty,
2673            components,
2674        };
2675        self.register_evaluated_expr(compose, span)
2676    }
2677
2678    fn unary_op(
2679        &mut self,
2680        op: UnaryOperator,
2681        expr: Handle<Expression>,
2682        span: Span,
2683    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2684        let expr = self.eval_zero_value_and_splat(expr, span)?;
2685
2686        let expr = match self.expressions[expr] {
2687            Expression::Literal(value) => Expression::Literal(match op {
2688                UnaryOperator::Negate => match value {
2689                    Literal::I16(v) => Literal::I16(v.wrapping_neg()),
2690                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2691                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2692                    Literal::F32(v) => Literal::F32(-v),
2693                    Literal::F16(v) => Literal::F16(-v),
2694                    Literal::F64(v) => Literal::F64(-v),
2695                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2696                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2697                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2698                },
2699                UnaryOperator::LogicalNot => match value {
2700                    Literal::Bool(v) => Literal::Bool(!v),
2701                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2702                },
2703                UnaryOperator::BitwiseNot => match value {
2704                    Literal::I16(v) => Literal::I16(!v),
2705                    Literal::I32(v) => Literal::I32(!v),
2706                    Literal::I64(v) => Literal::I64(!v),
2707                    Literal::U16(v) => Literal::U16(!v),
2708                    Literal::U32(v) => Literal::U32(!v),
2709                    Literal::U64(v) => Literal::U64(!v),
2710                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2711                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2712                },
2713            }),
2714            Expression::Compose {
2715                ty,
2716                components: ref src_components,
2717            } => {
2718                match self.types[ty].inner {
2719                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2720                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2721                }
2722
2723                let mut components = src_components.clone();
2724                for component in &mut components {
2725                    *component = self.unary_op(op, *component, span)?;
2726                }
2727
2728                Expression::Compose { ty, components }
2729            }
2730            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2731        };
2732
2733        self.register_evaluated_expr(expr, span)
2734    }
2735
2736    fn binary_op(
2737        &mut self,
2738        op: BinaryOperator,
2739        left: Handle<Expression>,
2740        right: Handle<Expression>,
2741        span: Span,
2742    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2743        let left = self.eval_zero_value_and_splat(left, span)?;
2744        let right = self.eval_zero_value_and_splat(right, span)?;
2745
2746        // Note: in most cases constant evaluation checks for overflow, but for
2747        // i32/u32, it uses wrapping arithmetic. See
2748        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2749
2750        let expr = match (&self.expressions[left], &self.expressions[right]) {
2751            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2752                if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2753                    && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2754                {
2755                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2756                }
2757
2758                if matches!(
2759                    (left_value, op),
2760                    (
2761                        Literal::Bool(_),
2762                        BinaryOperator::Less
2763                            | BinaryOperator::LessEqual
2764                            | BinaryOperator::Greater
2765                            | BinaryOperator::GreaterEqual
2766                    )
2767                ) {
2768                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2769                }
2770
2771                let literal = match op {
2772                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2773                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2774                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2775                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2776                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2777                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2778
2779                    _ => match (left_value, right_value) {
2780                        (Literal::I16(a), Literal::I16(b)) => Literal::I16(match op {
2781                            BinaryOperator::Add => a.wrapping_add(b),
2782                            BinaryOperator::Subtract => a.wrapping_sub(b),
2783                            BinaryOperator::Multiply => a.wrapping_mul(b),
2784                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2785                                if b == 0 {
2786                                    ConstantEvaluatorError::DivisionByZero
2787                                } else {
2788                                    ConstantEvaluatorError::Overflow("division".into())
2789                                }
2790                            })?,
2791                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2792                                if b == 0 {
2793                                    ConstantEvaluatorError::RemainderByZero
2794                                } else {
2795                                    ConstantEvaluatorError::Overflow("remainder".into())
2796                                }
2797                            })?,
2798                            BinaryOperator::And => a & b,
2799                            BinaryOperator::ExclusiveOr => a ^ b,
2800                            BinaryOperator::InclusiveOr => a | b,
2801                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2802                        }),
2803                        (Literal::I16(a), Literal::U32(b)) => Literal::I16(match op {
2804                            BinaryOperator::ShiftLeft => {
2805                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2806                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2807                                }
2808                                a.checked_shl(b)
2809                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2810                            }
2811                            BinaryOperator::ShiftRight => a
2812                                .checked_shr(b)
2813                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2814                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2815                        }),
2816                        (Literal::U16(a), Literal::U16(b)) => Literal::U16(match op {
2817                            BinaryOperator::Add => a.wrapping_add(b),
2818                            BinaryOperator::Subtract => a.wrapping_sub(b),
2819                            BinaryOperator::Multiply => a.wrapping_mul(b),
2820                            BinaryOperator::Divide => a
2821                                .checked_div(b)
2822                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2823                            BinaryOperator::Modulo => a
2824                                .checked_rem(b)
2825                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2826                            BinaryOperator::And => a & b,
2827                            BinaryOperator::ExclusiveOr => a ^ b,
2828                            BinaryOperator::InclusiveOr => a | b,
2829                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2830                        }),
2831                        (Literal::U16(a), Literal::U32(b)) => Literal::U16(match op {
2832                            BinaryOperator::ShiftLeft => a
2833                                .checked_mul(
2834                                    1u16.checked_shl(b)
2835                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2836                                )
2837                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2838                            BinaryOperator::ShiftRight => a
2839                                .checked_shr(b)
2840                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2841                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2842                        }),
2843                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2844                            BinaryOperator::Add => a.wrapping_add(b),
2845                            BinaryOperator::Subtract => a.wrapping_sub(b),
2846                            BinaryOperator::Multiply => a.wrapping_mul(b),
2847                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2848                                if b == 0 {
2849                                    ConstantEvaluatorError::DivisionByZero
2850                                } else {
2851                                    ConstantEvaluatorError::Overflow("division".into())
2852                                }
2853                            })?,
2854                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2855                                if b == 0 {
2856                                    ConstantEvaluatorError::RemainderByZero
2857                                } else {
2858                                    ConstantEvaluatorError::Overflow("remainder".into())
2859                                }
2860                            })?,
2861                            BinaryOperator::And => a & b,
2862                            BinaryOperator::ExclusiveOr => a ^ b,
2863                            BinaryOperator::InclusiveOr => a | b,
2864                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2865                        }),
2866                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2867                            BinaryOperator::ShiftLeft => {
2868                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2869                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2870                                }
2871                                a.checked_shl(b)
2872                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2873                            }
2874                            BinaryOperator::ShiftRight => a
2875                                .checked_shr(b)
2876                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2877                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2878                        }),
2879                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2880                            BinaryOperator::Add => a.wrapping_add(b),
2881                            BinaryOperator::Subtract => a.wrapping_sub(b),
2882                            BinaryOperator::Multiply => a.wrapping_mul(b),
2883                            BinaryOperator::Divide => a
2884                                .checked_div(b)
2885                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2886                            BinaryOperator::Modulo => a
2887                                .checked_rem(b)
2888                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2889                            BinaryOperator::And => a & b,
2890                            BinaryOperator::ExclusiveOr => a ^ b,
2891                            BinaryOperator::InclusiveOr => a | b,
2892                            BinaryOperator::ShiftLeft => a
2893                                .checked_mul(
2894                                    1u32.checked_shl(b)
2895                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2896                                )
2897                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2898                            BinaryOperator::ShiftRight => a
2899                                .checked_shr(b)
2900                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2901                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2902                        }),
2903                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2904                            BinaryOperator::Add => a + b,
2905                            BinaryOperator::Subtract => a - b,
2906                            BinaryOperator::Multiply => a * b,
2907                            BinaryOperator::Divide => a / b,
2908                            BinaryOperator::Modulo => a % b,
2909                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2910                        }),
2911                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2912                            Literal::AbstractInt(match op {
2913                                BinaryOperator::ShiftLeft => {
2914                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2915                                        return Err(ConstantEvaluatorError::Overflow(
2916                                            "<<".to_string(),
2917                                        ));
2918                                    }
2919                                    a.checked_shl(b).unwrap_or(0)
2920                                }
2921                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2922                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2923                            })
2924                        }
2925                        (Literal::F16(a), Literal::F16(b)) => {
2926                            let result = match op {
2927                                BinaryOperator::Add => a + b,
2928                                BinaryOperator::Subtract => a - b,
2929                                BinaryOperator::Multiply => a * b,
2930                                BinaryOperator::Divide => a / b,
2931                                BinaryOperator::Modulo => a % b,
2932                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2933                            };
2934                            if !result.is_finite() {
2935                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2936                            }
2937                            Literal::F16(result)
2938                        }
2939                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2940                            Literal::AbstractInt(match op {
2941                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2942                                    ConstantEvaluatorError::Overflow("addition".into())
2943                                })?,
2944                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2945                                    ConstantEvaluatorError::Overflow("subtraction".into())
2946                                })?,
2947                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2948                                    ConstantEvaluatorError::Overflow("multiplication".into())
2949                                })?,
2950                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2951                                    if b == 0 {
2952                                        ConstantEvaluatorError::DivisionByZero
2953                                    } else {
2954                                        ConstantEvaluatorError::Overflow("division".into())
2955                                    }
2956                                })?,
2957                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2958                                    if b == 0 {
2959                                        ConstantEvaluatorError::RemainderByZero
2960                                    } else {
2961                                        ConstantEvaluatorError::Overflow("remainder".into())
2962                                    }
2963                                })?,
2964                                BinaryOperator::And => a & b,
2965                                BinaryOperator::ExclusiveOr => a ^ b,
2966                                BinaryOperator::InclusiveOr => a | b,
2967                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2968                            })
2969                        }
2970                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2971                            let result = match op {
2972                                BinaryOperator::Add => a + b,
2973                                BinaryOperator::Subtract => a - b,
2974                                BinaryOperator::Multiply => a * b,
2975                                BinaryOperator::Divide => a / b,
2976                                BinaryOperator::Modulo => a % b,
2977                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2978                            };
2979                            if !result.is_finite() {
2980                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2981                            }
2982                            Literal::AbstractFloat(result)
2983                        }
2984                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2985                            BinaryOperator::LogicalAnd => a && b,
2986                            BinaryOperator::LogicalOr => a || b,
2987                            BinaryOperator::And => a & b,
2988                            BinaryOperator::InclusiveOr => a | b,
2989                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2990                        }),
2991                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2992                    },
2993                };
2994                Expression::Literal(literal)
2995            }
2996            (
2997                &Expression::Compose {
2998                    components: ref src_components,
2999                    ty,
3000                },
3001                &Expression::Literal(_),
3002            ) => {
3003                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3004                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3005                }
3006                let mut components = src_components.clone();
3007                for component in &mut components {
3008                    *component = self.binary_op(op, *component, right, span)?;
3009                }
3010                Expression::Compose { ty, components }
3011            }
3012            (
3013                &Expression::Literal(_),
3014                &Expression::Compose {
3015                    components: ref src_components,
3016                    ty,
3017                },
3018            ) => {
3019                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3020                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3021                }
3022                let mut components = src_components.clone();
3023                for component in &mut components {
3024                    *component = self.binary_op(op, left, *component, span)?;
3025                }
3026                Expression::Compose { ty, components }
3027            }
3028            (
3029                &Expression::Compose {
3030                    components: ref left_components,
3031                    ty: left_ty,
3032                },
3033                &Expression::Compose {
3034                    components: ref right_components,
3035                    ty: right_ty,
3036                },
3037            ) => {
3038                // We have to make a copy of the component lists, because the
3039                // call to `binary_op_vector` needs `&mut self`, but `self` owns
3040                // the component lists.
3041                let left_flattened = crate::proc::flatten_compose(
3042                    left_ty,
3043                    left_components,
3044                    self.expressions,
3045                    self.types,
3046                )
3047                .collect::<Vec<_>>();
3048                let right_flattened = crate::proc::flatten_compose(
3049                    right_ty,
3050                    right_components,
3051                    self.expressions,
3052                    self.types,
3053                )
3054                .collect::<Vec<_>>();
3055
3056                self.binary_op_compose(
3057                    op,
3058                    &left_flattened,
3059                    &right_flattened,
3060                    left_ty,
3061                    right_ty,
3062                    span,
3063                )?
3064            }
3065            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3066        };
3067
3068        return self.register_evaluated_expr(expr, span);
3069
3070        fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
3071            let is_numeric_vec = matches!(
3072                compose_ty, TypeInner::Vector { scalar, .. }
3073                if scalar.kind != ScalarKind::Bool
3074            );
3075            let is_allowed_vec_scalar_op = matches!(
3076                op,
3077                BinaryOperator::Add
3078                    | BinaryOperator::Subtract
3079                    | BinaryOperator::Multiply
3080                    | BinaryOperator::Divide
3081                    | BinaryOperator::Modulo
3082            );
3083            let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
3084            let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
3085            is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
3086        }
3087    }
3088
3089    fn binary_op_compose(
3090        &mut self,
3091        op: BinaryOperator,
3092        left_components: &[Handle<Expression>],
3093        right_components: &[Handle<Expression>],
3094        left_ty: Handle<Type>,
3095        right_ty: Handle<Type>,
3096        span: Span,
3097    ) -> Result<Expression, ConstantEvaluatorError> {
3098        match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
3099            // Binary operation on vector-vector
3100            (
3101                &TypeInner::Vector {
3102                    size: left_size, ..
3103                },
3104                &TypeInner::Vector {
3105                    size: right_size, ..
3106                },
3107            ) if left_size == right_size => self.binary_op_vector(
3108                op,
3109                left_size,
3110                left_components,
3111                right_components,
3112                left_ty,
3113                span,
3114            ),
3115            // Binary operation on vector-matrix
3116            (
3117                &TypeInner::Vector { size, .. },
3118                &TypeInner::Matrix {
3119                    columns,
3120                    rows,
3121                    scalar,
3122                },
3123            ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3124                left_components,
3125                right_components,
3126                columns,
3127                scalar,
3128                span,
3129            ),
3130            // Binary operation on matrix-vector
3131            (
3132                &TypeInner::Matrix {
3133                    columns,
3134                    rows,
3135                    scalar,
3136                },
3137                &TypeInner::Vector { size, .. },
3138            ) if op == BinaryOperator::Multiply && size == columns => {
3139                self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3140            }
3141            // Binary operation on matrix-matrix
3142            (
3143                &TypeInner::Matrix {
3144                    columns: left_columns,
3145                    rows: left_rows,
3146                    scalar,
3147                },
3148                &TypeInner::Matrix {
3149                    columns: right_columns,
3150                    rows: right_rows,
3151                    ..
3152                },
3153            ) => match op {
3154                BinaryOperator::Add | BinaryOperator::Subtract
3155                    if left_columns == right_columns && left_rows == right_rows =>
3156                {
3157                    let components = left_components
3158                        .iter()
3159                        .zip(right_components)
3160                        .map(|(&left, &right)| self.binary_op(op, left, right, span))
3161                        .collect::<Result<Vec<_>, _>>()?;
3162                    Ok(Expression::Compose {
3163                        ty: left_ty,
3164                        components,
3165                    })
3166                }
3167                BinaryOperator::Multiply if left_columns == right_rows => self
3168                    .multiply_matrix_matrix(
3169                        left_components,
3170                        right_components,
3171                        left_rows,
3172                        right_columns,
3173                        scalar,
3174                        span,
3175                    ),
3176                _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3177            },
3178            _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3179        }
3180    }
3181
3182    fn binary_op_vector(
3183        &mut self,
3184        op: BinaryOperator,
3185        size: crate::VectorSize,
3186        left_components: &[Handle<Expression>],
3187        right_components: &[Handle<Expression>],
3188        left_ty: Handle<Type>,
3189        span: Span,
3190    ) -> Result<Expression, ConstantEvaluatorError> {
3191        let ty = match op {
3192            // Relational operators produce vectors of booleans.
3193            BinaryOperator::Equal
3194            | BinaryOperator::NotEqual
3195            | BinaryOperator::Less
3196            | BinaryOperator::LessEqual
3197            | BinaryOperator::Greater
3198            | BinaryOperator::GreaterEqual => self.types.insert(
3199                Type {
3200                    name: None,
3201                    inner: TypeInner::Vector {
3202                        size,
3203                        scalar: crate::Scalar::BOOL,
3204                    },
3205                },
3206                span,
3207            ),
3208
3209            // Other operators produce the same type as their left
3210            // operand.
3211            BinaryOperator::Add
3212            | BinaryOperator::Subtract
3213            | BinaryOperator::Multiply
3214            | BinaryOperator::Divide
3215            | BinaryOperator::Modulo
3216            | BinaryOperator::And
3217            | BinaryOperator::ExclusiveOr
3218            | BinaryOperator::InclusiveOr
3219            | BinaryOperator::ShiftLeft
3220            | BinaryOperator::ShiftRight => left_ty,
3221
3222            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3223                // Not supported on vectors
3224                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3225            }
3226        };
3227
3228        let components = left_components
3229            .iter()
3230            .zip(right_components)
3231            .map(|(&left, &right)| self.binary_op(op, left, right, span))
3232            .collect::<Result<Vec<_>, _>>()?;
3233
3234        Ok(Expression::Compose { ty, components })
3235    }
3236
3237    fn multiply_vector_matrix(
3238        &mut self,
3239        vec_components: &[Handle<Expression>],
3240        mat_components: &[Handle<Expression>],
3241        mat_columns: crate::VectorSize,
3242        scalar: crate::Scalar,
3243        span: Span,
3244    ) -> Result<Expression, ConstantEvaluatorError> {
3245        let ty = self.types.insert(
3246            Type {
3247                name: None,
3248                inner: TypeInner::Vector {
3249                    size: mat_columns,
3250                    scalar,
3251                },
3252            },
3253            span,
3254        );
3255        let components = mat_components
3256            .iter()
3257            .map(|&column| {
3258                let Expression::Compose { ref components, .. } = self.expressions[column] else {
3259                    unreachable!()
3260                };
3261                self.dot_exprs(
3262                    vec_components.iter().cloned(),
3263                    components.clone().into_iter(),
3264                    span,
3265                )
3266            })
3267            .collect::<Result<Vec<_>, _>>()?;
3268        Ok(Expression::Compose { ty, components })
3269    }
3270
3271    fn multiply_matrix_vector(
3272        &mut self,
3273        mat_components: &[Handle<Expression>],
3274        vec_components: &[Handle<Expression>],
3275        mat_rows: crate::VectorSize,
3276        scalar: crate::Scalar,
3277        span: Span,
3278    ) -> Result<Expression, ConstantEvaluatorError> {
3279        let ty = self.types.insert(
3280            Type {
3281                name: None,
3282                inner: TypeInner::Vector {
3283                    size: mat_rows,
3284                    scalar,
3285                },
3286            },
3287            span,
3288        );
3289
3290        let flatten = self.flatten_matrix(mat_components);
3291        let nr = mat_rows as usize;
3292        let components = (0..nr)
3293            .map(|r| {
3294                let row = flatten.iter().skip(r).step_by(nr).cloned();
3295                self.dot_exprs(row, vec_components.iter().cloned(), span)
3296            })
3297            .collect::<Result<Vec<_>, _>>()?;
3298        Ok(Expression::Compose { ty, components })
3299    }
3300
3301    fn multiply_matrix_matrix(
3302        &mut self,
3303        left_components: &[Handle<Expression>],
3304        right_components: &[Handle<Expression>],
3305        left_rows: crate::VectorSize,
3306        right_columns: crate::VectorSize,
3307        scalar: crate::Scalar,
3308        span: Span,
3309    ) -> Result<Expression, ConstantEvaluatorError> {
3310        let left_nc = left_components.len();
3311        let left_nr = left_rows as usize;
3312        let right_nc = right_columns as usize;
3313        let right_nr = left_nc;
3314
3315        let mut result = Vec::with_capacity(right_nc);
3316        let result_ty = self.types.insert(
3317            Type {
3318                name: None,
3319                inner: TypeInner::Matrix {
3320                    columns: right_columns,
3321                    rows: left_rows,
3322                    scalar,
3323                },
3324            },
3325            span,
3326        );
3327        let result_column_ty = self.types.insert(
3328            Type {
3329                name: None,
3330                inner: TypeInner::Vector {
3331                    size: left_rows,
3332                    scalar,
3333                },
3334            },
3335            span,
3336        );
3337
3338        let left_flattened = self.flatten_matrix(left_components);
3339        let right_flattened = self.flatten_matrix(right_components);
3340        for c in 0..right_nc {
3341            let result_column = (0..left_nr)
3342                .map(|r| {
3343                    let row = left_flattened.iter().skip(r).step_by(left_nr);
3344                    let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3345                    self.dot_exprs(row.cloned(), column.cloned(), span)
3346                })
3347                .collect::<Result<Vec<_>, _>>()?;
3348            let expr = Expression::Compose {
3349                ty: result_column_ty,
3350                components: result_column,
3351            };
3352            let handle = self.register_evaluated_expr(expr, span)?;
3353            result.push(handle);
3354        }
3355        Ok(Expression::Compose {
3356            ty: result_ty,
3357            components: result,
3358        })
3359    }
3360
3361    fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3362        let mut flattened = ArrayVec::<_, 16>::new();
3363        for &column in columns {
3364            let Expression::Compose { ref components, .. } = self.expressions[column] else {
3365                unreachable!()
3366            };
3367            flattened.extend(components.iter().cloned());
3368        }
3369        flattened
3370    }
3371
3372    fn dot_exprs(
3373        &mut self,
3374        left: impl Iterator<Item = Handle<Expression>>,
3375        right: impl Iterator<Item = Handle<Expression>>,
3376        span: Span,
3377    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3378        let mut acc = None;
3379        for (l, r) in left.zip(right) {
3380            let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3381            match acc.as_mut() {
3382                Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3383                None => acc = Some(result),
3384            }
3385        }
3386        Ok(acc.unwrap())
3387    }
3388
3389    fn relational(
3390        &mut self,
3391        fun: RelationalFunction,
3392        arg: Handle<Expression>,
3393        span: Span,
3394    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3395        let arg = self.eval_zero_value_and_splat(arg, span)?;
3396        match fun {
3397            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3398                Expression::Literal(Literal::Bool(_)) => Ok(arg),
3399                Expression::Compose { ty, ref components }
3400                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3401                {
3402                    let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3403                    for component in
3404                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3405                    {
3406                        match self.expressions[component] {
3407                            Expression::Literal(Literal::Bool(val)) => {
3408                                bool_components.push(val);
3409                            }
3410                            _ => {
3411                                return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3412                            }
3413                        }
3414                    }
3415                    let components = bool_components;
3416                    let result = match fun {
3417                        RelationalFunction::All => components.iter().all(|c| *c),
3418                        RelationalFunction::Any => components.iter().any(|c| *c),
3419                        _ => unreachable!(),
3420                    };
3421                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3422                }
3423                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3424            },
3425            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3426                "{fun:?} built-in function"
3427            ))),
3428        }
3429    }
3430
3431    /// Deep copy `expr` from `expressions` into `self.expressions`.
3432    ///
3433    /// Return the root of the new copy.
3434    ///
3435    /// This is used when we're evaluating expressions in a function's
3436    /// expression arena that refer to a constant: we need to copy the
3437    /// constant's value into the function's arena so we can operate on it.
3438    fn copy_from(
3439        &mut self,
3440        expr: Handle<Expression>,
3441        expressions: &Arena<Expression>,
3442    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3443        let span = expressions.get_span(expr);
3444        match expressions[expr] {
3445            ref expr @ (Expression::Literal(_)
3446            | Expression::Constant(_)
3447            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3448            Expression::Compose { ty, ref components } => {
3449                let mut components = components.clone();
3450                for component in &mut components {
3451                    *component = self.copy_from(*component, expressions)?;
3452                }
3453                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3454            }
3455            Expression::Splat { size, value } => {
3456                let value = self.copy_from(value, expressions)?;
3457                self.register_evaluated_expr(Expression::Splat { size, value }, span)
3458            }
3459            _ => {
3460                log::debug!("copy_from: SubexpressionsAreNotConstant");
3461                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3462            }
3463        }
3464    }
3465
3466    /// Returns the total number of components, after flattening, of a vector compose expression.
3467    fn vector_compose_flattened_size(
3468        &self,
3469        components: &[Handle<Expression>],
3470    ) -> Result<usize, ConstantEvaluatorError> {
3471        components
3472            .iter()
3473            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3474                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3475                    TypeInner::Scalar(_) => 1,
3476                    // We trust that the vector size of `component` is correct,
3477                    // as it will have already been validated when `component`
3478                    // was registered.
3479                    TypeInner::Vector { size, .. } => size as usize,
3480                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3481                };
3482                Ok(acc + size)
3483            })
3484    }
3485
3486    fn register_evaluated_expr(
3487        &mut self,
3488        expr: Expression,
3489        span: Span,
3490    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3491        // It suffices to only check_literal_value() for `Literal` expressions,
3492        // since we only register one expression at a time, `Compose`
3493        // expressions can only refer to other expressions, and `ZeroValue`
3494        // expressions are always okay.
3495        if let Expression::Literal(literal) = expr {
3496            crate::valid::check_literal_value(literal)?;
3497        }
3498
3499        // Ensure vector composes contain the correct number of components. We
3500        // do so here when each compose is registered to avoid having to deal
3501        // with the mess each time the compose is used in another expression.
3502        if let Expression::Compose { ty, ref components } = expr {
3503            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3504                let expected = size as usize;
3505                let actual = self.vector_compose_flattened_size(components)?;
3506                if expected != actual {
3507                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3508                        expected,
3509                        actual,
3510                    });
3511                }
3512            }
3513        }
3514
3515        Ok(self.append_expr(expr, span, ExpressionKind::Const))
3516    }
3517
3518    fn append_expr(
3519        &mut self,
3520        expr: Expression,
3521        span: Span,
3522        expr_type: ExpressionKind,
3523    ) -> Handle<Expression> {
3524        let h = match self.behavior {
3525            Behavior::Wgsl(
3526                WgslRestrictions::Runtime(ref mut function_local_data)
3527                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3528            )
3529            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3530                let is_running = function_local_data.emitter.is_running();
3531                let needs_pre_emit = expr.needs_pre_emit();
3532                if is_running && needs_pre_emit {
3533                    function_local_data
3534                        .block
3535                        .extend(function_local_data.emitter.finish(self.expressions));
3536                    let h = self.expressions.append(expr, span);
3537                    function_local_data.emitter.start(self.expressions);
3538                    h
3539                } else {
3540                    self.expressions.append(expr, span)
3541                }
3542            }
3543            _ => self.expressions.append(expr, span),
3544        };
3545        self.expression_kind_tracker.insert(h, expr_type);
3546        h
3547    }
3548
3549    /// Resolve the type of `expr` if it is a constant expression.
3550    ///
3551    /// If `expr` was evaluated to a constant, returns its type.
3552    /// Otherwise, returns an error.
3553    fn resolve_type(
3554        &self,
3555        expr: Handle<Expression>,
3556    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3557        use crate::proc::TypeResolution as Tr;
3558        use crate::Expression as Ex;
3559        let resolution = match self.expressions[expr] {
3560            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3561            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3562            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3563            Ex::Splat { size, value } => {
3564                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3565                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3566                };
3567                Tr::Value(TypeInner::Vector { scalar, size })
3568            }
3569            _ => {
3570                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3571                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3572            }
3573        };
3574
3575        Ok(resolution)
3576    }
3577
3578    fn select(
3579        &mut self,
3580        reject: Handle<Expression>,
3581        accept: Handle<Expression>,
3582        condition: Handle<Expression>,
3583        span: Span,
3584    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3585        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3586
3587        let reject = arg(reject)?;
3588        let accept = arg(accept)?;
3589        let condition = arg(condition)?;
3590
3591        let select_single_component =
3592            |this: &mut Self, reject_scalar, reject, accept, condition| {
3593                let accept = this.cast(accept, reject_scalar, span)?;
3594                if condition {
3595                    Ok(accept)
3596                } else {
3597                    Ok(reject)
3598                }
3599            };
3600
3601        match (&self.expressions[reject], &self.expressions[accept]) {
3602            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3603                let reject_scalar = reject_lit.scalar();
3604                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3605                else {
3606                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3607                };
3608                select_single_component(self, reject_scalar, reject, accept, condition)
3609            }
3610            (
3611                &Expression::Compose {
3612                    ty: reject_ty,
3613                    components: ref reject_components,
3614                },
3615                &Expression::Compose {
3616                    ty: accept_ty,
3617                    components: ref accept_components,
3618                },
3619            ) => {
3620                let ty_deets = |ty| {
3621                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3622                    (size.unwrap(), scalar)
3623                };
3624
3625                let expected_vec_size = {
3626                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3627                        [reject_ty, accept_ty].map(ty_deets);
3628
3629                    if reject_vec_size != accept_vec_size {
3630                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3631                            reject: reject_vec_size,
3632                            accept: accept_vec_size,
3633                        });
3634                    }
3635                    reject_vec_size
3636                };
3637
3638                let condition_components = match self.expressions[condition] {
3639                    Expression::Literal(Literal::Bool(condition)) => {
3640                        vec![condition; (expected_vec_size as u8).into()]
3641                    }
3642                    Expression::Compose {
3643                        ty: condition_ty,
3644                        components: ref condition_components,
3645                    } => {
3646                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3647                        if condition_scalar.kind != ScalarKind::Bool {
3648                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3649                        }
3650                        if condition_vec_size != expected_vec_size {
3651                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3652                        }
3653                        condition_components
3654                            .iter()
3655                            .copied()
3656                            .map(|component| match &self.expressions[component] {
3657                                &Expression::Literal(Literal::Bool(condition)) => condition,
3658                                _ => unreachable!(),
3659                            })
3660                            .collect()
3661                    }
3662
3663                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3664                };
3665
3666                let evaluated = Expression::Compose {
3667                    ty: reject_ty,
3668                    components: reject_components
3669                        .clone()
3670                        .into_iter()
3671                        .zip(accept_components.clone().into_iter())
3672                        .zip(condition_components.into_iter())
3673                        .map(|((reject, accept), condition)| {
3674                            let reject_scalar = match &self.expressions[reject] {
3675                                &Expression::Literal(lit) => lit.scalar(),
3676                                _ => unreachable!(),
3677                            };
3678                            select_single_component(self, reject_scalar, reject, accept, condition)
3679                        })
3680                        .collect::<Result<_, _>>()?,
3681                };
3682                self.register_evaluated_expr(evaluated, span)
3683            }
3684            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3685        }
3686    }
3687}
3688
3689fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3690    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3691    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3692    // return a right-to-left bit index of 0.
3693    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3694        match e {
3695            idx @ 0..=31 => idx,
3696            32 => u32::MAX,
3697            _ => unreachable!(),
3698        }
3699    };
3700    match concrete_int {
3701        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3702        ConcreteInt::I32([e]) => {
3703            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3704        }
3705    }
3706}
3707
3708#[test]
3709fn first_trailing_bit_smoke() {
3710    assert_eq!(
3711        first_trailing_bit(ConcreteInt::I32([0])),
3712        ConcreteInt::I32([-1])
3713    );
3714    assert_eq!(
3715        first_trailing_bit(ConcreteInt::I32([1])),
3716        ConcreteInt::I32([0])
3717    );
3718    assert_eq!(
3719        first_trailing_bit(ConcreteInt::I32([2])),
3720        ConcreteInt::I32([1])
3721    );
3722    assert_eq!(
3723        first_trailing_bit(ConcreteInt::I32([-1])),
3724        ConcreteInt::I32([0]),
3725    );
3726    assert_eq!(
3727        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3728        ConcreteInt::I32([31]),
3729    );
3730    assert_eq!(
3731        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3732        ConcreteInt::I32([0]),
3733    );
3734    for idx in 0..32 {
3735        assert_eq!(
3736            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3737            ConcreteInt::I32([idx])
3738        )
3739    }
3740
3741    assert_eq!(
3742        first_trailing_bit(ConcreteInt::U32([0])),
3743        ConcreteInt::U32([u32::MAX])
3744    );
3745    assert_eq!(
3746        first_trailing_bit(ConcreteInt::U32([1])),
3747        ConcreteInt::U32([0])
3748    );
3749    assert_eq!(
3750        first_trailing_bit(ConcreteInt::U32([2])),
3751        ConcreteInt::U32([1])
3752    );
3753    assert_eq!(
3754        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3755        ConcreteInt::U32([31]),
3756    );
3757    assert_eq!(
3758        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3759        ConcreteInt::U32([0]),
3760    );
3761    for idx in 0..32 {
3762        assert_eq!(
3763            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3764            ConcreteInt::U32([idx])
3765        )
3766    }
3767}
3768
3769fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3770    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3771    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3772    // index of 0.
3773    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3774        match e {
3775            idx @ 0..=31 => 31 - idx,
3776            32 => u32::MAX,
3777            _ => unreachable!(),
3778        }
3779    };
3780    match concrete_int {
3781        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3782            let rtl_bit_index = if e.is_negative() {
3783                e.leading_ones()
3784            } else {
3785                e.leading_zeros()
3786            };
3787            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3788        }]),
3789        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3790    }
3791}
3792
3793#[test]
3794fn first_leading_bit_smoke() {
3795    assert_eq!(
3796        first_leading_bit(ConcreteInt::I32([-1])),
3797        ConcreteInt::I32([-1])
3798    );
3799    assert_eq!(
3800        first_leading_bit(ConcreteInt::I32([0])),
3801        ConcreteInt::I32([-1])
3802    );
3803    assert_eq!(
3804        first_leading_bit(ConcreteInt::I32([1])),
3805        ConcreteInt::I32([0])
3806    );
3807    assert_eq!(
3808        first_leading_bit(ConcreteInt::I32([-2])),
3809        ConcreteInt::I32([0])
3810    );
3811    assert_eq!(
3812        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3813        ConcreteInt::I32([12])
3814    );
3815    assert_eq!(
3816        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3817        ConcreteInt::I32([30])
3818    );
3819    assert_eq!(
3820        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3821        ConcreteInt::I32([30])
3822    );
3823    // NOTE: Ignore the sign bit, which is a separate (above) case.
3824    for idx in 0..(32 - 1) {
3825        assert_eq!(
3826            first_leading_bit(ConcreteInt::I32([1 << idx])),
3827            ConcreteInt::I32([idx])
3828        );
3829    }
3830    for idx in 1..(32 - 1) {
3831        assert_eq!(
3832            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3833            ConcreteInt::I32([idx - 1])
3834        );
3835    }
3836
3837    assert_eq!(
3838        first_leading_bit(ConcreteInt::U32([0])),
3839        ConcreteInt::U32([u32::MAX])
3840    );
3841    assert_eq!(
3842        first_leading_bit(ConcreteInt::U32([1])),
3843        ConcreteInt::U32([0])
3844    );
3845    assert_eq!(
3846        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3847        ConcreteInt::U32([31])
3848    );
3849    for idx in 0..32 {
3850        assert_eq!(
3851            first_leading_bit(ConcreteInt::U32([1 << idx])),
3852            ConcreteInt::U32([idx])
3853        )
3854    }
3855}
3856
3857/// Trait for conversions of abstract values to concrete types.
3858trait TryFromAbstract<T>: Sized {
3859    /// Convert an abstract literal `value` to `Self`.
3860    ///
3861    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3862    /// WGSL, we follow WGSL's conversion rules here:
3863    ///
3864    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3865    ///   from [`AbstractInt`] to an integer type are either lossless or an
3866    ///   error.
3867    ///
3868    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3869    ///   to floating point in constant expressions and override
3870    ///   expressions are errors if the value is out of range for the
3871    ///   destination type, but rounding is okay.
3872    ///
3873    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3874    ///   other floating point type, following the scalar floating point to
3875    ///   integral conversion algorithm (§15.7.6). There is no automatic
3876    ///   conversion from AbstractFloat to integer types.
3877    ///
3878    /// [`AbstractInt`]: crate::Literal::AbstractInt
3879    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3880    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3881}
3882
3883impl TryFromAbstract<i64> for i32 {
3884    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3885        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3886            value: format!("{value:?}"),
3887            to_type: "i32",
3888        })
3889    }
3890}
3891
3892impl TryFromAbstract<i64> for u32 {
3893    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3894        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3895            value: format!("{value:?}"),
3896            to_type: "u32",
3897        })
3898    }
3899}
3900
3901impl TryFromAbstract<i64> for u64 {
3902    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3903        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3904            value: format!("{value:?}"),
3905            to_type: "u64",
3906        })
3907    }
3908}
3909
3910impl TryFromAbstract<i64> for i64 {
3911    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3912        Ok(value)
3913    }
3914}
3915
3916impl TryFromAbstract<i64> for f32 {
3917    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3918        let f = value as f32;
3919        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3920        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3921        // overflow here.
3922        Ok(f)
3923    }
3924}
3925
3926impl TryFromAbstract<f64> for f32 {
3927    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3928        let f = value as f32;
3929        if f.is_infinite() {
3930            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3931                value: format!("{value:?}"),
3932                to_type: "f32",
3933            });
3934        }
3935        Ok(f)
3936    }
3937}
3938
3939impl TryFromAbstract<i64> for f64 {
3940    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3941        let f = value as f64;
3942        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3943        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3944        // overflow here.
3945        Ok(f)
3946    }
3947}
3948
3949impl TryFromAbstract<f64> for f64 {
3950    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3951        Ok(value)
3952    }
3953}
3954
3955impl TryFromAbstract<f64> for i32 {
3956    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3957        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3958        // To convert a floating point scalar value X to an integer scalar type T:
3959        // * If X is a NaN, the result is an indeterminate value in T.
3960        // * If X is exactly representable in the target type T, then the
3961        //   result is that value.
3962        // * Otherwise, the result is the value in T closest to truncate(X) and
3963        //   also exactly representable in the original floating point type.
3964        //
3965        // A rust cast satisfies these requirements apart from "the result
3966        // is... exactly representable in the original floating point type".
3967        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3968        // we're all good.
3969        Ok(value as i32)
3970    }
3971}
3972
3973impl TryFromAbstract<f64> for u32 {
3974    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3975        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3976        // so a simple rust cast is sufficient.
3977        Ok(value as u32)
3978    }
3979}
3980
3981impl TryFromAbstract<f64> for i64 {
3982    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3983        // As above, except we clamp to the minimum and maximum values
3984        // representable by both f64 and i64.
3985        use crate::proc::type_methods::IntFloatLimits;
3986        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3987    }
3988}
3989
3990impl TryFromAbstract<f64> for u64 {
3991    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3992        // As above, this time clamping to the minimum and maximum values
3993        // representable by both f64 and u64.
3994        use crate::proc::type_methods::IntFloatLimits;
3995        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3996    }
3997}
3998
3999impl TryFromAbstract<f64> for f16 {
4000    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
4001        let f = f16::from_f64(value);
4002        if f.is_infinite() {
4003            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4004                value: format!("{value:?}"),
4005                to_type: "f16",
4006            });
4007        }
4008        Ok(f)
4009    }
4010}
4011
4012impl TryFromAbstract<i64> for f16 {
4013    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
4014        let f = f16::from_i64(value);
4015        if f.is_none() {
4016            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4017                value: format!("{value:?}"),
4018                to_type: "f16",
4019            });
4020        }
4021        Ok(f.unwrap())
4022    }
4023}
4024
4025fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
4026where
4027    T: Copy,
4028    T: core::ops::Mul<T, Output = T>,
4029    T: core::ops::Sub<T, Output = T>,
4030{
4031    [
4032        a[1] * b[2] - a[2] * b[1],
4033        a[2] * b[0] - a[0] * b[2],
4034        a[0] * b[1] - a[1] * b[0],
4035    ]
4036}
4037
4038#[cfg(test)]
4039mod tests {
4040    use alloc::{vec, vec::Vec};
4041
4042    use crate::{
4043        Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
4044        Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
4045    };
4046
4047    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
4048
4049    #[test]
4050    fn unary_op() {
4051        let mut types = UniqueArena::new();
4052        let mut constants = Arena::new();
4053        let overrides = Arena::new();
4054        let mut global_expressions = Arena::new();
4055
4056        let scalar_ty = types.insert(
4057            Type {
4058                name: None,
4059                inner: TypeInner::Scalar(crate::Scalar::I32),
4060            },
4061            Default::default(),
4062        );
4063
4064        let vec_ty = types.insert(
4065            Type {
4066                name: None,
4067                inner: TypeInner::Vector {
4068                    size: VectorSize::Bi,
4069                    scalar: crate::Scalar::I32,
4070                },
4071            },
4072            Default::default(),
4073        );
4074
4075        let h = constants.append(
4076            Constant {
4077                name: None,
4078                ty: scalar_ty,
4079                init: global_expressions
4080                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4081            },
4082            Default::default(),
4083        );
4084
4085        let h1 = constants.append(
4086            Constant {
4087                name: None,
4088                ty: scalar_ty,
4089                init: global_expressions
4090                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
4091            },
4092            Default::default(),
4093        );
4094
4095        let vec_h = constants.append(
4096            Constant {
4097                name: None,
4098                ty: vec_ty,
4099                init: global_expressions.append(
4100                    Expression::Compose {
4101                        ty: vec_ty,
4102                        components: vec![constants[h].init, constants[h1].init],
4103                    },
4104                    Default::default(),
4105                ),
4106            },
4107            Default::default(),
4108        );
4109
4110        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4111        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
4112
4113        let expr2 = Expression::Unary {
4114            op: UnaryOperator::Negate,
4115            expr,
4116        };
4117
4118        let expr3 = Expression::Unary {
4119            op: UnaryOperator::BitwiseNot,
4120            expr,
4121        };
4122
4123        let expr4 = Expression::Unary {
4124            op: UnaryOperator::BitwiseNot,
4125            expr: expr1,
4126        };
4127
4128        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4129        let mut solver = ConstantEvaluator {
4130            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4131            types: &mut types,
4132            constants: &constants,
4133            overrides: &overrides,
4134            expressions: &mut global_expressions,
4135            expression_kind_tracker,
4136            layouter: &mut crate::proc::Layouter::default(),
4137        };
4138
4139        let res1 = solver
4140            .try_eval_and_append(expr2, Default::default())
4141            .unwrap();
4142        let res2 = solver
4143            .try_eval_and_append(expr3, Default::default())
4144            .unwrap();
4145        let res3 = solver
4146            .try_eval_and_append(expr4, Default::default())
4147            .unwrap();
4148
4149        assert_eq!(
4150            global_expressions[res1],
4151            Expression::Literal(Literal::I32(-4))
4152        );
4153
4154        assert_eq!(
4155            global_expressions[res2],
4156            Expression::Literal(Literal::I32(!4))
4157        );
4158
4159        let res3_inner = &global_expressions[res3];
4160
4161        match *res3_inner {
4162            Expression::Compose {
4163                ref ty,
4164                ref components,
4165            } => {
4166                assert_eq!(*ty, vec_ty);
4167                let mut components_iter = components.iter().copied();
4168                assert_eq!(
4169                    global_expressions[components_iter.next().unwrap()],
4170                    Expression::Literal(Literal::I32(!4))
4171                );
4172                assert_eq!(
4173                    global_expressions[components_iter.next().unwrap()],
4174                    Expression::Literal(Literal::I32(!8))
4175                );
4176                assert!(components_iter.next().is_none());
4177            }
4178            _ => panic!("Expected vector"),
4179        }
4180    }
4181
4182    #[test]
4183    fn matrix_op() {
4184        let mut helper = MatrixTestHelper::new();
4185
4186        for nc in 2..=4 {
4187            for nr in 2..=4 {
4188                // Validates multiplication on vector-matrix.
4189                // vecR(0, 1, .., r) * matCxR(0, 1, .., nc * nr)
4190                let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4191                let expected = (0..nc)
4192                    .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4193                    .collect::<Vec<f32>>();
4194                assert_eq!(evaluated, expected);
4195
4196                // Validates multiplication on matrix-vector.
4197                // matCxR(0, 1, .., nc * nr) * vecC(0, 1, .., nc)
4198                let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4199                let expected = (0..nr)
4200                    .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4201                    .collect::<Vec<f32>>();
4202                assert_eq!(evaluated, expected);
4203
4204                for k in 2..=4 {
4205                    // Validates multiplication on matrix-matrix.
4206                    // matKxR(0, 1, .., k * nr) * matCxK(0, 1, .., nc * k)
4207                    let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4208                    let expected = (0..nc)
4209                        .flat_map(|c| {
4210                            (0..nr).map(move |r| {
4211                                (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4212                            })
4213                        })
4214                        .collect::<Vec<f32>>();
4215                    assert_eq!(evaluated, expected);
4216                }
4217            }
4218        }
4219    }
4220
4221    /// Test fixture providing pre-built f32 vector and matrix constant
4222    /// expressions with sequential element values, used to evaluate and verify
4223    /// matrix operations.
4224    struct MatrixTestHelper {
4225        types: UniqueArena<Type>,
4226        expressions: Arena<Expression>,
4227        /// Vector expressions from [0, 1] to [0, 1, 2, 3].
4228        vec_exprs: FastHashMap<usize, Handle<Expression>>,
4229        /// Matrix expressions from [0, .., 3] to [0, .., 15].
4230        mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4231    }
4232
4233    impl MatrixTestHelper {
4234        fn new() -> Self {
4235            let mut types = UniqueArena::new();
4236            let mut expressions = Arena::new();
4237            let span = crate::Span::default();
4238
4239            let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4240            for c in 2..=4 {
4241                let vec_ty = types.insert(
4242                    Type {
4243                        name: None,
4244                        inner: TypeInner::Vector {
4245                            size: Self::int_to_vector_size(c),
4246                            scalar: crate::Scalar::F32,
4247                        },
4248                    },
4249                    span,
4250                );
4251                vec_tys.insert(c, vec_ty);
4252                for r in 2..=4 {
4253                    let mat_ty = types.insert(
4254                        Type {
4255                            name: None,
4256                            inner: TypeInner::Matrix {
4257                                columns: Self::int_to_vector_size(c),
4258                                rows: Self::int_to_vector_size(r),
4259                                scalar: crate::Scalar::F32,
4260                            },
4261                        },
4262                        span,
4263                    );
4264                    mat_tys.insert((c, r), mat_ty);
4265                }
4266            }
4267
4268            let mut lit_exprs = FastHashMap::default();
4269            for i in 0..16 {
4270                let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4271                lit_exprs.insert(i, expr);
4272            }
4273
4274            let mut vec_exprs = FastHashMap::default();
4275            for c in 2..=4 {
4276                let expr = expressions.append(
4277                    Expression::Compose {
4278                        ty: *vec_tys.get(&c).unwrap(),
4279                        components: (0..c)
4280                            .map(|i| *lit_exprs.get(&i).unwrap())
4281                            .collect::<Vec<_>>(),
4282                    },
4283                    span,
4284                );
4285                vec_exprs.insert(c, expr);
4286            }
4287
4288            let mut mat_exprs = FastHashMap::default();
4289            for c in 2..=4 {
4290                for r in 2..=4 {
4291                    let mut columns = Vec::with_capacity(c);
4292                    for cc in 0..c {
4293                        let start = cc * r;
4294                        let expr = expressions.append(
4295                            Expression::Compose {
4296                                ty: *vec_tys.get(&r).unwrap(),
4297                                components: (start..start + r)
4298                                    .map(|i| *lit_exprs.get(&i).unwrap())
4299                                    .collect::<Vec<_>>(),
4300                            },
4301                            span,
4302                        );
4303                        columns.push(expr);
4304                    }
4305
4306                    let expr = expressions.append(
4307                        Expression::Compose {
4308                            ty: *mat_tys.get(&(c, r)).unwrap(),
4309                            components: columns,
4310                        },
4311                        span,
4312                    );
4313                    mat_exprs.insert((c, r), expr);
4314                }
4315            }
4316
4317            Self {
4318                types,
4319                expressions,
4320                vec_exprs,
4321                mat_exprs,
4322            }
4323        }
4324
4325        /// Evaluates vec[0..nr] * mat[0..nc*nr] and returns the result as f32s.
4326        fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4327            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4328            let mut solver = ConstantEvaluator {
4329                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4330                types: &mut self.types,
4331                constants: &Arena::new(),
4332                overrides: &Arena::new(),
4333                expressions: &mut self.expressions,
4334                expression_kind_tracker,
4335                layouter: &mut crate::proc::Layouter::default(),
4336            };
4337
4338            let result = solver
4339                .try_eval_and_append(
4340                    Expression::Binary {
4341                        op: BinaryOperator::Multiply,
4342                        left: *self.vec_exprs.get(&nr).unwrap(),
4343                        right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4344                    },
4345                    Default::default(),
4346                )
4347                .unwrap();
4348            self.flatten(result)
4349        }
4350
4351        /// Evaluates mat[0..nc*nr] * vec[0..nc] and returns the result as f32s.
4352        fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4353            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4354            let mut solver = ConstantEvaluator {
4355                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4356                types: &mut self.types,
4357                constants: &Arena::new(),
4358                overrides: &Arena::new(),
4359                expressions: &mut self.expressions,
4360                expression_kind_tracker,
4361                layouter: &mut crate::proc::Layouter::default(),
4362            };
4363
4364            let result = solver
4365                .try_eval_and_append(
4366                    Expression::Binary {
4367                        op: BinaryOperator::Multiply,
4368                        left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4369                        right: *self.vec_exprs.get(&nc).unwrap(),
4370                    },
4371                    Default::default(),
4372                )
4373                .unwrap();
4374            self.flatten(result)
4375        }
4376
4377        /// Evaluates mat[0..k*l_nr] * mat[0..r_nc*k] and returns the result as
4378        /// f32s.
4379        fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4380            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4381            let mut solver = ConstantEvaluator {
4382                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4383                types: &mut self.types,
4384                constants: &Arena::new(),
4385                overrides: &Arena::new(),
4386                expressions: &mut self.expressions,
4387                expression_kind_tracker,
4388                layouter: &mut crate::proc::Layouter::default(),
4389            };
4390
4391            let result = solver
4392                .try_eval_and_append(
4393                    Expression::Binary {
4394                        op: BinaryOperator::Multiply,
4395                        left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4396                        right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4397                    },
4398                    Default::default(),
4399                )
4400                .unwrap();
4401            self.flatten(result)
4402        }
4403
4404        fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4405            let Expression::Compose {
4406                ref components,
4407                ref ty,
4408            } = self.expressions[expr]
4409            else {
4410                unreachable!()
4411            };
4412
4413            match self.types[*ty].inner {
4414                TypeInner::Vector { .. } => components
4415                    .iter()
4416                    .map(|&comp| {
4417                        let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4418                            unreachable!()
4419                        };
4420                        v
4421                    })
4422                    .collect(),
4423                TypeInner::Matrix { .. } => components
4424                    .iter()
4425                    .flat_map(|&comp| self.flatten(comp))
4426                    .collect(),
4427                _ => unreachable!(),
4428            }
4429        }
4430
4431        fn int_to_vector_size(int: usize) -> VectorSize {
4432            match int {
4433                2 => VectorSize::Bi,
4434                3 => VectorSize::Tri,
4435                4 => VectorSize::Quad,
4436                _ => unreachable!(),
4437            }
4438        }
4439    }
4440
4441    #[test]
4442    fn cast() {
4443        let mut types = UniqueArena::new();
4444        let mut constants = Arena::new();
4445        let overrides = Arena::new();
4446        let mut global_expressions = Arena::new();
4447
4448        let scalar_ty = types.insert(
4449            Type {
4450                name: None,
4451                inner: TypeInner::Scalar(crate::Scalar::I32),
4452            },
4453            Default::default(),
4454        );
4455
4456        let h = constants.append(
4457            Constant {
4458                name: None,
4459                ty: scalar_ty,
4460                init: global_expressions
4461                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4462            },
4463            Default::default(),
4464        );
4465
4466        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4467
4468        let root = Expression::As {
4469            expr,
4470            kind: ScalarKind::Bool,
4471            convert: Some(crate::BOOL_WIDTH),
4472        };
4473
4474        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4475        let mut solver = ConstantEvaluator {
4476            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4477            types: &mut types,
4478            constants: &constants,
4479            overrides: &overrides,
4480            expressions: &mut global_expressions,
4481            expression_kind_tracker,
4482            layouter: &mut crate::proc::Layouter::default(),
4483        };
4484
4485        let res = solver
4486            .try_eval_and_append(root, Default::default())
4487            .unwrap();
4488
4489        assert_eq!(
4490            global_expressions[res],
4491            Expression::Literal(Literal::Bool(true))
4492        );
4493    }
4494
4495    #[test]
4496    fn access() {
4497        let mut types = UniqueArena::new();
4498        let mut constants = Arena::new();
4499        let overrides = Arena::new();
4500        let mut global_expressions = Arena::new();
4501
4502        let matrix_ty = types.insert(
4503            Type {
4504                name: None,
4505                inner: TypeInner::Matrix {
4506                    columns: VectorSize::Bi,
4507                    rows: VectorSize::Tri,
4508                    scalar: crate::Scalar::F32,
4509                },
4510            },
4511            Default::default(),
4512        );
4513
4514        let vec_ty = types.insert(
4515            Type {
4516                name: None,
4517                inner: TypeInner::Vector {
4518                    size: VectorSize::Tri,
4519                    scalar: crate::Scalar::F32,
4520                },
4521            },
4522            Default::default(),
4523        );
4524
4525        let mut vec1_components = Vec::with_capacity(3);
4526        let mut vec2_components = Vec::with_capacity(3);
4527
4528        for i in 0..3 {
4529            let h = global_expressions.append(
4530                Expression::Literal(Literal::F32(i as f32)),
4531                Default::default(),
4532            );
4533
4534            vec1_components.push(h)
4535        }
4536
4537        for i in 3..6 {
4538            let h = global_expressions.append(
4539                Expression::Literal(Literal::F32(i as f32)),
4540                Default::default(),
4541            );
4542
4543            vec2_components.push(h)
4544        }
4545
4546        let vec1 = constants.append(
4547            Constant {
4548                name: None,
4549                ty: vec_ty,
4550                init: global_expressions.append(
4551                    Expression::Compose {
4552                        ty: vec_ty,
4553                        components: vec1_components,
4554                    },
4555                    Default::default(),
4556                ),
4557            },
4558            Default::default(),
4559        );
4560
4561        let vec2 = constants.append(
4562            Constant {
4563                name: None,
4564                ty: vec_ty,
4565                init: global_expressions.append(
4566                    Expression::Compose {
4567                        ty: vec_ty,
4568                        components: vec2_components,
4569                    },
4570                    Default::default(),
4571                ),
4572            },
4573            Default::default(),
4574        );
4575
4576        let h = constants.append(
4577            Constant {
4578                name: None,
4579                ty: matrix_ty,
4580                init: global_expressions.append(
4581                    Expression::Compose {
4582                        ty: matrix_ty,
4583                        components: vec![constants[vec1].init, constants[vec2].init],
4584                    },
4585                    Default::default(),
4586                ),
4587            },
4588            Default::default(),
4589        );
4590
4591        let base = global_expressions.append(Expression::Constant(h), Default::default());
4592
4593        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4594        let mut solver = ConstantEvaluator {
4595            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4596            types: &mut types,
4597            constants: &constants,
4598            overrides: &overrides,
4599            expressions: &mut global_expressions,
4600            expression_kind_tracker,
4601            layouter: &mut crate::proc::Layouter::default(),
4602        };
4603
4604        let root1 = Expression::AccessIndex { base, index: 1 };
4605
4606        let res1 = solver
4607            .try_eval_and_append(root1, Default::default())
4608            .unwrap();
4609
4610        let root2 = Expression::AccessIndex {
4611            base: res1,
4612            index: 2,
4613        };
4614
4615        let res2 = solver
4616            .try_eval_and_append(root2, Default::default())
4617            .unwrap();
4618
4619        match global_expressions[res1] {
4620            Expression::Compose {
4621                ref ty,
4622                ref components,
4623            } => {
4624                assert_eq!(*ty, vec_ty);
4625                let mut components_iter = components.iter().copied();
4626                assert_eq!(
4627                    global_expressions[components_iter.next().unwrap()],
4628                    Expression::Literal(Literal::F32(3.))
4629                );
4630                assert_eq!(
4631                    global_expressions[components_iter.next().unwrap()],
4632                    Expression::Literal(Literal::F32(4.))
4633                );
4634                assert_eq!(
4635                    global_expressions[components_iter.next().unwrap()],
4636                    Expression::Literal(Literal::F32(5.))
4637                );
4638                assert!(components_iter.next().is_none());
4639            }
4640            _ => panic!("Expected vector"),
4641        }
4642
4643        assert_eq!(
4644            global_expressions[res2],
4645            Expression::Literal(Literal::F32(5.))
4646        );
4647    }
4648
4649    #[test]
4650    fn compose_of_constants() {
4651        let mut types = UniqueArena::new();
4652        let mut constants = Arena::new();
4653        let overrides = Arena::new();
4654        let mut global_expressions = Arena::new();
4655
4656        let i32_ty = types.insert(
4657            Type {
4658                name: None,
4659                inner: TypeInner::Scalar(crate::Scalar::I32),
4660            },
4661            Default::default(),
4662        );
4663
4664        let vec2_i32_ty = types.insert(
4665            Type {
4666                name: None,
4667                inner: TypeInner::Vector {
4668                    size: VectorSize::Bi,
4669                    scalar: crate::Scalar::I32,
4670                },
4671            },
4672            Default::default(),
4673        );
4674
4675        let h = constants.append(
4676            Constant {
4677                name: None,
4678                ty: i32_ty,
4679                init: global_expressions
4680                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4681            },
4682            Default::default(),
4683        );
4684
4685        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4686
4687        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4688        let mut solver = ConstantEvaluator {
4689            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4690            types: &mut types,
4691            constants: &constants,
4692            overrides: &overrides,
4693            expressions: &mut global_expressions,
4694            expression_kind_tracker,
4695            layouter: &mut crate::proc::Layouter::default(),
4696        };
4697
4698        let solved_compose = solver
4699            .try_eval_and_append(
4700                Expression::Compose {
4701                    ty: vec2_i32_ty,
4702                    components: vec![h_expr, h_expr],
4703                },
4704                Default::default(),
4705            )
4706            .unwrap();
4707        let solved_negate = solver
4708            .try_eval_and_append(
4709                Expression::Unary {
4710                    op: UnaryOperator::Negate,
4711                    expr: solved_compose,
4712                },
4713                Default::default(),
4714            )
4715            .unwrap();
4716
4717        let pass = match global_expressions[solved_negate] {
4718            Expression::Compose { ty, ref components } => {
4719                ty == vec2_i32_ty
4720                    && components.iter().all(|&component| {
4721                        let component = &global_expressions[component];
4722                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4723                    })
4724            }
4725            _ => false,
4726        };
4727        if !pass {
4728            panic!("unexpected evaluation result")
4729        }
4730    }
4731
4732    #[test]
4733    fn splat_of_constant() {
4734        let mut types = UniqueArena::new();
4735        let mut constants = Arena::new();
4736        let overrides = Arena::new();
4737        let mut global_expressions = Arena::new();
4738
4739        let i32_ty = types.insert(
4740            Type {
4741                name: None,
4742                inner: TypeInner::Scalar(crate::Scalar::I32),
4743            },
4744            Default::default(),
4745        );
4746
4747        let vec2_i32_ty = types.insert(
4748            Type {
4749                name: None,
4750                inner: TypeInner::Vector {
4751                    size: VectorSize::Bi,
4752                    scalar: crate::Scalar::I32,
4753                },
4754            },
4755            Default::default(),
4756        );
4757
4758        let h = constants.append(
4759            Constant {
4760                name: None,
4761                ty: i32_ty,
4762                init: global_expressions
4763                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4764            },
4765            Default::default(),
4766        );
4767
4768        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4769
4770        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4771        let mut solver = ConstantEvaluator {
4772            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4773            types: &mut types,
4774            constants: &constants,
4775            overrides: &overrides,
4776            expressions: &mut global_expressions,
4777            expression_kind_tracker,
4778            layouter: &mut crate::proc::Layouter::default(),
4779        };
4780
4781        let solved_compose = solver
4782            .try_eval_and_append(
4783                Expression::Splat {
4784                    size: VectorSize::Bi,
4785                    value: h_expr,
4786                },
4787                Default::default(),
4788            )
4789            .unwrap();
4790        let solved_negate = solver
4791            .try_eval_and_append(
4792                Expression::Unary {
4793                    op: UnaryOperator::Negate,
4794                    expr: solved_compose,
4795                },
4796                Default::default(),
4797            )
4798            .unwrap();
4799
4800        let pass = match global_expressions[solved_negate] {
4801            Expression::Compose { ty, ref components } => {
4802                ty == vec2_i32_ty
4803                    && components.iter().all(|&component| {
4804                        let component = &global_expressions[component];
4805                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4806                    })
4807            }
4808            _ => false,
4809        };
4810        if !pass {
4811            panic!("unexpected evaluation result")
4812        }
4813    }
4814
4815    #[test]
4816    fn splat_of_zero_value() {
4817        let mut types = UniqueArena::new();
4818        let constants = Arena::new();
4819        let overrides = Arena::new();
4820        let mut global_expressions = Arena::new();
4821
4822        let f32_ty = types.insert(
4823            Type {
4824                name: None,
4825                inner: TypeInner::Scalar(crate::Scalar::F32),
4826            },
4827            Default::default(),
4828        );
4829
4830        let vec2_f32_ty = types.insert(
4831            Type {
4832                name: None,
4833                inner: TypeInner::Vector {
4834                    size: VectorSize::Bi,
4835                    scalar: crate::Scalar::F32,
4836                },
4837            },
4838            Default::default(),
4839        );
4840
4841        let five =
4842            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4843        let five_splat = global_expressions.append(
4844            Expression::Splat {
4845                size: VectorSize::Bi,
4846                value: five,
4847            },
4848            Default::default(),
4849        );
4850        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4851        let zero_splat = global_expressions.append(
4852            Expression::Splat {
4853                size: VectorSize::Bi,
4854                value: zero,
4855            },
4856            Default::default(),
4857        );
4858
4859        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4860        let mut solver = ConstantEvaluator {
4861            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4862            types: &mut types,
4863            constants: &constants,
4864            overrides: &overrides,
4865            expressions: &mut global_expressions,
4866            expression_kind_tracker,
4867            layouter: &mut crate::proc::Layouter::default(),
4868        };
4869
4870        let solved_add = solver
4871            .try_eval_and_append(
4872                Expression::Binary {
4873                    op: BinaryOperator::Add,
4874                    left: zero_splat,
4875                    right: five_splat,
4876                },
4877                Default::default(),
4878            )
4879            .unwrap();
4880
4881        let pass = match global_expressions[solved_add] {
4882            Expression::Compose { ty, ref components } => {
4883                ty == vec2_f32_ty
4884                    && components.iter().all(|&component| {
4885                        let component = &global_expressions[component];
4886                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4887                    })
4888            }
4889            _ => false,
4890        };
4891        if !pass {
4892            panic!("unexpected evaluation result")
4893        }
4894    }
4895}