naga/proc/
constant_evaluator.rs

1// Code in this file intentionally uses `for` loops and `.push()` rather than
2// `ArrayVec::from_iter`, because the latter is monomorphized by all three of
3// the item type, the capacity, and the iterator type, which can easily bloat
4// the compiled executable (by ~260 KiB, when it was removed).
5
6use alloc::{
7    format,
8    string::{String, ToString},
9    vec,
10    vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19    arena::{Arena, Handle, HandleVec, UniqueArena},
20    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21    ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
28/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
29///
30/// Technique stolen directly from
31/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
32macro_rules! with_dollar_sign {
33    ($($body:tt)*) => {
34        macro_rules! __with_dollar_sign { $($body)* }
35        __with_dollar_sign!($);
36    }
37}
38
39macro_rules! gen_component_wise_extractor {
40    (
41        $ident:ident -> $target:ident,
42        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44    ) => {
45        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
46        #[derive(Debug)]
47        #[cfg_attr(test, derive(PartialEq))]
48        enum $target<const N: usize> {
49            $(
50                #[doc = concat!(
51                    "Maps to [`Literal::",
52                    stringify!($literal),
53                    "`]",
54                )]
55                $mapping([$ty; N]),
56            )+
57        }
58
59        impl From<$target<1>> for Expression {
60            fn from(value: $target<1>) -> Self {
61                match value {
62                    $(
63                        $target::$mapping([value]) => {
64                            Expression::Literal(Literal::$literal(value))
65                        }
66                    )+
67                }
68            }
69        }
70
71        #[doc = concat!(
72            "Attempts to evaluate multiple `exprs` as a combined [`",
73            stringify!($target),
74            "`] to pass to `handler`. ",
75        )]
76        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
77        /// component of each vector.
78        ///
79        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
80        /// same length, a new vector expression is registered, composed of each component emitted
81        /// by `handler`.
82        fn $ident<const N: usize, const M: usize>(
83            eval: &mut ConstantEvaluator<'_>,
84            span: Span,
85            exprs: [Handle<Expression>; N],
86            handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88        where
89            $target<M>: Into<Expression>,
90        {
91            assert!(N > 0);
92            let err = ConstantEvaluatorError::InvalidMathArg;
93            let mut exprs = exprs.into_iter();
94
95            macro_rules! sanitize {
96                ($expr:expr) => {
97                    eval.eval_zero_value_and_splat($expr, span)
98                        .map(|expr| &eval.expressions[expr])
99                };
100            }
101
102            let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103                $(
104                    &Expression::Literal(Literal::$literal(x)) => {
105                        let mut arr = ArrayVec::<_, N>::new();
106                        arr.push(x);
107                        for expr in exprs {
108                            match sanitize!(expr)? {
109                                &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110                                _ => return Err(err),
111                            }
112                        }
113                        let comps = $target::$mapping(arr.into_inner().unwrap());
114                        Ok(handler(comps)?.into())
115                    },
116                )+
117                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118                    &TypeInner::Vector { size, scalar } => match scalar.kind {
119                        $(ScalarKind::$scalar_kind)|* => {
120                            let first_ty = ty;
121                            let mut component_groups =
122                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123                            {
124                                let mut inner = ArrayVec::new();
125                                for item in crate::proc::flatten_compose(
126                                    first_ty,
127                                    components,
128                                    eval.expressions,
129                                    eval.types,
130                                ) {
131                                    inner.push(item);
132                                }
133                                component_groups.push(inner);
134                            }
135                            for expr in exprs {
136                                match sanitize!(expr)? {
137                                    &Expression::Compose { ty, ref components }
138                                        if &eval.types[ty].inner
139                                            == &eval.types[first_ty].inner =>
140                                    {
141                                        let mut inner = ArrayVec::new();
142                                        for item in crate::proc::flatten_compose(
143                                            ty,
144                                            components,
145                                            eval.expressions,
146                                            eval.types,
147                                        ) {
148                                            inner.push(item);
149                                        }
150                                        component_groups.push(inner);
151                                    }
152                                    _ => return Err(err),
153                                }
154                            }
155                            let component_groups = component_groups.into_inner().unwrap();
156                            let mut new_components =
157                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158                            for idx in 0..(size as u8).into() {
159                                let mut group_arr = ArrayVec::<_, N>::new();
160                                for cs in component_groups.iter() {
161                                    group_arr.push(
162                                        cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163                                    );
164                                }
165                                let group = group_arr.into_inner().unwrap();
166                                new_components.push($ident(
167                                    eval,
168                                    span,
169                                    group,
170                                    handler,
171                                )?);
172                            }
173                            Ok(Expression::Compose {
174                                ty: first_ty,
175                                components: new_components.into_iter().collect(),
176                            })
177                        }
178                        _ => return Err(err),
179                    },
180                    _ => return Err(err),
181                },
182                _ => return Err(err),
183            };
184            eval.register_evaluated_expr(new_expr?, span)
185        }
186
187        with_dollar_sign! {
188            ($d:tt) => {
189                #[allow(unused)]
190                #[doc = concat!(
191                    "A convenience macro for using the same RHS for each [`",
192                    stringify!($target),
193                    "`] variant in a call to [`",
194                    stringify!($ident),
195                    "`].",
196                )]
197                macro_rules! $ident {
198                    (
199                        $eval:expr,
200                        $span:expr,
201                        [$d ($d expr:expr),+ $d (,)?],
202                        |$d ($d arg:ident),+| $d tt:tt
203                    ) => {
204                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
205                            $(
206                                $target::$mapping([$d ($d arg),+]) => {
207                                    let res = $d tt;
208                                    Result::map(res, $target::$mapping)
209                                },
210                            )+
211                        })
212                    };
213                }
214            };
215        }
216    };
217}
218
219gen_component_wise_extractor! {
220    component_wise_scalar -> Scalar,
221    literals: [
222        AbstractFloat => AbstractFloat: f64,
223        F32 => F32: f32,
224        F16 => F16: f16,
225        AbstractInt => AbstractInt: i64,
226        U32 => U32: u32,
227        I32 => I32: i32,
228        U64 => U64: u64,
229        I64 => I64: i64,
230    ],
231    scalar_kinds: [
232        Float,
233        AbstractFloat,
234        Sint,
235        Uint,
236        AbstractInt,
237    ],
238}
239
240gen_component_wise_extractor! {
241    component_wise_float -> Float,
242    literals: [
243        AbstractFloat => Abstract: f64,
244        F32 => F32: f32,
245        F16 => F16: f16,
246    ],
247    scalar_kinds: [
248        Float,
249        AbstractFloat,
250    ],
251}
252
253gen_component_wise_extractor! {
254    component_wise_concrete_int -> ConcreteInt,
255    literals: [
256        U32 => U32: u32,
257        I32 => I32: i32,
258    ],
259    scalar_kinds: [
260        Sint,
261        Uint,
262    ],
263}
264
265gen_component_wise_extractor! {
266    component_wise_signed -> Signed,
267    literals: [
268        AbstractFloat => AbstractFloat: f64,
269        AbstractInt => AbstractInt: i64,
270        F32 => F32: f32,
271        F16 => F16: f16,
272        I32 => I32: i32,
273    ],
274    scalar_kinds: [
275        Sint,
276        AbstractInt,
277        Float,
278        AbstractFloat,
279    ],
280}
281
282/// Vectors with a concrete element type.
283#[derive(Debug)]
284enum LiteralVector {
285    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
289    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
290    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
291    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
292    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
293    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
294    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
295}
296
297impl LiteralVector {
298    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
299    fn len(&self) -> usize {
300        match *self {
301            LiteralVector::F64(ref v) => v.len(),
302            LiteralVector::F32(ref v) => v.len(),
303            LiteralVector::F16(ref v) => v.len(),
304            LiteralVector::U32(ref v) => v.len(),
305            LiteralVector::I32(ref v) => v.len(),
306            LiteralVector::U64(ref v) => v.len(),
307            LiteralVector::I64(ref v) => v.len(),
308            LiteralVector::Bool(ref v) => v.len(),
309            LiteralVector::AbstractInt(ref v) => v.len(),
310            LiteralVector::AbstractFloat(ref v) => v.len(),
311        }
312    }
313
314    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
315    fn from_literal(literal: Literal) -> Self {
316        fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
317            let mut v = ArrayVec::new();
318            v.push(val);
319            v
320        }
321        match literal {
322            Literal::F64(e) => Self::F64(arrayvec_of(e)),
323            Literal::F32(e) => Self::F32(arrayvec_of(e)),
324            Literal::U32(e) => Self::U32(arrayvec_of(e)),
325            Literal::I32(e) => Self::I32(arrayvec_of(e)),
326            Literal::U64(e) => Self::U64(arrayvec_of(e)),
327            Literal::I64(e) => Self::I64(arrayvec_of(e)),
328            Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
329            Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
330            Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
331            Literal::F16(e) => Self::F16(arrayvec_of(e)),
332        }
333    }
334
335    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
336    /// Returns error if components types do not match.
337    /// # Panics
338    /// Panics if vector is empty
339    fn from_literal_vec(
340        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
341    ) -> Result<Self, ConstantEvaluatorError> {
342        assert!(!components.is_empty());
343        // TODO: should a vector of i32 be constructible from abstract int?
344        macro_rules! compose_literals {
345            ($components:expr, $variant:ident, $self_variant:ident) => {{
346                let mut out = ArrayVec::new();
347                for l in &$components {
348                    match l {
349                        &Literal::$variant(v) => out.push(v),
350                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
351                    }
352                }
353                Self::$self_variant(out)
354            }};
355        }
356        Ok(match components[0] {
357            Literal::I32(_) => compose_literals!(components, I32, I32),
358            Literal::U32(_) => compose_literals!(components, U32, U32),
359            Literal::I64(_) => compose_literals!(components, I64, I64),
360            Literal::U64(_) => compose_literals!(components, U64, U64),
361            Literal::F32(_) => compose_literals!(components, F32, F32),
362            Literal::F64(_) => compose_literals!(components, F64, F64),
363            Literal::Bool(_) => compose_literals!(components, Bool, Bool),
364            Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
365            Literal::AbstractFloat(_) => {
366                compose_literals!(components, AbstractFloat, AbstractFloat)
367            }
368            Literal::F16(_) => compose_literals!(components, F16, F16),
369        })
370    }
371
372    #[allow(dead_code)]
373    /// Returns [`ArrayVec`] of [`Literal`]s
374    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
375        macro_rules! decompose_literals {
376            ($v:expr, $variant:ident) => {{
377                let mut out = ArrayVec::new();
378                for e in $v {
379                    out.push(Literal::$variant(*e));
380                }
381                out
382            }};
383        }
384        match *self {
385            LiteralVector::F64(ref v) => decompose_literals!(v, F64),
386            LiteralVector::F32(ref v) => decompose_literals!(v, F32),
387            LiteralVector::F16(ref v) => decompose_literals!(v, F16),
388            LiteralVector::U32(ref v) => decompose_literals!(v, U32),
389            LiteralVector::I32(ref v) => decompose_literals!(v, I32),
390            LiteralVector::U64(ref v) => decompose_literals!(v, U64),
391            LiteralVector::I64(ref v) => decompose_literals!(v, I64),
392            LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
393            LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
394            LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
395        }
396    }
397
398    #[allow(dead_code)]
399    /// Puts self into eval's expressions arena and returns handle to it
400    fn register_as_evaluated_expr(
401        &self,
402        eval: &mut ConstantEvaluator<'_>,
403        span: Span,
404    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
405        let lit_vec = self.to_literal_vec();
406        assert!(!lit_vec.is_empty());
407        let expr = if lit_vec.len() == 1 {
408            Expression::Literal(lit_vec[0])
409        } else {
410            Expression::Compose {
411                ty: eval.types.insert(
412                    Type {
413                        name: None,
414                        inner: TypeInner::Vector {
415                            size: match lit_vec.len() {
416                                2 => crate::VectorSize::Bi,
417                                3 => crate::VectorSize::Tri,
418                                4 => crate::VectorSize::Quad,
419                                _ => unreachable!(),
420                            },
421                            scalar: lit_vec[0].scalar(),
422                        },
423                    },
424                    Span::UNDEFINED,
425                ),
426                components: lit_vec
427                    .iter()
428                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
429                    .collect::<Result<_, _>>()?,
430            }
431        };
432        eval.register_evaluated_expr(expr, span)
433    }
434}
435
436/// A macro for matching on [`LiteralVector`] variants.
437///
438/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
439/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
440///
441/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
442///
443/// Example usage:
444///
445/// ```rust,ignore
446/// match_literal_vector!(match v => Literal {
447///     F16 => |v| {v.sum()},
448///     Integer => |v| {v.sum()},
449///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
450/// })
451/// ```
452///
453/// ```rust,ignore
454/// match_literal_vector!(match (e1, e2) => LiteralVector {
455///     F16 => |e1, e2| {e1+e2},
456///     Integer => |e1, e2| {e1+e2},
457///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
458/// })
459/// ```
460macro_rules! match_literal_vector {
461    (match $lit_vec:expr => $out:ident {
462        $(
463            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
464        ),+
465        $(,)?
466    }) => {
467        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
468    };
469
470    (@inner_start
471        $lit_vec:expr;
472        $out:ident;
473        [$($ty:ident),+];
474        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
475    ) => {
476        match_literal_vector!(@inner
477            $lit_vec;
478            $out;
479            [$($ty),+];
480            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
481        )
482    };
483
484    (@inner
485        $lit_vec:expr;
486        $out:ident;
487        [$ty:ident $(, $ty1:ident)*];
488        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
489        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
490    ) => {
491        match_literal_vector!(@inner
492            $ty;
493            $lit_vec;
494            $out;
495            [$($ty1),*];
496            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
497            [$({ $($var),+ ; $($ret)? ; $body }),+]
498        )
499    };
500    (@inner
501        Integer;
502        $lit_vec:expr;
503        $out:ident;
504        [$($ty:ident),*];
505        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
506        [
507            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
508            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
509        ]
510    ) => {
511        match_literal_vector!(@inner
512            $lit_vec;
513            $out;
514            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
515            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
516            [
517                { $($var),+ ; $($ret)? ; $body }, // U32
518                { $($var),+ ; $($ret)? ; $body }, // I32
519                { $($var),+ ; $($ret)? ; $body }, // U64
520                { $($var),+ ; $($ret)? ; $body }, // I64
521                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
522                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
523            ]
524        )
525    };
526    (@inner
527        Float;
528        $lit_vec:expr;
529        $out:ident;
530        [$($ty:ident),*];
531        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
532        [
533            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
534            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
535        ]
536    ) => {
537        match_literal_vector!(@inner
538            $lit_vec;
539            $out;
540            [F16, F32, F64, AbstractFloat $(, $ty)*];
541            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542            [
543                { $($var),+ ; $($ret)? ; $body }, // F16
544                { $($var),+ ; $($ret)? ; $body }, // F32
545                { $($var),+ ; $($ret)? ; $body }, // F64
546                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
547                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
548            ]
549        )
550    };
551    (@inner
552        $ty:ident;
553        $lit_vec:expr;
554        $out:ident;
555        [$ty1:ident $(,$ty2:ident)*];
556        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
557            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
558            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
559        ]
560    ) => {
561        match_literal_vector!(@inner
562            $ty1;
563            $lit_vec;
564            $out;
565            [$($ty2),*];
566            [
567                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
568                { $ty; $($var),+ ; $($ret)? ; $body }
569            ] <>
570            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
571
572        )
573    };
574    (@inner
575        $ty:ident;
576        $lit_vec:expr;
577        $out:ident;
578        [];
579        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
580        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
581    ) => {
582        match_literal_vector!(@inner_finish
583            $lit_vec;
584            $out;
585            [
586                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
587                { $ty; $($var),+ ; $($ret)? ; $body }
588            ]
589        )
590    };
591    (@inner_finish
592        $lit_vec:expr;
593        $out:ident;
594        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
595    ) => {
596        match $lit_vec {
597            $(
598                #[allow(unused_parens)]
599                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
600            )+
601            _ => Err(ConstantEvaluatorError::InvalidMathArg),
602        }
603    };
604    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
605        $out::$ty($body)
606    };
607    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
608        $out::$ret($body)
609    };
610}
611
612fn float_length<F>(e: &[F]) -> Option<F>
613where
614    F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
615{
616    if e.len() == 1 {
617        // Avoids possible overflow in squaring
618        Some(e[0].abs())
619    } else {
620        let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
621        result.is_finite().then_some(result)
622    }
623}
624
625#[derive(Debug)]
626enum Behavior<'a> {
627    Wgsl(WgslRestrictions<'a>),
628    Glsl(GlslRestrictions<'a>),
629}
630
631impl Behavior<'_> {
632    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
633    const fn has_runtime_restrictions(&self) -> bool {
634        matches!(
635            self,
636            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
637                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
638        )
639    }
640}
641
642/// A context for evaluating constant expressions.
643///
644/// A `ConstantEvaluator` points at an expression arena to which it can append
645/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
646/// of Naga [`Expression`] you like, and if its value can be computed at compile
647/// time, `try_eval_and_append` appends an expression representing the computed
648/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
649/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
650///
651/// A `ConstantEvaluator` also holds whatever information we need to carry out
652/// that evaluation: types, other constants, and so on.
653///
654/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
655/// [`Compose`]: Expression::Compose
656/// [`ZeroValue`]: Expression::ZeroValue
657/// [`Literal`]: Expression::Literal
658/// [`Swizzle`]: Expression::Swizzle
659#[derive(Debug)]
660pub struct ConstantEvaluator<'a> {
661    /// Which language's evaluation rules we should follow.
662    behavior: Behavior<'a>,
663
664    /// The module's type arena.
665    ///
666    /// Because expressions like [`Splat`] contain type handles, we need to be
667    /// able to add new types to produce those expressions.
668    ///
669    /// [`Splat`]: Expression::Splat
670    types: &'a mut UniqueArena<Type>,
671
672    /// The module's constant arena.
673    constants: &'a Arena<Constant>,
674
675    /// The module's override arena.
676    overrides: &'a Arena<Override>,
677
678    /// The arena to which we are contributing expressions.
679    expressions: &'a mut Arena<Expression>,
680
681    /// Tracks the constness of expressions residing in [`Self::expressions`]
682    expression_kind_tracker: &'a mut ExpressionKindTracker,
683
684    layouter: &'a mut crate::proc::Layouter,
685}
686
687#[derive(Debug)]
688enum WgslRestrictions<'a> {
689    /// - const-expressions will be evaluated and inserted in the arena
690    Const(Option<FunctionLocalData<'a>>),
691    /// - const-expressions will be evaluated and inserted in the arena
692    /// - override-expressions will be inserted in the arena
693    Override,
694    /// - const-expressions will be evaluated and inserted in the arena
695    /// - override-expressions will be inserted in the arena
696    /// - runtime-expressions will be inserted in the arena
697    Runtime(FunctionLocalData<'a>),
698}
699
700#[derive(Debug)]
701enum GlslRestrictions<'a> {
702    /// - const-expressions will be evaluated and inserted in the arena
703    Const,
704    /// - const-expressions will be evaluated and inserted in the arena
705    /// - override-expressions will be inserted in the arena
706    /// - runtime-expressions will be inserted in the arena
707    Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711struct FunctionLocalData<'a> {
712    /// Global constant expressions
713    global_expressions: &'a Arena<Expression>,
714    emitter: &'a mut super::Emitter,
715    block: &'a mut crate::Block,
716}
717
718#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
719pub enum ExpressionKind {
720    Const,
721    Override,
722    Runtime,
723}
724
725#[derive(Debug)]
726pub struct ExpressionKindTracker {
727    inner: HandleVec<Expression, ExpressionKind>,
728}
729
730impl ExpressionKindTracker {
731    pub const fn new() -> Self {
732        Self {
733            inner: HandleVec::new(),
734        }
735    }
736
737    /// Forces the the expression to not be const
738    pub fn force_non_const(&mut self, value: Handle<Expression>) {
739        self.inner[value] = ExpressionKind::Runtime;
740    }
741
742    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
743        self.inner.insert(value, expr_type);
744    }
745
746    pub fn is_const(&self, h: Handle<Expression>) -> bool {
747        matches!(self.type_of(h), ExpressionKind::Const)
748    }
749
750    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
751        matches!(
752            self.type_of(h),
753            ExpressionKind::Const | ExpressionKind::Override
754        )
755    }
756
757    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
758        self.inner[value]
759    }
760
761    pub fn from_arena(arena: &Arena<Expression>) -> Self {
762        let mut tracker = Self {
763            inner: HandleVec::with_capacity(arena.len()),
764        };
765        for (handle, expr) in arena.iter() {
766            tracker
767                .inner
768                .insert(handle, tracker.type_of_with_expr(expr));
769        }
770        tracker
771    }
772
773    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
774        match *expr {
775            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
776                ExpressionKind::Const
777            }
778            Expression::Override(_) => ExpressionKind::Override,
779            Expression::Compose { ref components, .. } => {
780                let mut expr_type = ExpressionKind::Const;
781                for component in components {
782                    expr_type = expr_type.max(self.type_of(*component))
783                }
784                expr_type
785            }
786            Expression::Splat { value, .. } => self.type_of(value),
787            Expression::AccessIndex { base, .. } => self.type_of(base),
788            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
789            Expression::Swizzle { vector, .. } => self.type_of(vector),
790            Expression::Unary { expr, .. } => self.type_of(expr),
791            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
792            Expression::Math {
793                arg,
794                arg1,
795                arg2,
796                arg3,
797                ..
798            } => self
799                .type_of(arg)
800                .max(
801                    arg1.map(|arg| self.type_of(arg))
802                        .unwrap_or(ExpressionKind::Const),
803                )
804                .max(
805                    arg2.map(|arg| self.type_of(arg))
806                        .unwrap_or(ExpressionKind::Const),
807                )
808                .max(
809                    arg3.map(|arg| self.type_of(arg))
810                        .unwrap_or(ExpressionKind::Const),
811                ),
812            Expression::As { expr, .. } => self.type_of(expr),
813            Expression::Select {
814                condition,
815                accept,
816                reject,
817            } => self
818                .type_of(condition)
819                .max(self.type_of(accept))
820                .max(self.type_of(reject)),
821            Expression::Relational { argument, .. } => self.type_of(argument),
822            Expression::ArrayLength(expr) => self.type_of(expr),
823            _ => ExpressionKind::Runtime,
824        }
825    }
826}
827
828#[derive(Clone, Debug, thiserror::Error)]
829#[cfg_attr(test, derive(PartialEq))]
830pub enum ConstantEvaluatorError {
831    #[error("Constants cannot access function arguments")]
832    FunctionArg,
833    #[error("Constants cannot access global variables")]
834    GlobalVariable,
835    #[error("Constants cannot access local variables")]
836    LocalVariable,
837    #[error("Cannot get the array length of a non array type")]
838    InvalidArrayLengthArg,
839    #[error("Constants cannot get the array length of a dynamically sized array")]
840    ArrayLengthDynamic,
841    #[error("Cannot call arrayLength on array sized by override-expression")]
842    ArrayLengthOverridden,
843    #[error("Constants cannot call functions")]
844    Call,
845    #[error("Constants don't support workGroupUniformLoad")]
846    WorkGroupUniformLoadResult,
847    #[error("Constants don't support atomic functions")]
848    Atomic,
849    #[error("Constants don't support derivative functions")]
850    Derivative,
851    #[error("Constants don't support load expressions")]
852    Load,
853    #[error("Constants don't support image expressions")]
854    ImageExpression,
855    #[error("Constants don't support ray query expressions")]
856    RayQueryExpression,
857    #[error("Constants don't support subgroup expressions")]
858    SubgroupExpression,
859    #[error("Cannot access the type")]
860    InvalidAccessBase,
861    #[error("Cannot access at the index")]
862    InvalidAccessIndex,
863    #[error("Cannot access with index of type")]
864    InvalidAccessIndexTy,
865    #[error("Constants don't support array length expressions")]
866    ArrayLength,
867    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
868    InvalidCastArg { from: String, to: String },
869    #[error("Cannot apply the unary op to the argument")]
870    InvalidUnaryOpArg,
871    #[error("Cannot apply the binary op to the arguments")]
872    InvalidBinaryOpArgs,
873    #[error("Cannot apply math function to type")]
874    InvalidMathArg,
875    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
876    InvalidMathArgCount(crate::MathFunction, usize, usize),
877    #[error("{0} built-in function argument is out of valid range")]
878    InvalidMathArgValue(String),
879    #[error("Cannot apply relational function to type")]
880    InvalidRelationalArg(RelationalFunction),
881    #[error("value of `low` is greater than `high` for clamp built-in function")]
882    InvalidClamp,
883    #[error("Constructor expects {expected} components, found {actual}")]
884    InvalidVectorComposeLength { expected: usize, actual: usize },
885    #[error("Constructor must only contain vector or scalar arguments")]
886    InvalidVectorComposeComponent,
887    #[error("Splat is defined only on scalar values")]
888    SplatScalarOnly,
889    #[error("Can only swizzle vector constants")]
890    SwizzleVectorOnly,
891    #[error("swizzle component not present in source expression")]
892    SwizzleOutOfBounds,
893    #[error("Type is not constructible")]
894    TypeNotConstructible,
895    #[error("Subexpression(s) are not constant")]
896    SubexpressionsAreNotConstant,
897    #[error("Not implemented as constant expression: {0}")]
898    NotImplemented(String),
899    #[error("{0} operation overflowed")]
900    Overflow(String),
901    #[error(
902        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
903    )]
904    AutomaticConversionLossy {
905        value: String,
906        to_type: &'static str,
907    },
908    #[error("Division by zero")]
909    DivisionByZero,
910    #[error("Remainder by zero")]
911    RemainderByZero,
912    #[error("RHS of shift operation is greater than or equal to 32")]
913    ShiftedMoreThan32Bits,
914    #[error(transparent)]
915    Literal(#[from] crate::valid::LiteralError),
916    #[error("Can't use pipeline-overridable constants in const-expressions")]
917    Override,
918    #[error("Unexpected runtime-expression")]
919    RuntimeExpr,
920    #[error("Unexpected override-expression")]
921    OverrideExpr,
922    #[error("Expected boolean expression for condition argument of `select`, got something else")]
923    SelectScalarConditionNotABool,
924    #[error(
925        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
926        reject,
927        accept
928    )]
929    SelectVecRejectAcceptSizeMismatch {
930        reject: crate::VectorSize,
931        accept: crate::VectorSize,
932    },
933    #[error("Expected boolean vector for condition arg., got something else")]
934    SelectConditionNotAVecBool,
935    #[error(
936        "Expected same number of vector components between condition, accept, and reject args., got something else",
937    )]
938    SelectConditionVecSizeMismatch,
939    #[error(
940        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
941    )]
942    SelectAcceptRejectTypeMismatch,
943    #[error("Cooperative operations can't be constant")]
944    CooperativeOperation,
945}
946
947impl<'a> ConstantEvaluator<'a> {
948    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
949    /// constant expression arena.
950    ///
951    /// Report errors according to WGSL's rules for constant evaluation.
952    pub const fn for_wgsl_module(
953        module: &'a mut crate::Module,
954        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
955        layouter: &'a mut crate::proc::Layouter,
956        in_override_ctx: bool,
957    ) -> Self {
958        Self::for_module(
959            Behavior::Wgsl(if in_override_ctx {
960                WgslRestrictions::Override
961            } else {
962                WgslRestrictions::Const(None)
963            }),
964            module,
965            global_expression_kind_tracker,
966            layouter,
967        )
968    }
969
970    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
971    /// constant expression arena.
972    ///
973    /// Report errors according to GLSL's rules for constant evaluation.
974    pub const fn for_glsl_module(
975        module: &'a mut crate::Module,
976        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
977        layouter: &'a mut crate::proc::Layouter,
978    ) -> Self {
979        Self::for_module(
980            Behavior::Glsl(GlslRestrictions::Const),
981            module,
982            global_expression_kind_tracker,
983            layouter,
984        )
985    }
986
987    const fn for_module(
988        behavior: Behavior<'a>,
989        module: &'a mut crate::Module,
990        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
991        layouter: &'a mut crate::proc::Layouter,
992    ) -> Self {
993        Self {
994            behavior,
995            types: &mut module.types,
996            constants: &module.constants,
997            overrides: &module.overrides,
998            expressions: &mut module.global_expressions,
999            expression_kind_tracker: global_expression_kind_tracker,
1000            layouter,
1001        }
1002    }
1003
1004    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1005    /// expression arena.
1006    ///
1007    /// Report errors according to WGSL's rules for constant evaluation.
1008    pub const fn for_wgsl_function(
1009        module: &'a mut crate::Module,
1010        expressions: &'a mut Arena<Expression>,
1011        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1012        layouter: &'a mut crate::proc::Layouter,
1013        emitter: &'a mut super::Emitter,
1014        block: &'a mut crate::Block,
1015        is_const: bool,
1016    ) -> Self {
1017        let local_data = FunctionLocalData {
1018            global_expressions: &module.global_expressions,
1019            emitter,
1020            block,
1021        };
1022        Self {
1023            behavior: Behavior::Wgsl(if is_const {
1024                WgslRestrictions::Const(Some(local_data))
1025            } else {
1026                WgslRestrictions::Runtime(local_data)
1027            }),
1028            types: &mut module.types,
1029            constants: &module.constants,
1030            overrides: &module.overrides,
1031            expressions,
1032            expression_kind_tracker: local_expression_kind_tracker,
1033            layouter,
1034        }
1035    }
1036
1037    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1038    /// expression arena.
1039    ///
1040    /// Report errors according to GLSL's rules for constant evaluation.
1041    pub const fn for_glsl_function(
1042        module: &'a mut crate::Module,
1043        expressions: &'a mut Arena<Expression>,
1044        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1045        layouter: &'a mut crate::proc::Layouter,
1046        emitter: &'a mut super::Emitter,
1047        block: &'a mut crate::Block,
1048    ) -> Self {
1049        Self {
1050            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1051                global_expressions: &module.global_expressions,
1052                emitter,
1053                block,
1054            })),
1055            types: &mut module.types,
1056            constants: &module.constants,
1057            overrides: &module.overrides,
1058            expressions,
1059            expression_kind_tracker: local_expression_kind_tracker,
1060            layouter,
1061        }
1062    }
1063
1064    pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1065        crate::proc::GlobalCtx {
1066            types: self.types,
1067            constants: self.constants,
1068            overrides: self.overrides,
1069            global_expressions: match self.function_local_data() {
1070                Some(data) => data.global_expressions,
1071                None => self.expressions,
1072            },
1073        }
1074    }
1075
1076    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1077        if !self.expression_kind_tracker.is_const(expr) {
1078            log::debug!("check: SubexpressionsAreNotConstant");
1079            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1080        }
1081        Ok(())
1082    }
1083
1084    fn check_and_get(
1085        &mut self,
1086        expr: Handle<Expression>,
1087    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1088        match self.expressions[expr] {
1089            Expression::Constant(c) => {
1090                // Are we working in a function's expression arena, or the
1091                // module's constant expression arena?
1092                if let Some(function_local_data) = self.function_local_data() {
1093                    // Deep-copy the constant's value into our arena.
1094                    self.copy_from(
1095                        self.constants[c].init,
1096                        function_local_data.global_expressions,
1097                    )
1098                } else {
1099                    // "See through" the constant and use its initializer.
1100                    Ok(self.constants[c].init)
1101                }
1102            }
1103            _ => {
1104                self.check(expr)?;
1105                Ok(expr)
1106            }
1107        }
1108    }
1109
1110    /// Try to evaluate `expr` at compile time.
1111    ///
1112    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1113    /// we can determine its value at compile time, we append an expression
1114    /// representing its value - a tree of [`Literal`], [`Compose`],
1115    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1116    /// `self` contributes to.
1117    ///
1118    /// If `expr`'s value cannot be determined at compile time, and `self` is
1119    /// contributing to some function's expression arena, then append `expr` to
1120    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1121    /// contributing to the module's constant expression arena; since `expr`'s
1122    /// value is not a constant, return an error.
1123    ///
1124    /// We only consider `expr` itself, without recursing into its operands. Its
1125    /// operands must all have been produced by prior calls to
1126    /// `try_eval_and_append`, to ensure that they have already been reduced to
1127    /// an evaluated form if possible.
1128    ///
1129    /// [`Literal`]: Expression::Literal
1130    /// [`Compose`]: Expression::Compose
1131    /// [`ZeroValue`]: Expression::ZeroValue
1132    /// [`Swizzle`]: Expression::Swizzle
1133    pub fn try_eval_and_append(
1134        &mut self,
1135        expr: Expression,
1136        span: Span,
1137    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1138        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1139            ExpressionKind::Const => {
1140                let eval_result = self.try_eval_and_append_impl(&expr, span);
1141                // We should be able to evaluate `Const` expressions at this
1142                // point. If we failed to, then that probably means we just
1143                // haven't implemented that part of constant evaluation. Work
1144                // around this by simply emitting it as a run-time expression.
1145                if self.behavior.has_runtime_restrictions()
1146                    && matches!(
1147                        eval_result,
1148                        Err(ConstantEvaluatorError::NotImplemented(_)
1149                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1150                    )
1151                {
1152                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1153                } else {
1154                    eval_result
1155                }
1156            }
1157            ExpressionKind::Override => match self.behavior {
1158                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1159                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1160                }
1161                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1162                    Err(ConstantEvaluatorError::OverrideExpr)
1163                }
1164
1165                // GLSL specialization constants (constant_id) become Override expressions
1166                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1167                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1168                }
1169                Behavior::Glsl(GlslRestrictions::Const) => {
1170                    Err(ConstantEvaluatorError::OverrideExpr)
1171                }
1172            },
1173            ExpressionKind::Runtime => {
1174                if self.behavior.has_runtime_restrictions() {
1175                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1176                } else {
1177                    Err(ConstantEvaluatorError::RuntimeExpr)
1178                }
1179            }
1180        }
1181    }
1182
1183    /// Is the [`Self::expressions`] arena the global module expression arena?
1184    const fn is_global_arena(&self) -> bool {
1185        matches!(
1186            self.behavior,
1187            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1188                | Behavior::Glsl(GlslRestrictions::Const)
1189        )
1190    }
1191
1192    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1193        match self.behavior {
1194            Behavior::Wgsl(
1195                WgslRestrictions::Runtime(ref function_local_data)
1196                | WgslRestrictions::Const(Some(ref function_local_data)),
1197            )
1198            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1199                Some(function_local_data)
1200            }
1201            _ => None,
1202        }
1203    }
1204
1205    fn try_eval_and_append_impl(
1206        &mut self,
1207        expr: &Expression,
1208        span: Span,
1209    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1210        log::trace!("try_eval_and_append: {expr:?}");
1211        match *expr {
1212            Expression::Constant(c) if self.is_global_arena() => {
1213                // "See through" the constant and use its initializer.
1214                // This is mainly done to avoid having constants pointing to other constants.
1215                Ok(self.constants[c].init)
1216            }
1217            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1218            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1219                self.register_evaluated_expr(expr.clone(), span)
1220            }
1221            Expression::Compose { ty, ref components } => {
1222                let components = components
1223                    .iter()
1224                    .map(|component| self.check_and_get(*component))
1225                    .collect::<Result<Vec<_>, _>>()?;
1226                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1227            }
1228            Expression::Splat { size, value } => {
1229                let value = self.check_and_get(value)?;
1230                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1231            }
1232            Expression::AccessIndex { base, index } => {
1233                let base = self.check_and_get(base)?;
1234
1235                self.access(base, index as usize, span)
1236            }
1237            Expression::Access { base, index } => {
1238                let base = self.check_and_get(base)?;
1239                let index = self.check_and_get(index)?;
1240
1241                let index_val: u32 = self
1242                    .to_ctx()
1243                    .get_const_val_from(index, self.expressions)
1244                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1245                self.access(base, index_val as usize, span)
1246            }
1247            Expression::Swizzle {
1248                size,
1249                vector,
1250                pattern,
1251            } => {
1252                let vector = self.check_and_get(vector)?;
1253
1254                self.swizzle(size, span, vector, pattern)
1255            }
1256            Expression::Unary { expr, op } => {
1257                let expr = self.check_and_get(expr)?;
1258
1259                self.unary_op(op, expr, span)
1260            }
1261            Expression::Binary { left, right, op } => {
1262                let left = self.check_and_get(left)?;
1263                let right = self.check_and_get(right)?;
1264
1265                self.binary_op(op, left, right, span)
1266            }
1267            Expression::Math {
1268                fun,
1269                arg,
1270                arg1,
1271                arg2,
1272                arg3,
1273            } => {
1274                let arg = self.check_and_get(arg)?;
1275                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1276                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1277                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1278
1279                self.math(arg, arg1, arg2, arg3, fun, span)
1280            }
1281            Expression::As {
1282                convert,
1283                expr,
1284                kind,
1285            } => {
1286                let expr = self.check_and_get(expr)?;
1287
1288                match convert {
1289                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1290                    None => Err(ConstantEvaluatorError::NotImplemented(
1291                        "bitcast built-in function".into(),
1292                    )),
1293                }
1294            }
1295            Expression::Select {
1296                reject,
1297                accept,
1298                condition,
1299            } => {
1300                let mut arg = |expr| self.check_and_get(expr);
1301
1302                let reject = arg(reject)?;
1303                let accept = arg(accept)?;
1304                let condition = arg(condition)?;
1305
1306                self.select(reject, accept, condition, span)
1307            }
1308            Expression::Relational { fun, argument } => {
1309                let argument = self.check_and_get(argument)?;
1310                self.relational(fun, argument, span)
1311            }
1312            Expression::ArrayLength(expr) => match self.behavior {
1313                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1314                Behavior::Glsl(_) => {
1315                    let expr = self.check_and_get(expr)?;
1316                    self.array_length(expr, span)
1317                }
1318            },
1319            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1320            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1321            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1322            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1323            Expression::WorkGroupUniformLoadResult { .. } => {
1324                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1325            }
1326            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1327            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1328            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1329            Expression::ImageSample { .. }
1330            | Expression::ImageLoad { .. }
1331            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1332            Expression::RayQueryProceedResult
1333            | Expression::RayQueryGetIntersection { .. }
1334            | Expression::RayQueryVertexPositions { .. } => {
1335                Err(ConstantEvaluatorError::RayQueryExpression)
1336            }
1337            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1338            Expression::SubgroupOperationResult { .. } => {
1339                Err(ConstantEvaluatorError::SubgroupExpression)
1340            }
1341            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1342                Err(ConstantEvaluatorError::CooperativeOperation)
1343            }
1344        }
1345    }
1346
1347    /// Splat `value` to `size`, without using [`Splat`] expressions.
1348    ///
1349    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1350    /// build a vector with the given `size` whose components are all
1351    /// `value`.
1352    ///
1353    /// Use `span` as the span of the inserted expressions and
1354    /// resulting types.
1355    ///
1356    /// [`Splat`]: Expression::Splat
1357    /// [`Compose`]: Expression::Compose
1358    /// [`ZeroValue`]: Expression::ZeroValue
1359    fn splat(
1360        &mut self,
1361        value: Handle<Expression>,
1362        size: crate::VectorSize,
1363        span: Span,
1364    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1365        match self.expressions[value] {
1366            Expression::Literal(literal) => {
1367                let scalar = literal.scalar();
1368                let ty = self.types.insert(
1369                    Type {
1370                        name: None,
1371                        inner: TypeInner::Vector { size, scalar },
1372                    },
1373                    span,
1374                );
1375                let expr = Expression::Compose {
1376                    ty,
1377                    components: vec![value; size as usize],
1378                };
1379                self.register_evaluated_expr(expr, span)
1380            }
1381            Expression::ZeroValue(ty) => {
1382                let inner = match self.types[ty].inner {
1383                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1384                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1385                };
1386                let res_ty = self.types.insert(Type { name: None, inner }, span);
1387                let expr = Expression::ZeroValue(res_ty);
1388                self.register_evaluated_expr(expr, span)
1389            }
1390            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1391        }
1392    }
1393
1394    fn swizzle(
1395        &mut self,
1396        size: crate::VectorSize,
1397        span: Span,
1398        src_constant: Handle<Expression>,
1399        pattern: [crate::SwizzleComponent; 4],
1400    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1401        let mut get_dst_ty = |ty| match self.types[ty].inner {
1402            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1403                Type {
1404                    name: None,
1405                    inner: TypeInner::Vector { size, scalar },
1406                },
1407                span,
1408            )),
1409            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1410        };
1411
1412        match self.expressions[src_constant] {
1413            Expression::ZeroValue(ty) => {
1414                let dst_ty = get_dst_ty(ty)?;
1415                let expr = Expression::ZeroValue(dst_ty);
1416                self.register_evaluated_expr(expr, span)
1417            }
1418            Expression::Splat { value, .. } => {
1419                let expr = Expression::Splat { size, value };
1420                self.register_evaluated_expr(expr, span)
1421            }
1422            Expression::Compose { ty, ref components } => {
1423                let dst_ty = get_dst_ty(ty)?;
1424
1425                let mut flattened = [src_constant; 4]; // dummy value
1426                let len =
1427                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1428                        .zip(flattened.iter_mut())
1429                        .map(|(component, elt)| *elt = component)
1430                        .count();
1431                let flattened = &flattened[..len];
1432
1433                let swizzled_components = pattern[..size as usize]
1434                    .iter()
1435                    .map(|&sc| {
1436                        let sc = sc as usize;
1437                        if let Some(elt) = flattened.get(sc) {
1438                            Ok(*elt)
1439                        } else {
1440                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1441                        }
1442                    })
1443                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1444                let expr = Expression::Compose {
1445                    ty: dst_ty,
1446                    components: swizzled_components,
1447                };
1448                self.register_evaluated_expr(expr, span)
1449            }
1450            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1451        }
1452    }
1453
1454    fn math(
1455        &mut self,
1456        arg: Handle<Expression>,
1457        arg1: Option<Handle<Expression>>,
1458        arg2: Option<Handle<Expression>>,
1459        arg3: Option<Handle<Expression>>,
1460        fun: crate::MathFunction,
1461        span: Span,
1462    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1463        let expected = fun.argument_count();
1464        let given = Some(arg)
1465            .into_iter()
1466            .chain(arg1)
1467            .chain(arg2)
1468            .chain(arg3)
1469            .count();
1470        if expected != given {
1471            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1472                fun, expected, given,
1473            ));
1474        }
1475
1476        // NOTE: We try to match the declaration order of `MathFunction` here.
1477        match fun {
1478            // comparison
1479            crate::MathFunction::Abs => {
1480                component_wise_scalar(self, span, [arg], |args| match args {
1481                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1482                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1483                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1484                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1485                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1486                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1487                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1488                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1489                })
1490            }
1491            crate::MathFunction::Min => {
1492                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1493                    Ok([e1.min(e2)])
1494                })
1495            }
1496            crate::MathFunction::Max => {
1497                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1498                    Ok([e1.max(e2)])
1499                })
1500            }
1501            crate::MathFunction::Clamp => {
1502                component_wise_scalar!(
1503                    self,
1504                    span,
1505                    [arg, arg1.unwrap(), arg2.unwrap()],
1506                    |e, low, high| {
1507                        if low > high {
1508                            Err(ConstantEvaluatorError::InvalidClamp)
1509                        } else {
1510                            Ok([e.clamp(low, high)])
1511                        }
1512                    }
1513                )
1514            }
1515            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1516                Float::F16([e]) => Ok(Float::F16(
1517                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1518                )),
1519                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1520                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1521            }),
1522
1523            // trigonometry
1524            crate::MathFunction::Cos => {
1525                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1526            }
1527            crate::MathFunction::Cosh => {
1528                component_wise_float!(self, span, [arg], |e| {
1529                    let result = e.cosh();
1530                    if result.is_finite() {
1531                        Ok([result])
1532                    } else {
1533                        Err(ConstantEvaluatorError::Overflow("cosh".into()))
1534                    }
1535                })
1536            }
1537            crate::MathFunction::Sin => {
1538                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1539            }
1540            crate::MathFunction::Sinh => {
1541                component_wise_float!(self, span, [arg], |e| {
1542                    let result = e.sinh();
1543                    if result.is_finite() {
1544                        Ok([result])
1545                    } else {
1546                        Err(ConstantEvaluatorError::Overflow("sinh".into()))
1547                    }
1548                })
1549            }
1550            crate::MathFunction::Tan => {
1551                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1552            }
1553            crate::MathFunction::Tanh => {
1554                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1555            }
1556            crate::MathFunction::Acos => {
1557                component_wise_float!(self, span, [arg], |e| {
1558                    if e.abs() <= One::one() {
1559                        Ok([e.acos()])
1560                    } else {
1561                        Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1562                    }
1563                })
1564            }
1565            crate::MathFunction::Asin => {
1566                component_wise_float!(self, span, [arg], |e| {
1567                    if e.abs() <= One::one() {
1568                        Ok([e.asin()])
1569                    } else {
1570                        Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1571                    }
1572                })
1573            }
1574            crate::MathFunction::Atan => {
1575                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1576            }
1577            crate::MathFunction::Atan2 => {
1578                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1579                    Ok([y.atan2(x)])
1580                })
1581            }
1582            crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1583                Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1584                Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1585                Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1586            }),
1587            crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1588                Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1589                Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1590                Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1591                _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1592            }),
1593            crate::MathFunction::Atanh => {
1594                component_wise_float!(self, span, [arg], |e| {
1595                    if e.abs() < One::one() {
1596                        Ok([e.atanh()])
1597                    } else {
1598                        Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1599                    }
1600                })
1601            }
1602            crate::MathFunction::Radians => {
1603                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1604            }
1605            crate::MathFunction::Degrees => {
1606                component_wise_float!(self, span, [arg], |e| {
1607                    let result = e.to_degrees();
1608                    if result.is_finite() {
1609                        Ok([result])
1610                    } else {
1611                        Err(ConstantEvaluatorError::Overflow("degrees".into()))
1612                    }
1613                })
1614            }
1615
1616            // decomposition
1617            crate::MathFunction::Ceil => {
1618                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1619            }
1620            crate::MathFunction::Floor => {
1621                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1622            }
1623            crate::MathFunction::Round => {
1624                component_wise_float(self, span, [arg], |e| match e {
1625                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1626                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1627                    Float::F16([e]) => {
1628                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1629                        //
1630                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1631                        // which has licensing compatible with ours. See also
1632                        // <https://github.com/rust-lang/rust/issues/96710>.
1633                        //
1634                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1635                        fn round_ties_even(x: f64) -> f64 {
1636                            let i = x as i64;
1637                            let f = (x - i as f64).abs();
1638                            if f == 0.5 {
1639                                if i & 1 == 1 {
1640                                    // -1.5, 1.5, 3.5, ...
1641                                    (x.abs() + 0.5).copysign(x)
1642                                } else {
1643                                    (x.abs() - 0.5).copysign(x)
1644                                }
1645                            } else {
1646                                x.round()
1647                            }
1648                        }
1649
1650                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1651                    }
1652                })
1653            }
1654            crate::MathFunction::Fract => {
1655                component_wise_float!(self, span, [arg], |e| {
1656                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1657                    // here.
1658                    Ok([e - e.floor()])
1659                })
1660            }
1661            crate::MathFunction::Trunc => {
1662                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1663            }
1664
1665            // exponent
1666            crate::MathFunction::Exp => {
1667                component_wise_float!(self, span, [arg], |e| {
1668                    let result = e.exp();
1669                    if result.is_finite() {
1670                        Ok([result])
1671                    } else {
1672                        Err(ConstantEvaluatorError::Overflow("exp".into()))
1673                    }
1674                })
1675            }
1676            crate::MathFunction::Exp2 => {
1677                component_wise_float!(self, span, [arg], |e| {
1678                    let result = e.exp2();
1679                    if result.is_finite() {
1680                        Ok([result])
1681                    } else {
1682                        Err(ConstantEvaluatorError::Overflow("exp2".into()))
1683                    }
1684                })
1685            }
1686            crate::MathFunction::Log => {
1687                component_wise_float!(self, span, [arg], |e| {
1688                    if e > Zero::zero() {
1689                        Ok([e.ln()])
1690                    } else {
1691                        Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1692                    }
1693                })
1694            }
1695            crate::MathFunction::Log2 => {
1696                component_wise_float!(self, span, [arg], |e| {
1697                    if e > Zero::zero() {
1698                        Ok([e.log2()])
1699                    } else {
1700                        Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1701                    }
1702                })
1703            }
1704            crate::MathFunction::Pow => {
1705                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1706                    // 0.pow(0) is an error since exp2(0 * log2(0)) is NaN.
1707                    // https://www.w3.org/TR/WGSL/#pow-builtin
1708                    if e1 < Zero::zero()
1709                        || e1.is_one() && e2.is_infinite()
1710                        || e1.is_infinite() && e2.is_zero()
1711                        || e1.is_zero() && e2.is_zero()
1712                    {
1713                        Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1714                    } else {
1715                        let result = e1.powf(e2);
1716                        if result.is_finite() {
1717                            Ok([result])
1718                        } else {
1719                            Err(ConstantEvaluatorError::Overflow("pow".into()))
1720                        }
1721                    }
1722                })
1723            }
1724
1725            // computational
1726            crate::MathFunction::Sign => {
1727                component_wise_signed!(self, span, [arg], |e| {
1728                    Ok([if e.is_zero() {
1729                        Zero::zero()
1730                    } else {
1731                        e.signum()
1732                    }])
1733                })
1734            }
1735            crate::MathFunction::Fma => {
1736                component_wise_float!(
1737                    self,
1738                    span,
1739                    [arg, arg1.unwrap(), arg2.unwrap()],
1740                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1741                )
1742            }
1743            crate::MathFunction::Step => {
1744                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1745                    Float::Abstract([edge, x]) => {
1746                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1747                    }
1748                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1749                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1750                        f16::one()
1751                    } else {
1752                        f16::zero()
1753                    }])),
1754                })
1755            }
1756            crate::MathFunction::Sqrt => {
1757                component_wise_float!(self, span, [arg], |e| {
1758                    if e >= Zero::zero() {
1759                        Ok([e.sqrt()])
1760                    } else {
1761                        Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1762                    }
1763                })
1764            }
1765            crate::MathFunction::InverseSqrt => {
1766                component_wise_float(self, span, [arg], |e| match e {
1767                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1768                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1769                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1770                })
1771            }
1772
1773            // bits
1774            crate::MathFunction::CountTrailingZeros => {
1775                component_wise_concrete_int!(self, span, [arg], |e| {
1776                    #[allow(clippy::useless_conversion)]
1777                    Ok([e
1778                        .trailing_zeros()
1779                        .try_into()
1780                        .expect("bit count overflowed 32 bits, somehow!?")])
1781                })
1782            }
1783            crate::MathFunction::CountLeadingZeros => {
1784                component_wise_concrete_int!(self, span, [arg], |e| {
1785                    #[allow(clippy::useless_conversion)]
1786                    Ok([e
1787                        .leading_zeros()
1788                        .try_into()
1789                        .expect("bit count overflowed 32 bits, somehow!?")])
1790                })
1791            }
1792            crate::MathFunction::CountOneBits => {
1793                component_wise_concrete_int!(self, span, [arg], |e| {
1794                    #[allow(clippy::useless_conversion)]
1795                    Ok([e
1796                        .count_ones()
1797                        .try_into()
1798                        .expect("bit count overflowed 32 bits, somehow!?")])
1799                })
1800            }
1801            crate::MathFunction::ReverseBits => {
1802                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1803            }
1804            crate::MathFunction::FirstTrailingBit => {
1805                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1806            }
1807            crate::MathFunction::FirstLeadingBit => {
1808                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1809            }
1810
1811            // vector
1812            crate::MathFunction::Dot4I8Packed => {
1813                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1814            }
1815            crate::MathFunction::Dot4U8Packed => {
1816                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1817            }
1818            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1819            crate::MathFunction::Dot => {
1820                // https://www.w3.org/TR/WGSL/#dot-builtin
1821                let e1 = self.extract_vec(arg, false)?;
1822                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1823                if e1.len() != e2.len() {
1824                    return Err(ConstantEvaluatorError::InvalidMathArg);
1825                }
1826
1827                fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1828                where
1829                    P: num_traits::Float,
1830                {
1831                    let result = a
1832                        .iter()
1833                        .zip(b.iter())
1834                        .map(|(&aa, &bb)| aa * bb)
1835                        .fold(P::zero(), |acc, x| acc + x);
1836                    if result.is_finite() {
1837                        Ok(result)
1838                    } else {
1839                        Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1840                    }
1841                }
1842
1843                fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1844                where
1845                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1846                {
1847                    a.iter()
1848                        .zip(b.iter())
1849                        .map(|(&aa, bb)| aa.checked_mul(bb))
1850                        .try_fold(P::zero(), |acc, x| {
1851                            if let Some(x) = x {
1852                                acc.checked_add(&x)
1853                            } else {
1854                                None
1855                            }
1856                        })
1857                        .ok_or(ConstantEvaluatorError::Overflow(
1858                            "in dot built-in".to_string(),
1859                        ))
1860                }
1861
1862                fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1863                where
1864                    P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1865                {
1866                    a.iter()
1867                        .zip(b.iter())
1868                        .map(|(&aa, bb)| aa.wrapping_mul(bb))
1869                        .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1870                }
1871
1872                let result = match_literal_vector!(match (e1, e2) => Literal {
1873                    Float => |e1, e2| { float_dot_checked(e1, e2)? },
1874                    AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1875                    I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1876                    U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1877                })?;
1878                self.register_evaluated_expr(Expression::Literal(result), span)
1879            }
1880            crate::MathFunction::Length => {
1881                // https://www.w3.org/TR/WGSL/#length-builtin
1882                let e1 = self.extract_vec(arg, true)?;
1883
1884                let result = match_literal_vector!(match e1 => Literal {
1885                    Float => |e1| {
1886                        float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1887                    },
1888                })?;
1889                self.register_evaluated_expr(Expression::Literal(result), span)
1890            }
1891            crate::MathFunction::Distance => {
1892                // https://www.w3.org/TR/WGSL/#distance-builtin
1893                let e1 = self.extract_vec(arg, true)?;
1894                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1895                if e1.len() != e2.len() {
1896                    return Err(ConstantEvaluatorError::InvalidMathArg);
1897                }
1898
1899                fn float_distance<F>(a: &[F], b: &[F]) -> F
1900                where
1901                    F: core::ops::Mul<F>,
1902                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1903                {
1904                    if a.len() == 1 {
1905                        // Avoids possible overflow in squaring
1906                        (a[0] - b[0]).abs()
1907                    } else {
1908                        a.iter()
1909                            .zip(b.iter())
1910                            .map(|(&aa, &bb)| aa - bb)
1911                            .map(|ei| ei * ei)
1912                            .sum::<F>()
1913                            .sqrt()
1914                    }
1915                }
1916                let result = match_literal_vector!(match (e1, e2) => Literal {
1917                    Float => |e1, e2| { float_distance(e1, e2) },
1918                })?;
1919                self.register_evaluated_expr(Expression::Literal(result), span)
1920            }
1921            crate::MathFunction::Normalize => {
1922                // https://www.w3.org/TR/WGSL/#normalize-builtin
1923                let e1 = self.extract_vec(arg, true)?;
1924
1925                fn float_normalize<F>(
1926                    e: &[F],
1927                ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1928                where
1929                    F: core::ops::Mul<F>,
1930                    F: num_traits::Float + iter::Sum,
1931                {
1932                    let len = match float_length(e) {
1933                        Some(len) if !len.is_zero() => Ok(len),
1934                        Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1935                            "normalize".into(),
1936                        )),
1937                        None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1938                    }?;
1939
1940                    let mut out = ArrayVec::new();
1941                    for &ei in e {
1942                        out.push(ei / len);
1943                    }
1944                    Ok(out)
1945                }
1946
1947                let result = match_literal_vector!(match e1 => LiteralVector {
1948                    Float => |e1| { float_normalize(e1)? },
1949                })?;
1950                result.register_as_evaluated_expr(self, span)
1951            }
1952
1953            // unimplemented
1954            crate::MathFunction::Modf
1955            | crate::MathFunction::Frexp
1956            | crate::MathFunction::Ldexp
1957            | crate::MathFunction::Outer
1958            | crate::MathFunction::FaceForward
1959            | crate::MathFunction::Reflect
1960            | crate::MathFunction::Refract
1961            | crate::MathFunction::Mix
1962            | crate::MathFunction::SmoothStep
1963            | crate::MathFunction::Inverse
1964            | crate::MathFunction::Transpose
1965            | crate::MathFunction::Determinant
1966            | crate::MathFunction::QuantizeToF16
1967            | crate::MathFunction::ExtractBits
1968            | crate::MathFunction::InsertBits
1969            | crate::MathFunction::Pack4x8snorm
1970            | crate::MathFunction::Pack4x8unorm
1971            | crate::MathFunction::Pack2x16snorm
1972            | crate::MathFunction::Pack2x16unorm
1973            | crate::MathFunction::Pack2x16float
1974            | crate::MathFunction::Pack4xI8
1975            | crate::MathFunction::Pack4xU8
1976            | crate::MathFunction::Pack4xI8Clamp
1977            | crate::MathFunction::Pack4xU8Clamp
1978            | crate::MathFunction::Unpack4x8snorm
1979            | crate::MathFunction::Unpack4x8unorm
1980            | crate::MathFunction::Unpack2x16snorm
1981            | crate::MathFunction::Unpack2x16unorm
1982            | crate::MathFunction::Unpack2x16float
1983            | crate::MathFunction::Unpack4xI8
1984            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1985                format!("{fun:?} built-in function"),
1986            )),
1987        }
1988    }
1989
1990    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1991    fn packed_dot_product(
1992        &mut self,
1993        a: Handle<Expression>,
1994        b: Handle<Expression>,
1995        span: Span,
1996        signed: bool,
1997    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1998        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1999            return Err(ConstantEvaluatorError::InvalidMathArg);
2000        };
2001        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2002            return Err(ConstantEvaluatorError::InvalidMathArg);
2003        };
2004
2005        let result = if signed {
2006            Literal::I32(
2007                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2008                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2009                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2010                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2011            )
2012        } else {
2013            Literal::U32(
2014                (a & 0xFF) * (b & 0xFF)
2015                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2016                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2017                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2018            )
2019        };
2020
2021        self.register_evaluated_expr(Expression::Literal(result), span)
2022    }
2023
2024    /// Vector cross product.
2025    fn cross_product(
2026        &mut self,
2027        a: Handle<Expression>,
2028        b: Handle<Expression>,
2029        span: Span,
2030    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2031        use Literal as Li;
2032
2033        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2034        let (b, _) = self.extract_vec_with_size::<3>(b)?;
2035
2036        let product = match (a, b) {
2037            (
2038                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2039                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2040            ) => {
2041                // `cross` has no overload for AbstractInt, so AbstractInt
2042                // arguments are automatically converted to AbstractFloat. Since
2043                // `f64` has a much wider range than `i64`, there's no danger of
2044                // overflow here.
2045                let p = cross_product(
2046                    [a0 as f64, a1 as f64, a2 as f64],
2047                    [b0 as f64, b1 as f64, b2 as f64],
2048                );
2049                [
2050                    Li::AbstractFloat(p[0]),
2051                    Li::AbstractFloat(p[1]),
2052                    Li::AbstractFloat(p[2]),
2053                ]
2054            }
2055            (
2056                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2057                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2058            ) => {
2059                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2060                [
2061                    Li::AbstractFloat(p[0]),
2062                    Li::AbstractFloat(p[1]),
2063                    Li::AbstractFloat(p[2]),
2064                ]
2065            }
2066            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2067                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2068                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2069            }
2070            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2071                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2072                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2073            }
2074            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2075                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2076                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2077            }
2078            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2079        };
2080
2081        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2082        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2083        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2084
2085        self.register_evaluated_expr(
2086            Expression::Compose {
2087                ty,
2088                components: vec![p0, p1, p2],
2089            },
2090            span,
2091        )
2092    }
2093
2094    /// Extract the values of a `vecN` from `expr`.
2095    ///
2096    /// Return the value of `expr`, whose type is `vecN<S>` for some
2097    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2098    /// values.
2099    ///
2100    /// Also return the type handle from the `Compose` expression.
2101    fn extract_vec_with_size<const N: usize>(
2102        &mut self,
2103        expr: Handle<Expression>,
2104    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2105        let span = self.expressions.get_span(expr);
2106        let expr = self.eval_zero_value_and_splat(expr, span)?;
2107        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2108            return Err(ConstantEvaluatorError::InvalidMathArg);
2109        };
2110
2111        let mut value = [Literal::Bool(false); N];
2112        for (component, elt) in
2113            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2114                .zip(value.iter_mut())
2115        {
2116            let Expression::Literal(literal) = self.expressions[component] else {
2117                return Err(ConstantEvaluatorError::InvalidMathArg);
2118            };
2119            *elt = literal;
2120        }
2121
2122        Ok((value, ty))
2123    }
2124
2125    /// Extract the values of a `vecN` from `expr`.
2126    ///
2127    /// Return the value of `expr`, whose type is `vecN<S>` for some
2128    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2129    /// values.
2130    ///
2131    /// Also return the type handle from the `Compose` expression.
2132    fn extract_vec(
2133        &mut self,
2134        expr: Handle<Expression>,
2135        allow_single: bool,
2136    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2137        let span = self.expressions.get_span(expr);
2138        let expr = self.eval_zero_value_and_splat(expr, span)?;
2139
2140        match self.expressions[expr] {
2141            Expression::Literal(literal) if allow_single => {
2142                Ok(LiteralVector::from_literal(literal))
2143            }
2144            Expression::Compose { ty, ref components } => {
2145                let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2146                for expr in
2147                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2148                {
2149                    match self.expressions[expr] {
2150                        Expression::Literal(l) => components_out.push(l),
2151                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2152                    }
2153                }
2154                LiteralVector::from_literal_vec(components_out)
2155            }
2156            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2157        }
2158    }
2159
2160    fn array_length(
2161        &mut self,
2162        array: Handle<Expression>,
2163        span: Span,
2164    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2165        match self.expressions[array] {
2166            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2167                match self.types[ty].inner {
2168                    TypeInner::Array { size, .. } => match size {
2169                        ArraySize::Constant(len) => {
2170                            let expr = Expression::Literal(Literal::U32(len.get()));
2171                            self.register_evaluated_expr(expr, span)
2172                        }
2173                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2174                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2175                    },
2176                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2177                }
2178            }
2179            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2180        }
2181    }
2182
2183    fn access(
2184        &mut self,
2185        base: Handle<Expression>,
2186        index: usize,
2187        span: Span,
2188    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2189        match self.expressions[base] {
2190            Expression::ZeroValue(ty) => {
2191                let ty_inner = &self.types[ty].inner;
2192                let components = ty_inner
2193                    .components()
2194                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2195
2196                if index >= components as usize {
2197                    Err(ConstantEvaluatorError::InvalidAccessBase)
2198                } else {
2199                    let ty_res = ty_inner
2200                        .component_type(index)
2201                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2202                    let ty = match ty_res {
2203                        crate::proc::TypeResolution::Handle(ty) => ty,
2204                        crate::proc::TypeResolution::Value(inner) => {
2205                            self.types.insert(Type { name: None, inner }, span)
2206                        }
2207                    };
2208                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2209                }
2210            }
2211            Expression::Splat { size, value } => {
2212                if index >= size as usize {
2213                    Err(ConstantEvaluatorError::InvalidAccessBase)
2214                } else {
2215                    Ok(value)
2216                }
2217            }
2218            Expression::Compose { ty, ref components } => {
2219                let _ = self.types[ty]
2220                    .inner
2221                    .components()
2222                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2223
2224                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2225                    .nth(index)
2226                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2227            }
2228            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2229        }
2230    }
2231
2232    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2233    ///
2234    /// [`ZeroValue`]: Expression::ZeroValue
2235    /// [`Splat`]: Expression::Splat
2236    /// [`Literal`]: Expression::Literal
2237    /// [`Compose`]: Expression::Compose
2238    fn eval_zero_value_and_splat(
2239        &mut self,
2240        mut expr: Handle<Expression>,
2241        span: Span,
2242    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2243        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2244        // each of its components.
2245        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2246            let components = components
2247                .clone()
2248                .iter()
2249                .map(|component| self.eval_zero_value_and_splat(*component, span))
2250                .collect::<Result<_, _>>()?;
2251            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2252        }
2253
2254        // The result of the splat() for a Splat of a scalar ZeroValue is a
2255        // vector ZeroValue, so we must call eval_zero_value_impl() after
2256        // splat() in order to ensure we have no ZeroValues remaining.
2257        if let Expression::Splat { size, value } = self.expressions[expr] {
2258            expr = self.splat(value, size, span)?;
2259        }
2260        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2261            expr = self.eval_zero_value_impl(ty, span)?;
2262        }
2263        Ok(expr)
2264    }
2265
2266    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2267    ///
2268    /// [`ZeroValue`]: Expression::ZeroValue
2269    /// [`Literal`]: Expression::Literal
2270    /// [`Compose`]: Expression::Compose
2271    fn eval_zero_value(
2272        &mut self,
2273        expr: Handle<Expression>,
2274        span: Span,
2275    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2276        match self.expressions[expr] {
2277            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2278            _ => Ok(expr),
2279        }
2280    }
2281
2282    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2283    ///
2284    /// [`ZeroValue`]: Expression::ZeroValue
2285    /// [`Literal`]: Expression::Literal
2286    /// [`Compose`]: Expression::Compose
2287    fn eval_zero_value_impl(
2288        &mut self,
2289        ty: Handle<Type>,
2290        span: Span,
2291    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2292        match self.types[ty].inner {
2293            TypeInner::Scalar(scalar) => {
2294                let expr = Expression::Literal(
2295                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2296                );
2297                self.register_evaluated_expr(expr, span)
2298            }
2299            TypeInner::Vector { size, scalar } => {
2300                let scalar_ty = self.types.insert(
2301                    Type {
2302                        name: None,
2303                        inner: TypeInner::Scalar(scalar),
2304                    },
2305                    span,
2306                );
2307                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2308                let expr = Expression::Compose {
2309                    ty,
2310                    components: vec![el; size as usize],
2311                };
2312                self.register_evaluated_expr(expr, span)
2313            }
2314            TypeInner::Matrix {
2315                columns,
2316                rows,
2317                scalar,
2318            } => {
2319                let vec_ty = self.types.insert(
2320                    Type {
2321                        name: None,
2322                        inner: TypeInner::Vector { size: rows, scalar },
2323                    },
2324                    span,
2325                );
2326                let el = self.eval_zero_value_impl(vec_ty, span)?;
2327                let expr = Expression::Compose {
2328                    ty,
2329                    components: vec![el; columns as usize],
2330                };
2331                self.register_evaluated_expr(expr, span)
2332            }
2333            TypeInner::Array {
2334                base,
2335                size: ArraySize::Constant(size),
2336                ..
2337            } => {
2338                let el = self.eval_zero_value_impl(base, span)?;
2339                let expr = Expression::Compose {
2340                    ty,
2341                    components: vec![el; size.get() as usize],
2342                };
2343                self.register_evaluated_expr(expr, span)
2344            }
2345            TypeInner::Struct { ref members, .. } => {
2346                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2347                let mut components = Vec::with_capacity(members.len());
2348                for ty in types {
2349                    components.push(self.eval_zero_value_impl(ty, span)?);
2350                }
2351                let expr = Expression::Compose { ty, components };
2352                self.register_evaluated_expr(expr, span)
2353            }
2354            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2355        }
2356    }
2357
2358    /// Convert the scalar components of `expr` to `target`.
2359    ///
2360    /// Treat `span` as the location of the resulting expression.
2361    pub fn cast(
2362        &mut self,
2363        expr: Handle<Expression>,
2364        target: crate::Scalar,
2365        span: Span,
2366    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2367        use crate::Scalar as Sc;
2368
2369        let expr = self.eval_zero_value(expr, span)?;
2370
2371        let make_error = || -> Result<_, ConstantEvaluatorError> {
2372            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2373
2374            #[cfg(feature = "wgsl-in")]
2375            let to = target.to_wgsl_for_diagnostics();
2376
2377            #[cfg(not(feature = "wgsl-in"))]
2378            let to = format!("{target:?}");
2379
2380            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2381        };
2382
2383        use crate::proc::type_methods::IntFloatLimits;
2384
2385        let expr = match self.expressions[expr] {
2386            Expression::Literal(literal) => {
2387                let literal = match target {
2388                    Sc::I32 => Literal::I32(match literal {
2389                        Literal::I32(v) => v,
2390                        Literal::U32(v) => v as i32,
2391                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2392                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2393                        Literal::Bool(v) => v as i32,
2394                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2395                            return make_error();
2396                        }
2397                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2398                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2399                    }),
2400                    Sc::U32 => Literal::U32(match literal {
2401                        Literal::I32(v) => v as u32,
2402                        Literal::U32(v) => v,
2403                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2404                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2405                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2406                        Literal::Bool(v) => v as u32,
2407                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2408                            return make_error();
2409                        }
2410                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2411                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2412                    }),
2413                    Sc::I64 => Literal::I64(match literal {
2414                        Literal::I32(v) => v as i64,
2415                        Literal::U32(v) => v as i64,
2416                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2417                        Literal::Bool(v) => v as i64,
2418                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2419                        Literal::I64(v) => v,
2420                        Literal::U64(v) => v as i64,
2421                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2422                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2423                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2424                    }),
2425                    Sc::U64 => Literal::U64(match literal {
2426                        Literal::I32(v) => v as u64,
2427                        Literal::U32(v) => v as u64,
2428                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2429                        Literal::Bool(v) => v as u64,
2430                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2431                        Literal::I64(v) => v as u64,
2432                        Literal::U64(v) => v,
2433                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2434                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2435                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2436                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2437                    }),
2438                    Sc::F16 => Literal::F16(match literal {
2439                        Literal::F16(v) => v,
2440                        Literal::F32(v) => f16::from_f32(v),
2441                        Literal::F64(v) => f16::from_f64(v),
2442                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2443                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2444                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2445                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2446                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2447                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2448                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2449                    }),
2450                    Sc::F32 => Literal::F32(match literal {
2451                        Literal::I32(v) => v as f32,
2452                        Literal::U32(v) => v as f32,
2453                        Literal::F32(v) => v,
2454                        Literal::Bool(v) => v as u32 as f32,
2455                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2456                            return make_error();
2457                        }
2458                        Literal::F16(v) => f16::to_f32(v),
2459                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2460                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2461                    }),
2462                    Sc::F64 => Literal::F64(match literal {
2463                        Literal::I32(v) => v as f64,
2464                        Literal::U32(v) => v as f64,
2465                        Literal::F16(v) => f16::to_f64(v),
2466                        Literal::F32(v) => v as f64,
2467                        Literal::F64(v) => v,
2468                        Literal::Bool(v) => v as u32 as f64,
2469                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2470                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2471                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2472                    }),
2473                    Sc::BOOL => Literal::Bool(match literal {
2474                        Literal::I32(v) => v != 0,
2475                        Literal::U32(v) => v != 0,
2476                        Literal::F32(v) => v != 0.0,
2477                        Literal::F16(v) => v != f16::zero(),
2478                        Literal::Bool(v) => v,
2479                        Literal::AbstractInt(v) => v != 0,
2480                        Literal::AbstractFloat(v) => v != 0.0,
2481                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2482                            return make_error();
2483                        }
2484                    }),
2485                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2486                        Literal::AbstractInt(v) => {
2487                            // Overflow is forbidden, but inexact conversions
2488                            // are fine. The range of f64 is far larger than
2489                            // that of i64, so we don't have to check anything
2490                            // here.
2491                            v as f64
2492                        }
2493                        Literal::AbstractFloat(v) => v,
2494                        _ => return make_error(),
2495                    }),
2496                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2497                        Literal::AbstractInt(v) => v,
2498                        _ => return make_error(),
2499                    }),
2500                    _ => {
2501                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2502                        return make_error();
2503                    }
2504                };
2505                Expression::Literal(literal)
2506            }
2507            Expression::Compose {
2508                ty,
2509                components: ref src_components,
2510            } => {
2511                let ty_inner = match self.types[ty].inner {
2512                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2513                        size,
2514                        scalar: target,
2515                    },
2516                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2517                        columns,
2518                        rows,
2519                        scalar: target,
2520                    },
2521                    _ => return make_error(),
2522                };
2523
2524                let mut components = src_components.clone();
2525                for component in &mut components {
2526                    *component = self.cast(*component, target, span)?;
2527                }
2528
2529                let ty = self.types.insert(
2530                    Type {
2531                        name: None,
2532                        inner: ty_inner,
2533                    },
2534                    span,
2535                );
2536
2537                Expression::Compose { ty, components }
2538            }
2539            Expression::Splat { size, value } => {
2540                let value_span = self.expressions.get_span(value);
2541                let cast_value = self.cast(value, target, value_span)?;
2542                Expression::Splat {
2543                    size,
2544                    value: cast_value,
2545                }
2546            }
2547            _ => return make_error(),
2548        };
2549
2550        self.register_evaluated_expr(expr, span)
2551    }
2552
2553    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2554    ///
2555    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2556    /// matrix, or nested arrays of such.
2557    ///
2558    /// This is basically the same as the [`cast`] method, except that that
2559    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2560    ///
2561    /// Treat `span` as the location of the resulting expression.
2562    ///
2563    /// [`cast`]: ConstantEvaluator::cast
2564    /// [`As`]: crate::Expression::As
2565    pub fn cast_array(
2566        &mut self,
2567        expr: Handle<Expression>,
2568        target: crate::Scalar,
2569        span: Span,
2570    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2571        let expr = self.check_and_get(expr)?;
2572
2573        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2574            return self.cast(expr, target, span);
2575        };
2576
2577        let TypeInner::Array {
2578            base: _,
2579            size,
2580            stride: _,
2581        } = self.types[ty].inner
2582        else {
2583            return self.cast(expr, target, span);
2584        };
2585
2586        let mut components = components.clone();
2587        for component in &mut components {
2588            *component = self.cast_array(*component, target, span)?;
2589        }
2590
2591        let first = components.first().unwrap();
2592        let new_base = match self.resolve_type(*first)? {
2593            crate::proc::TypeResolution::Handle(ty) => ty,
2594            crate::proc::TypeResolution::Value(inner) => {
2595                self.types.insert(Type { name: None, inner }, span)
2596            }
2597        };
2598        let mut layouter = core::mem::take(self.layouter);
2599        layouter.update(self.to_ctx()).unwrap();
2600        *self.layouter = layouter;
2601
2602        let new_base_stride = self.layouter[new_base].to_stride();
2603        let new_array_ty = self.types.insert(
2604            Type {
2605                name: None,
2606                inner: TypeInner::Array {
2607                    base: new_base,
2608                    size,
2609                    stride: new_base_stride,
2610                },
2611            },
2612            span,
2613        );
2614
2615        let compose = Expression::Compose {
2616            ty: new_array_ty,
2617            components,
2618        };
2619        self.register_evaluated_expr(compose, span)
2620    }
2621
2622    fn unary_op(
2623        &mut self,
2624        op: UnaryOperator,
2625        expr: Handle<Expression>,
2626        span: Span,
2627    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2628        let expr = self.eval_zero_value_and_splat(expr, span)?;
2629
2630        let expr = match self.expressions[expr] {
2631            Expression::Literal(value) => Expression::Literal(match op {
2632                UnaryOperator::Negate => match value {
2633                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2634                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2635                    Literal::F32(v) => Literal::F32(-v),
2636                    Literal::F16(v) => Literal::F16(-v),
2637                    Literal::F64(v) => Literal::F64(-v),
2638                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2639                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2640                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2641                },
2642                UnaryOperator::LogicalNot => match value {
2643                    Literal::Bool(v) => Literal::Bool(!v),
2644                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2645                },
2646                UnaryOperator::BitwiseNot => match value {
2647                    Literal::I32(v) => Literal::I32(!v),
2648                    Literal::I64(v) => Literal::I64(!v),
2649                    Literal::U32(v) => Literal::U32(!v),
2650                    Literal::U64(v) => Literal::U64(!v),
2651                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2652                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2653                },
2654            }),
2655            Expression::Compose {
2656                ty,
2657                components: ref src_components,
2658            } => {
2659                match self.types[ty].inner {
2660                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2661                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2662                }
2663
2664                let mut components = src_components.clone();
2665                for component in &mut components {
2666                    *component = self.unary_op(op, *component, span)?;
2667                }
2668
2669                Expression::Compose { ty, components }
2670            }
2671            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2672        };
2673
2674        self.register_evaluated_expr(expr, span)
2675    }
2676
2677    fn binary_op(
2678        &mut self,
2679        op: BinaryOperator,
2680        left: Handle<Expression>,
2681        right: Handle<Expression>,
2682        span: Span,
2683    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2684        let left = self.eval_zero_value_and_splat(left, span)?;
2685        let right = self.eval_zero_value_and_splat(right, span)?;
2686
2687        // Note: in most cases constant evaluation checks for overflow, but for
2688        // i32/u32, it uses wrapping arithmetic. See
2689        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2690
2691        let expr = match (&self.expressions[left], &self.expressions[right]) {
2692            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2693                if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2694                    && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2695                {
2696                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2697                }
2698
2699                if matches!(
2700                    (left_value, op),
2701                    (
2702                        Literal::Bool(_),
2703                        BinaryOperator::Less
2704                            | BinaryOperator::LessEqual
2705                            | BinaryOperator::Greater
2706                            | BinaryOperator::GreaterEqual
2707                    )
2708                ) {
2709                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2710                }
2711
2712                let literal = match op {
2713                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2714                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2715                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2716                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2717                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2718                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2719
2720                    _ => match (left_value, right_value) {
2721                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2722                            BinaryOperator::Add => a.wrapping_add(b),
2723                            BinaryOperator::Subtract => a.wrapping_sub(b),
2724                            BinaryOperator::Multiply => a.wrapping_mul(b),
2725                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2726                                if b == 0 {
2727                                    ConstantEvaluatorError::DivisionByZero
2728                                } else {
2729                                    ConstantEvaluatorError::Overflow("division".into())
2730                                }
2731                            })?,
2732                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2733                                if b == 0 {
2734                                    ConstantEvaluatorError::RemainderByZero
2735                                } else {
2736                                    ConstantEvaluatorError::Overflow("remainder".into())
2737                                }
2738                            })?,
2739                            BinaryOperator::And => a & b,
2740                            BinaryOperator::ExclusiveOr => a ^ b,
2741                            BinaryOperator::InclusiveOr => a | b,
2742                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2743                        }),
2744                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2745                            BinaryOperator::ShiftLeft => {
2746                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2747                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2748                                }
2749                                a.checked_shl(b)
2750                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2751                            }
2752                            BinaryOperator::ShiftRight => a
2753                                .checked_shr(b)
2754                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2755                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2756                        }),
2757                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2758                            BinaryOperator::Add => a.wrapping_add(b),
2759                            BinaryOperator::Subtract => a.wrapping_sub(b),
2760                            BinaryOperator::Multiply => a.wrapping_mul(b),
2761                            BinaryOperator::Divide => a
2762                                .checked_div(b)
2763                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2764                            BinaryOperator::Modulo => a
2765                                .checked_rem(b)
2766                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2767                            BinaryOperator::And => a & b,
2768                            BinaryOperator::ExclusiveOr => a ^ b,
2769                            BinaryOperator::InclusiveOr => a | b,
2770                            BinaryOperator::ShiftLeft => a
2771                                .checked_mul(
2772                                    1u32.checked_shl(b)
2773                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2774                                )
2775                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2776                            BinaryOperator::ShiftRight => a
2777                                .checked_shr(b)
2778                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2779                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2780                        }),
2781                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2782                            BinaryOperator::Add => a + b,
2783                            BinaryOperator::Subtract => a - b,
2784                            BinaryOperator::Multiply => a * b,
2785                            BinaryOperator::Divide => a / b,
2786                            BinaryOperator::Modulo => a % b,
2787                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2788                        }),
2789                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2790                            Literal::AbstractInt(match op {
2791                                BinaryOperator::ShiftLeft => {
2792                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2793                                        return Err(ConstantEvaluatorError::Overflow(
2794                                            "<<".to_string(),
2795                                        ));
2796                                    }
2797                                    a.checked_shl(b).unwrap_or(0)
2798                                }
2799                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2800                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2801                            })
2802                        }
2803                        (Literal::F16(a), Literal::F16(b)) => {
2804                            let result = match op {
2805                                BinaryOperator::Add => a + b,
2806                                BinaryOperator::Subtract => a - b,
2807                                BinaryOperator::Multiply => a * b,
2808                                BinaryOperator::Divide => a / b,
2809                                BinaryOperator::Modulo => a % b,
2810                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2811                            };
2812                            if !result.is_finite() {
2813                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2814                            }
2815                            Literal::F16(result)
2816                        }
2817                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2818                            Literal::AbstractInt(match op {
2819                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2820                                    ConstantEvaluatorError::Overflow("addition".into())
2821                                })?,
2822                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2823                                    ConstantEvaluatorError::Overflow("subtraction".into())
2824                                })?,
2825                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2826                                    ConstantEvaluatorError::Overflow("multiplication".into())
2827                                })?,
2828                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2829                                    if b == 0 {
2830                                        ConstantEvaluatorError::DivisionByZero
2831                                    } else {
2832                                        ConstantEvaluatorError::Overflow("division".into())
2833                                    }
2834                                })?,
2835                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2836                                    if b == 0 {
2837                                        ConstantEvaluatorError::RemainderByZero
2838                                    } else {
2839                                        ConstantEvaluatorError::Overflow("remainder".into())
2840                                    }
2841                                })?,
2842                                BinaryOperator::And => a & b,
2843                                BinaryOperator::ExclusiveOr => a ^ b,
2844                                BinaryOperator::InclusiveOr => a | b,
2845                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2846                            })
2847                        }
2848                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2849                            let result = match op {
2850                                BinaryOperator::Add => a + b,
2851                                BinaryOperator::Subtract => a - b,
2852                                BinaryOperator::Multiply => a * b,
2853                                BinaryOperator::Divide => a / b,
2854                                BinaryOperator::Modulo => a % b,
2855                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2856                            };
2857                            if !result.is_finite() {
2858                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2859                            }
2860                            Literal::AbstractFloat(result)
2861                        }
2862                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2863                            BinaryOperator::LogicalAnd => a && b,
2864                            BinaryOperator::LogicalOr => a || b,
2865                            BinaryOperator::And => a & b,
2866                            BinaryOperator::InclusiveOr => a | b,
2867                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2868                        }),
2869                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2870                    },
2871                };
2872                Expression::Literal(literal)
2873            }
2874            (
2875                &Expression::Compose {
2876                    components: ref src_components,
2877                    ty,
2878                },
2879                &Expression::Literal(_),
2880            ) => {
2881                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2882                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2883                }
2884                let mut components = src_components.clone();
2885                for component in &mut components {
2886                    *component = self.binary_op(op, *component, right, span)?;
2887                }
2888                Expression::Compose { ty, components }
2889            }
2890            (
2891                &Expression::Literal(_),
2892                &Expression::Compose {
2893                    components: ref src_components,
2894                    ty,
2895                },
2896            ) => {
2897                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2898                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2899                }
2900                let mut components = src_components.clone();
2901                for component in &mut components {
2902                    *component = self.binary_op(op, left, *component, span)?;
2903                }
2904                Expression::Compose { ty, components }
2905            }
2906            (
2907                &Expression::Compose {
2908                    components: ref left_components,
2909                    ty: left_ty,
2910                },
2911                &Expression::Compose {
2912                    components: ref right_components,
2913                    ty: right_ty,
2914                },
2915            ) => {
2916                // We have to make a copy of the component lists, because the
2917                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2918                // the component lists.
2919                let left_flattened = crate::proc::flatten_compose(
2920                    left_ty,
2921                    left_components,
2922                    self.expressions,
2923                    self.types,
2924                )
2925                .collect::<Vec<_>>();
2926                let right_flattened = crate::proc::flatten_compose(
2927                    right_ty,
2928                    right_components,
2929                    self.expressions,
2930                    self.types,
2931                )
2932                .collect::<Vec<_>>();
2933
2934                self.binary_op_compose(
2935                    op,
2936                    &left_flattened,
2937                    &right_flattened,
2938                    left_ty,
2939                    right_ty,
2940                    span,
2941                )?
2942            }
2943            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2944        };
2945
2946        return self.register_evaluated_expr(expr, span);
2947
2948        fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
2949            let is_numeric_vec = matches!(
2950                compose_ty, TypeInner::Vector { scalar, .. }
2951                if scalar.kind != ScalarKind::Bool
2952            );
2953            let is_allowed_vec_scalar_op = matches!(
2954                op,
2955                BinaryOperator::Add
2956                    | BinaryOperator::Subtract
2957                    | BinaryOperator::Multiply
2958                    | BinaryOperator::Divide
2959                    | BinaryOperator::Modulo
2960            );
2961            let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
2962            let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
2963            is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
2964        }
2965    }
2966
2967    fn binary_op_compose(
2968        &mut self,
2969        op: BinaryOperator,
2970        left_components: &[Handle<Expression>],
2971        right_components: &[Handle<Expression>],
2972        left_ty: Handle<Type>,
2973        right_ty: Handle<Type>,
2974        span: Span,
2975    ) -> Result<Expression, ConstantEvaluatorError> {
2976        match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2977            // Binary operation on vector-vector
2978            (
2979                &TypeInner::Vector {
2980                    size: left_size, ..
2981                },
2982                &TypeInner::Vector {
2983                    size: right_size, ..
2984                },
2985            ) if left_size == right_size => self.binary_op_vector(
2986                op,
2987                left_size,
2988                left_components,
2989                right_components,
2990                left_ty,
2991                span,
2992            ),
2993            // Binary operation on vector-matrix
2994            (
2995                &TypeInner::Vector { size, .. },
2996                &TypeInner::Matrix {
2997                    columns,
2998                    rows,
2999                    scalar,
3000                },
3001            ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3002                left_components,
3003                right_components,
3004                columns,
3005                scalar,
3006                span,
3007            ),
3008            // Binary operation on matrix-vector
3009            (
3010                &TypeInner::Matrix {
3011                    columns,
3012                    rows,
3013                    scalar,
3014                },
3015                &TypeInner::Vector { size, .. },
3016            ) if op == BinaryOperator::Multiply && size == columns => {
3017                self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3018            }
3019            // Binary operation on matrix-matrix
3020            (
3021                &TypeInner::Matrix {
3022                    columns: left_columns,
3023                    rows: left_rows,
3024                    scalar,
3025                },
3026                &TypeInner::Matrix {
3027                    columns: right_columns,
3028                    rows: right_rows,
3029                    ..
3030                },
3031            ) => match op {
3032                BinaryOperator::Add | BinaryOperator::Subtract
3033                    if left_columns == right_columns && left_rows == right_rows =>
3034                {
3035                    let components = left_components
3036                        .iter()
3037                        .zip(right_components)
3038                        .map(|(&left, &right)| self.binary_op(op, left, right, span))
3039                        .collect::<Result<Vec<_>, _>>()?;
3040                    Ok(Expression::Compose {
3041                        ty: left_ty,
3042                        components,
3043                    })
3044                }
3045                BinaryOperator::Multiply if left_columns == right_rows => self
3046                    .multiply_matrix_matrix(
3047                        left_components,
3048                        right_components,
3049                        left_rows,
3050                        right_columns,
3051                        scalar,
3052                        span,
3053                    ),
3054                _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3055            },
3056            _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3057        }
3058    }
3059
3060    fn binary_op_vector(
3061        &mut self,
3062        op: BinaryOperator,
3063        size: crate::VectorSize,
3064        left_components: &[Handle<Expression>],
3065        right_components: &[Handle<Expression>],
3066        left_ty: Handle<Type>,
3067        span: Span,
3068    ) -> Result<Expression, ConstantEvaluatorError> {
3069        let ty = match op {
3070            // Relational operators produce vectors of booleans.
3071            BinaryOperator::Equal
3072            | BinaryOperator::NotEqual
3073            | BinaryOperator::Less
3074            | BinaryOperator::LessEqual
3075            | BinaryOperator::Greater
3076            | BinaryOperator::GreaterEqual => self.types.insert(
3077                Type {
3078                    name: None,
3079                    inner: TypeInner::Vector {
3080                        size,
3081                        scalar: crate::Scalar::BOOL,
3082                    },
3083                },
3084                span,
3085            ),
3086
3087            // Other operators produce the same type as their left
3088            // operand.
3089            BinaryOperator::Add
3090            | BinaryOperator::Subtract
3091            | BinaryOperator::Multiply
3092            | BinaryOperator::Divide
3093            | BinaryOperator::Modulo
3094            | BinaryOperator::And
3095            | BinaryOperator::ExclusiveOr
3096            | BinaryOperator::InclusiveOr
3097            | BinaryOperator::ShiftLeft
3098            | BinaryOperator::ShiftRight => left_ty,
3099
3100            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3101                // Not supported on vectors
3102                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3103            }
3104        };
3105
3106        let components = left_components
3107            .iter()
3108            .zip(right_components)
3109            .map(|(&left, &right)| self.binary_op(op, left, right, span))
3110            .collect::<Result<Vec<_>, _>>()?;
3111
3112        Ok(Expression::Compose { ty, components })
3113    }
3114
3115    fn multiply_vector_matrix(
3116        &mut self,
3117        vec_components: &[Handle<Expression>],
3118        mat_components: &[Handle<Expression>],
3119        mat_columns: crate::VectorSize,
3120        scalar: crate::Scalar,
3121        span: Span,
3122    ) -> Result<Expression, ConstantEvaluatorError> {
3123        let ty = self.types.insert(
3124            Type {
3125                name: None,
3126                inner: TypeInner::Vector {
3127                    size: mat_columns,
3128                    scalar,
3129                },
3130            },
3131            span,
3132        );
3133        let components = mat_components
3134            .iter()
3135            .map(|&column| {
3136                let Expression::Compose { ref components, .. } = self.expressions[column] else {
3137                    unreachable!()
3138                };
3139                self.dot_exprs(
3140                    vec_components.iter().cloned(),
3141                    components.clone().into_iter(),
3142                    span,
3143                )
3144            })
3145            .collect::<Result<Vec<_>, _>>()?;
3146        Ok(Expression::Compose { ty, components })
3147    }
3148
3149    fn multiply_matrix_vector(
3150        &mut self,
3151        mat_components: &[Handle<Expression>],
3152        vec_components: &[Handle<Expression>],
3153        mat_rows: crate::VectorSize,
3154        scalar: crate::Scalar,
3155        span: Span,
3156    ) -> Result<Expression, ConstantEvaluatorError> {
3157        let ty = self.types.insert(
3158            Type {
3159                name: None,
3160                inner: TypeInner::Vector {
3161                    size: mat_rows,
3162                    scalar,
3163                },
3164            },
3165            span,
3166        );
3167
3168        let flatten = self.flatten_matrix(mat_components);
3169        let nr = mat_rows as usize;
3170        let components = (0..nr)
3171            .map(|r| {
3172                let row = flatten.iter().skip(r).step_by(nr).cloned();
3173                self.dot_exprs(row, vec_components.iter().cloned(), span)
3174            })
3175            .collect::<Result<Vec<_>, _>>()?;
3176        Ok(Expression::Compose { ty, components })
3177    }
3178
3179    fn multiply_matrix_matrix(
3180        &mut self,
3181        left_components: &[Handle<Expression>],
3182        right_components: &[Handle<Expression>],
3183        left_rows: crate::VectorSize,
3184        right_columns: crate::VectorSize,
3185        scalar: crate::Scalar,
3186        span: Span,
3187    ) -> Result<Expression, ConstantEvaluatorError> {
3188        let left_nc = left_components.len();
3189        let left_nr = left_rows as usize;
3190        let right_nc = right_columns as usize;
3191        let right_nr = left_nc;
3192
3193        let mut result = Vec::with_capacity(right_nc);
3194        let result_ty = self.types.insert(
3195            Type {
3196                name: None,
3197                inner: TypeInner::Matrix {
3198                    columns: right_columns,
3199                    rows: left_rows,
3200                    scalar,
3201                },
3202            },
3203            span,
3204        );
3205        let result_column_ty = self.types.insert(
3206            Type {
3207                name: None,
3208                inner: TypeInner::Vector {
3209                    size: left_rows,
3210                    scalar,
3211                },
3212            },
3213            span,
3214        );
3215
3216        let left_flattened = self.flatten_matrix(left_components);
3217        let right_flattened = self.flatten_matrix(right_components);
3218        for c in 0..right_nc {
3219            let result_column = (0..left_nr)
3220                .map(|r| {
3221                    let row = left_flattened.iter().skip(r).step_by(left_nr);
3222                    let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3223                    self.dot_exprs(row.cloned(), column.cloned(), span)
3224                })
3225                .collect::<Result<Vec<_>, _>>()?;
3226            let expr = Expression::Compose {
3227                ty: result_column_ty,
3228                components: result_column,
3229            };
3230            let handle = self.register_evaluated_expr(expr, span)?;
3231            result.push(handle);
3232        }
3233        Ok(Expression::Compose {
3234            ty: result_ty,
3235            components: result,
3236        })
3237    }
3238
3239    fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3240        let mut flattened = ArrayVec::<_, 16>::new();
3241        for &column in columns {
3242            let Expression::Compose { ref components, .. } = self.expressions[column] else {
3243                unreachable!()
3244            };
3245            flattened.extend(components.iter().cloned());
3246        }
3247        flattened
3248    }
3249
3250    fn dot_exprs(
3251        &mut self,
3252        left: impl Iterator<Item = Handle<Expression>>,
3253        right: impl Iterator<Item = Handle<Expression>>,
3254        span: Span,
3255    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3256        let mut acc = None;
3257        for (l, r) in left.zip(right) {
3258            let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3259            match acc.as_mut() {
3260                Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3261                None => acc = Some(result),
3262            }
3263        }
3264        Ok(acc.unwrap())
3265    }
3266
3267    fn relational(
3268        &mut self,
3269        fun: RelationalFunction,
3270        arg: Handle<Expression>,
3271        span: Span,
3272    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3273        let arg = self.eval_zero_value_and_splat(arg, span)?;
3274        match fun {
3275            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3276                Expression::Literal(Literal::Bool(_)) => Ok(arg),
3277                Expression::Compose { ty, ref components }
3278                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3279                {
3280                    let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3281                    for component in
3282                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3283                    {
3284                        match self.expressions[component] {
3285                            Expression::Literal(Literal::Bool(val)) => {
3286                                bool_components.push(val);
3287                            }
3288                            _ => {
3289                                return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3290                            }
3291                        }
3292                    }
3293                    let components = bool_components;
3294                    let result = match fun {
3295                        RelationalFunction::All => components.iter().all(|c| *c),
3296                        RelationalFunction::Any => components.iter().any(|c| *c),
3297                        _ => unreachable!(),
3298                    };
3299                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3300                }
3301                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3302            },
3303            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3304                "{fun:?} built-in function"
3305            ))),
3306        }
3307    }
3308
3309    /// Deep copy `expr` from `expressions` into `self.expressions`.
3310    ///
3311    /// Return the root of the new copy.
3312    ///
3313    /// This is used when we're evaluating expressions in a function's
3314    /// expression arena that refer to a constant: we need to copy the
3315    /// constant's value into the function's arena so we can operate on it.
3316    fn copy_from(
3317        &mut self,
3318        expr: Handle<Expression>,
3319        expressions: &Arena<Expression>,
3320    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3321        let span = expressions.get_span(expr);
3322        match expressions[expr] {
3323            ref expr @ (Expression::Literal(_)
3324            | Expression::Constant(_)
3325            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3326            Expression::Compose { ty, ref components } => {
3327                let mut components = components.clone();
3328                for component in &mut components {
3329                    *component = self.copy_from(*component, expressions)?;
3330                }
3331                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3332            }
3333            Expression::Splat { size, value } => {
3334                let value = self.copy_from(value, expressions)?;
3335                self.register_evaluated_expr(Expression::Splat { size, value }, span)
3336            }
3337            _ => {
3338                log::debug!("copy_from: SubexpressionsAreNotConstant");
3339                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3340            }
3341        }
3342    }
3343
3344    /// Returns the total number of components, after flattening, of a vector compose expression.
3345    fn vector_compose_flattened_size(
3346        &self,
3347        components: &[Handle<Expression>],
3348    ) -> Result<usize, ConstantEvaluatorError> {
3349        components
3350            .iter()
3351            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3352                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3353                    TypeInner::Scalar(_) => 1,
3354                    // We trust that the vector size of `component` is correct,
3355                    // as it will have already been validated when `component`
3356                    // was registered.
3357                    TypeInner::Vector { size, .. } => size as usize,
3358                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3359                };
3360                Ok(acc + size)
3361            })
3362    }
3363
3364    fn register_evaluated_expr(
3365        &mut self,
3366        expr: Expression,
3367        span: Span,
3368    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3369        // It suffices to only check_literal_value() for `Literal` expressions,
3370        // since we only register one expression at a time, `Compose`
3371        // expressions can only refer to other expressions, and `ZeroValue`
3372        // expressions are always okay.
3373        if let Expression::Literal(literal) = expr {
3374            crate::valid::check_literal_value(literal)?;
3375        }
3376
3377        // Ensure vector composes contain the correct number of components. We
3378        // do so here when each compose is registered to avoid having to deal
3379        // with the mess each time the compose is used in another expression.
3380        if let Expression::Compose { ty, ref components } = expr {
3381            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3382                let expected = size as usize;
3383                let actual = self.vector_compose_flattened_size(components)?;
3384                if expected != actual {
3385                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3386                        expected,
3387                        actual,
3388                    });
3389                }
3390            }
3391        }
3392
3393        Ok(self.append_expr(expr, span, ExpressionKind::Const))
3394    }
3395
3396    fn append_expr(
3397        &mut self,
3398        expr: Expression,
3399        span: Span,
3400        expr_type: ExpressionKind,
3401    ) -> Handle<Expression> {
3402        let h = match self.behavior {
3403            Behavior::Wgsl(
3404                WgslRestrictions::Runtime(ref mut function_local_data)
3405                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3406            )
3407            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3408                let is_running = function_local_data.emitter.is_running();
3409                let needs_pre_emit = expr.needs_pre_emit();
3410                if is_running && needs_pre_emit {
3411                    function_local_data
3412                        .block
3413                        .extend(function_local_data.emitter.finish(self.expressions));
3414                    let h = self.expressions.append(expr, span);
3415                    function_local_data.emitter.start(self.expressions);
3416                    h
3417                } else {
3418                    self.expressions.append(expr, span)
3419                }
3420            }
3421            _ => self.expressions.append(expr, span),
3422        };
3423        self.expression_kind_tracker.insert(h, expr_type);
3424        h
3425    }
3426
3427    /// Resolve the type of `expr` if it is a constant expression.
3428    ///
3429    /// If `expr` was evaluated to a constant, returns its type.
3430    /// Otherwise, returns an error.
3431    fn resolve_type(
3432        &self,
3433        expr: Handle<Expression>,
3434    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3435        use crate::proc::TypeResolution as Tr;
3436        use crate::Expression as Ex;
3437        let resolution = match self.expressions[expr] {
3438            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3439            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3440            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3441            Ex::Splat { size, value } => {
3442                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3443                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3444                };
3445                Tr::Value(TypeInner::Vector { scalar, size })
3446            }
3447            _ => {
3448                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3449                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3450            }
3451        };
3452
3453        Ok(resolution)
3454    }
3455
3456    fn select(
3457        &mut self,
3458        reject: Handle<Expression>,
3459        accept: Handle<Expression>,
3460        condition: Handle<Expression>,
3461        span: Span,
3462    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3463        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3464
3465        let reject = arg(reject)?;
3466        let accept = arg(accept)?;
3467        let condition = arg(condition)?;
3468
3469        let select_single_component =
3470            |this: &mut Self, reject_scalar, reject, accept, condition| {
3471                let accept = this.cast(accept, reject_scalar, span)?;
3472                if condition {
3473                    Ok(accept)
3474                } else {
3475                    Ok(reject)
3476                }
3477            };
3478
3479        match (&self.expressions[reject], &self.expressions[accept]) {
3480            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3481                let reject_scalar = reject_lit.scalar();
3482                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3483                else {
3484                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3485                };
3486                select_single_component(self, reject_scalar, reject, accept, condition)
3487            }
3488            (
3489                &Expression::Compose {
3490                    ty: reject_ty,
3491                    components: ref reject_components,
3492                },
3493                &Expression::Compose {
3494                    ty: accept_ty,
3495                    components: ref accept_components,
3496                },
3497            ) => {
3498                let ty_deets = |ty| {
3499                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3500                    (size.unwrap(), scalar)
3501                };
3502
3503                let expected_vec_size = {
3504                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3505                        [reject_ty, accept_ty].map(ty_deets);
3506
3507                    if reject_vec_size != accept_vec_size {
3508                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3509                            reject: reject_vec_size,
3510                            accept: accept_vec_size,
3511                        });
3512                    }
3513                    reject_vec_size
3514                };
3515
3516                let condition_components = match self.expressions[condition] {
3517                    Expression::Literal(Literal::Bool(condition)) => {
3518                        vec![condition; (expected_vec_size as u8).into()]
3519                    }
3520                    Expression::Compose {
3521                        ty: condition_ty,
3522                        components: ref condition_components,
3523                    } => {
3524                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3525                        if condition_scalar.kind != ScalarKind::Bool {
3526                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3527                        }
3528                        if condition_vec_size != expected_vec_size {
3529                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3530                        }
3531                        condition_components
3532                            .iter()
3533                            .copied()
3534                            .map(|component| match &self.expressions[component] {
3535                                &Expression::Literal(Literal::Bool(condition)) => condition,
3536                                _ => unreachable!(),
3537                            })
3538                            .collect()
3539                    }
3540
3541                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3542                };
3543
3544                let evaluated = Expression::Compose {
3545                    ty: reject_ty,
3546                    components: reject_components
3547                        .clone()
3548                        .into_iter()
3549                        .zip(accept_components.clone().into_iter())
3550                        .zip(condition_components.into_iter())
3551                        .map(|((reject, accept), condition)| {
3552                            let reject_scalar = match &self.expressions[reject] {
3553                                &Expression::Literal(lit) => lit.scalar(),
3554                                _ => unreachable!(),
3555                            };
3556                            select_single_component(self, reject_scalar, reject, accept, condition)
3557                        })
3558                        .collect::<Result<_, _>>()?,
3559                };
3560                self.register_evaluated_expr(evaluated, span)
3561            }
3562            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3563        }
3564    }
3565}
3566
3567fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3568    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3569    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3570    // return a right-to-left bit index of 0.
3571    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3572        match e {
3573            idx @ 0..=31 => idx,
3574            32 => u32::MAX,
3575            _ => unreachable!(),
3576        }
3577    };
3578    match concrete_int {
3579        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3580        ConcreteInt::I32([e]) => {
3581            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3582        }
3583    }
3584}
3585
3586#[test]
3587fn first_trailing_bit_smoke() {
3588    assert_eq!(
3589        first_trailing_bit(ConcreteInt::I32([0])),
3590        ConcreteInt::I32([-1])
3591    );
3592    assert_eq!(
3593        first_trailing_bit(ConcreteInt::I32([1])),
3594        ConcreteInt::I32([0])
3595    );
3596    assert_eq!(
3597        first_trailing_bit(ConcreteInt::I32([2])),
3598        ConcreteInt::I32([1])
3599    );
3600    assert_eq!(
3601        first_trailing_bit(ConcreteInt::I32([-1])),
3602        ConcreteInt::I32([0]),
3603    );
3604    assert_eq!(
3605        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3606        ConcreteInt::I32([31]),
3607    );
3608    assert_eq!(
3609        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3610        ConcreteInt::I32([0]),
3611    );
3612    for idx in 0..32 {
3613        assert_eq!(
3614            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3615            ConcreteInt::I32([idx])
3616        )
3617    }
3618
3619    assert_eq!(
3620        first_trailing_bit(ConcreteInt::U32([0])),
3621        ConcreteInt::U32([u32::MAX])
3622    );
3623    assert_eq!(
3624        first_trailing_bit(ConcreteInt::U32([1])),
3625        ConcreteInt::U32([0])
3626    );
3627    assert_eq!(
3628        first_trailing_bit(ConcreteInt::U32([2])),
3629        ConcreteInt::U32([1])
3630    );
3631    assert_eq!(
3632        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3633        ConcreteInt::U32([31]),
3634    );
3635    assert_eq!(
3636        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3637        ConcreteInt::U32([0]),
3638    );
3639    for idx in 0..32 {
3640        assert_eq!(
3641            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3642            ConcreteInt::U32([idx])
3643        )
3644    }
3645}
3646
3647fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3648    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3649    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3650    // index of 0.
3651    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3652        match e {
3653            idx @ 0..=31 => 31 - idx,
3654            32 => u32::MAX,
3655            _ => unreachable!(),
3656        }
3657    };
3658    match concrete_int {
3659        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3660            let rtl_bit_index = if e.is_negative() {
3661                e.leading_ones()
3662            } else {
3663                e.leading_zeros()
3664            };
3665            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3666        }]),
3667        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3668    }
3669}
3670
3671#[test]
3672fn first_leading_bit_smoke() {
3673    assert_eq!(
3674        first_leading_bit(ConcreteInt::I32([-1])),
3675        ConcreteInt::I32([-1])
3676    );
3677    assert_eq!(
3678        first_leading_bit(ConcreteInt::I32([0])),
3679        ConcreteInt::I32([-1])
3680    );
3681    assert_eq!(
3682        first_leading_bit(ConcreteInt::I32([1])),
3683        ConcreteInt::I32([0])
3684    );
3685    assert_eq!(
3686        first_leading_bit(ConcreteInt::I32([-2])),
3687        ConcreteInt::I32([0])
3688    );
3689    assert_eq!(
3690        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3691        ConcreteInt::I32([12])
3692    );
3693    assert_eq!(
3694        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3695        ConcreteInt::I32([30])
3696    );
3697    assert_eq!(
3698        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3699        ConcreteInt::I32([30])
3700    );
3701    // NOTE: Ignore the sign bit, which is a separate (above) case.
3702    for idx in 0..(32 - 1) {
3703        assert_eq!(
3704            first_leading_bit(ConcreteInt::I32([1 << idx])),
3705            ConcreteInt::I32([idx])
3706        );
3707    }
3708    for idx in 1..(32 - 1) {
3709        assert_eq!(
3710            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3711            ConcreteInt::I32([idx - 1])
3712        );
3713    }
3714
3715    assert_eq!(
3716        first_leading_bit(ConcreteInt::U32([0])),
3717        ConcreteInt::U32([u32::MAX])
3718    );
3719    assert_eq!(
3720        first_leading_bit(ConcreteInt::U32([1])),
3721        ConcreteInt::U32([0])
3722    );
3723    assert_eq!(
3724        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3725        ConcreteInt::U32([31])
3726    );
3727    for idx in 0..32 {
3728        assert_eq!(
3729            first_leading_bit(ConcreteInt::U32([1 << idx])),
3730            ConcreteInt::U32([idx])
3731        )
3732    }
3733}
3734
3735/// Trait for conversions of abstract values to concrete types.
3736trait TryFromAbstract<T>: Sized {
3737    /// Convert an abstract literal `value` to `Self`.
3738    ///
3739    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3740    /// WGSL, we follow WGSL's conversion rules here:
3741    ///
3742    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3743    ///   from [`AbstractInt`] to an integer type are either lossless or an
3744    ///   error.
3745    ///
3746    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3747    ///   to floating point in constant expressions and override
3748    ///   expressions are errors if the value is out of range for the
3749    ///   destination type, but rounding is okay.
3750    ///
3751    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3752    ///   other floating point type, following the scalar floating point to
3753    ///   integral conversion algorithm (§15.7.6). There is no automatic
3754    ///   conversion from AbstractFloat to integer types.
3755    ///
3756    /// [`AbstractInt`]: crate::Literal::AbstractInt
3757    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3758    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3759}
3760
3761impl TryFromAbstract<i64> for i32 {
3762    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3763        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3764            value: format!("{value:?}"),
3765            to_type: "i32",
3766        })
3767    }
3768}
3769
3770impl TryFromAbstract<i64> for u32 {
3771    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3772        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3773            value: format!("{value:?}"),
3774            to_type: "u32",
3775        })
3776    }
3777}
3778
3779impl TryFromAbstract<i64> for u64 {
3780    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3781        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3782            value: format!("{value:?}"),
3783            to_type: "u64",
3784        })
3785    }
3786}
3787
3788impl TryFromAbstract<i64> for i64 {
3789    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3790        Ok(value)
3791    }
3792}
3793
3794impl TryFromAbstract<i64> for f32 {
3795    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3796        let f = value as f32;
3797        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3798        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3799        // overflow here.
3800        Ok(f)
3801    }
3802}
3803
3804impl TryFromAbstract<f64> for f32 {
3805    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3806        let f = value as f32;
3807        if f.is_infinite() {
3808            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3809                value: format!("{value:?}"),
3810                to_type: "f32",
3811            });
3812        }
3813        Ok(f)
3814    }
3815}
3816
3817impl TryFromAbstract<i64> for f64 {
3818    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3819        let f = value as f64;
3820        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3821        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3822        // overflow here.
3823        Ok(f)
3824    }
3825}
3826
3827impl TryFromAbstract<f64> for f64 {
3828    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3829        Ok(value)
3830    }
3831}
3832
3833impl TryFromAbstract<f64> for i32 {
3834    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3835        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3836        // To convert a floating point scalar value X to an integer scalar type T:
3837        // * If X is a NaN, the result is an indeterminate value in T.
3838        // * If X is exactly representable in the target type T, then the
3839        //   result is that value.
3840        // * Otherwise, the result is the value in T closest to truncate(X) and
3841        //   also exactly representable in the original floating point type.
3842        //
3843        // A rust cast satisfies these requirements apart from "the result
3844        // is... exactly representable in the original floating point type".
3845        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3846        // we're all good.
3847        Ok(value as i32)
3848    }
3849}
3850
3851impl TryFromAbstract<f64> for u32 {
3852    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3853        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3854        // so a simple rust cast is sufficient.
3855        Ok(value as u32)
3856    }
3857}
3858
3859impl TryFromAbstract<f64> for i64 {
3860    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3861        // As above, except we clamp to the minimum and maximum values
3862        // representable by both f64 and i64.
3863        use crate::proc::type_methods::IntFloatLimits;
3864        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3865    }
3866}
3867
3868impl TryFromAbstract<f64> for u64 {
3869    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3870        // As above, this time clamping to the minimum and maximum values
3871        // representable by both f64 and u64.
3872        use crate::proc::type_methods::IntFloatLimits;
3873        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3874    }
3875}
3876
3877impl TryFromAbstract<f64> for f16 {
3878    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3879        let f = f16::from_f64(value);
3880        if f.is_infinite() {
3881            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3882                value: format!("{value:?}"),
3883                to_type: "f16",
3884            });
3885        }
3886        Ok(f)
3887    }
3888}
3889
3890impl TryFromAbstract<i64> for f16 {
3891    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3892        let f = f16::from_i64(value);
3893        if f.is_none() {
3894            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3895                value: format!("{value:?}"),
3896                to_type: "f16",
3897            });
3898        }
3899        Ok(f.unwrap())
3900    }
3901}
3902
3903fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3904where
3905    T: Copy,
3906    T: core::ops::Mul<T, Output = T>,
3907    T: core::ops::Sub<T, Output = T>,
3908{
3909    [
3910        a[1] * b[2] - a[2] * b[1],
3911        a[2] * b[0] - a[0] * b[2],
3912        a[0] * b[1] - a[1] * b[0],
3913    ]
3914}
3915
3916#[cfg(test)]
3917mod tests {
3918    use alloc::{vec, vec::Vec};
3919
3920    use crate::{
3921        Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
3922        Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
3923    };
3924
3925    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3926
3927    #[test]
3928    fn unary_op() {
3929        let mut types = UniqueArena::new();
3930        let mut constants = Arena::new();
3931        let overrides = Arena::new();
3932        let mut global_expressions = Arena::new();
3933
3934        let scalar_ty = types.insert(
3935            Type {
3936                name: None,
3937                inner: TypeInner::Scalar(crate::Scalar::I32),
3938            },
3939            Default::default(),
3940        );
3941
3942        let vec_ty = types.insert(
3943            Type {
3944                name: None,
3945                inner: TypeInner::Vector {
3946                    size: VectorSize::Bi,
3947                    scalar: crate::Scalar::I32,
3948                },
3949            },
3950            Default::default(),
3951        );
3952
3953        let h = constants.append(
3954            Constant {
3955                name: None,
3956                ty: scalar_ty,
3957                init: global_expressions
3958                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3959            },
3960            Default::default(),
3961        );
3962
3963        let h1 = constants.append(
3964            Constant {
3965                name: None,
3966                ty: scalar_ty,
3967                init: global_expressions
3968                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3969            },
3970            Default::default(),
3971        );
3972
3973        let vec_h = constants.append(
3974            Constant {
3975                name: None,
3976                ty: vec_ty,
3977                init: global_expressions.append(
3978                    Expression::Compose {
3979                        ty: vec_ty,
3980                        components: vec![constants[h].init, constants[h1].init],
3981                    },
3982                    Default::default(),
3983                ),
3984            },
3985            Default::default(),
3986        );
3987
3988        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3989        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3990
3991        let expr2 = Expression::Unary {
3992            op: UnaryOperator::Negate,
3993            expr,
3994        };
3995
3996        let expr3 = Expression::Unary {
3997            op: UnaryOperator::BitwiseNot,
3998            expr,
3999        };
4000
4001        let expr4 = Expression::Unary {
4002            op: UnaryOperator::BitwiseNot,
4003            expr: expr1,
4004        };
4005
4006        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4007        let mut solver = ConstantEvaluator {
4008            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4009            types: &mut types,
4010            constants: &constants,
4011            overrides: &overrides,
4012            expressions: &mut global_expressions,
4013            expression_kind_tracker,
4014            layouter: &mut crate::proc::Layouter::default(),
4015        };
4016
4017        let res1 = solver
4018            .try_eval_and_append(expr2, Default::default())
4019            .unwrap();
4020        let res2 = solver
4021            .try_eval_and_append(expr3, Default::default())
4022            .unwrap();
4023        let res3 = solver
4024            .try_eval_and_append(expr4, Default::default())
4025            .unwrap();
4026
4027        assert_eq!(
4028            global_expressions[res1],
4029            Expression::Literal(Literal::I32(-4))
4030        );
4031
4032        assert_eq!(
4033            global_expressions[res2],
4034            Expression::Literal(Literal::I32(!4))
4035        );
4036
4037        let res3_inner = &global_expressions[res3];
4038
4039        match *res3_inner {
4040            Expression::Compose {
4041                ref ty,
4042                ref components,
4043            } => {
4044                assert_eq!(*ty, vec_ty);
4045                let mut components_iter = components.iter().copied();
4046                assert_eq!(
4047                    global_expressions[components_iter.next().unwrap()],
4048                    Expression::Literal(Literal::I32(!4))
4049                );
4050                assert_eq!(
4051                    global_expressions[components_iter.next().unwrap()],
4052                    Expression::Literal(Literal::I32(!8))
4053                );
4054                assert!(components_iter.next().is_none());
4055            }
4056            _ => panic!("Expected vector"),
4057        }
4058    }
4059
4060    #[test]
4061    fn matrix_op() {
4062        let mut helper = MatrixTestHelper::new();
4063
4064        for nc in 2..=4 {
4065            for nr in 2..=4 {
4066                // Validates multiplication on vector-matrix.
4067                // vecR(0, 1, .., r) * matCxR(0, 1, .., nc * nr)
4068                let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4069                let expected = (0..nc)
4070                    .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4071                    .collect::<Vec<f32>>();
4072                assert_eq!(evaluated, expected);
4073
4074                // Validates multiplication on matrix-vector.
4075                // matCxR(0, 1, .., nc * nr) * vecC(0, 1, .., nc)
4076                let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4077                let expected = (0..nr)
4078                    .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4079                    .collect::<Vec<f32>>();
4080                assert_eq!(evaluated, expected);
4081
4082                for k in 2..=4 {
4083                    // Validates multiplication on matrix-matrix.
4084                    // matKxR(0, 1, .., k * nr) * matCxK(0, 1, .., nc * k)
4085                    let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4086                    let expected = (0..nc)
4087                        .flat_map(|c| {
4088                            (0..nr).map(move |r| {
4089                                (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4090                            })
4091                        })
4092                        .collect::<Vec<f32>>();
4093                    assert_eq!(evaluated, expected);
4094                }
4095            }
4096        }
4097    }
4098
4099    /// Test fixture providing pre-built f32 vector and matrix constant
4100    /// expressions with sequential element values, used to evaluate and verify
4101    /// matrix operations.
4102    struct MatrixTestHelper {
4103        types: UniqueArena<Type>,
4104        expressions: Arena<Expression>,
4105        /// Vector expressions from [0, 1] to [0, 1, 2, 3].
4106        vec_exprs: FastHashMap<usize, Handle<Expression>>,
4107        /// Matrix expressions from [0, .., 3] to [0, .., 15].
4108        mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4109    }
4110
4111    impl MatrixTestHelper {
4112        fn new() -> Self {
4113            let mut types = UniqueArena::new();
4114            let mut expressions = Arena::new();
4115            let span = crate::Span::default();
4116
4117            let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4118            for c in 2..=4 {
4119                let vec_ty = types.insert(
4120                    Type {
4121                        name: None,
4122                        inner: TypeInner::Vector {
4123                            size: Self::int_to_vector_size(c),
4124                            scalar: crate::Scalar::F32,
4125                        },
4126                    },
4127                    span,
4128                );
4129                vec_tys.insert(c, vec_ty);
4130                for r in 2..=4 {
4131                    let mat_ty = types.insert(
4132                        Type {
4133                            name: None,
4134                            inner: TypeInner::Matrix {
4135                                columns: Self::int_to_vector_size(c),
4136                                rows: Self::int_to_vector_size(r),
4137                                scalar: crate::Scalar::F32,
4138                            },
4139                        },
4140                        span,
4141                    );
4142                    mat_tys.insert((c, r), mat_ty);
4143                }
4144            }
4145
4146            let mut lit_exprs = FastHashMap::default();
4147            for i in 0..16 {
4148                let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4149                lit_exprs.insert(i, expr);
4150            }
4151
4152            let mut vec_exprs = FastHashMap::default();
4153            for c in 2..=4 {
4154                let expr = expressions.append(
4155                    Expression::Compose {
4156                        ty: *vec_tys.get(&c).unwrap(),
4157                        components: (0..c)
4158                            .map(|i| *lit_exprs.get(&i).unwrap())
4159                            .collect::<Vec<_>>(),
4160                    },
4161                    span,
4162                );
4163                vec_exprs.insert(c, expr);
4164            }
4165
4166            let mut mat_exprs = FastHashMap::default();
4167            for c in 2..=4 {
4168                for r in 2..=4 {
4169                    let mut columns = Vec::with_capacity(c);
4170                    for cc in 0..c {
4171                        let start = cc * r;
4172                        let expr = expressions.append(
4173                            Expression::Compose {
4174                                ty: *vec_tys.get(&r).unwrap(),
4175                                components: (start..start + r)
4176                                    .map(|i| *lit_exprs.get(&i).unwrap())
4177                                    .collect::<Vec<_>>(),
4178                            },
4179                            span,
4180                        );
4181                        columns.push(expr);
4182                    }
4183
4184                    let expr = expressions.append(
4185                        Expression::Compose {
4186                            ty: *mat_tys.get(&(c, r)).unwrap(),
4187                            components: columns,
4188                        },
4189                        span,
4190                    );
4191                    mat_exprs.insert((c, r), expr);
4192                }
4193            }
4194
4195            Self {
4196                types,
4197                expressions,
4198                vec_exprs,
4199                mat_exprs,
4200            }
4201        }
4202
4203        /// Evaluates vec[0..nr] * mat[0..nc*nr] and returns the result as f32s.
4204        fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4205            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4206            let mut solver = ConstantEvaluator {
4207                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4208                types: &mut self.types,
4209                constants: &Arena::new(),
4210                overrides: &Arena::new(),
4211                expressions: &mut self.expressions,
4212                expression_kind_tracker,
4213                layouter: &mut crate::proc::Layouter::default(),
4214            };
4215
4216            let result = solver
4217                .try_eval_and_append(
4218                    Expression::Binary {
4219                        op: BinaryOperator::Multiply,
4220                        left: *self.vec_exprs.get(&nr).unwrap(),
4221                        right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4222                    },
4223                    Default::default(),
4224                )
4225                .unwrap();
4226            self.flatten(result)
4227        }
4228
4229        /// Evaluates mat[0..nc*nr] * vec[0..nc] and returns the result as f32s.
4230        fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4231            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4232            let mut solver = ConstantEvaluator {
4233                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4234                types: &mut self.types,
4235                constants: &Arena::new(),
4236                overrides: &Arena::new(),
4237                expressions: &mut self.expressions,
4238                expression_kind_tracker,
4239                layouter: &mut crate::proc::Layouter::default(),
4240            };
4241
4242            let result = solver
4243                .try_eval_and_append(
4244                    Expression::Binary {
4245                        op: BinaryOperator::Multiply,
4246                        left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4247                        right: *self.vec_exprs.get(&nc).unwrap(),
4248                    },
4249                    Default::default(),
4250                )
4251                .unwrap();
4252            self.flatten(result)
4253        }
4254
4255        /// Evaluates mat[0..k*l_nr] * mat[0..r_nc*k] and returns the result as
4256        /// f32s.
4257        fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4258            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4259            let mut solver = ConstantEvaluator {
4260                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4261                types: &mut self.types,
4262                constants: &Arena::new(),
4263                overrides: &Arena::new(),
4264                expressions: &mut self.expressions,
4265                expression_kind_tracker,
4266                layouter: &mut crate::proc::Layouter::default(),
4267            };
4268
4269            let result = solver
4270                .try_eval_and_append(
4271                    Expression::Binary {
4272                        op: BinaryOperator::Multiply,
4273                        left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4274                        right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4275                    },
4276                    Default::default(),
4277                )
4278                .unwrap();
4279            self.flatten(result)
4280        }
4281
4282        fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4283            let Expression::Compose {
4284                ref components,
4285                ref ty,
4286            } = self.expressions[expr]
4287            else {
4288                unreachable!()
4289            };
4290
4291            match self.types[*ty].inner {
4292                TypeInner::Vector { .. } => components
4293                    .iter()
4294                    .map(|&comp| {
4295                        let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4296                            unreachable!()
4297                        };
4298                        v
4299                    })
4300                    .collect(),
4301                TypeInner::Matrix { .. } => components
4302                    .iter()
4303                    .flat_map(|&comp| self.flatten(comp))
4304                    .collect(),
4305                _ => unreachable!(),
4306            }
4307        }
4308
4309        fn int_to_vector_size(int: usize) -> VectorSize {
4310            match int {
4311                2 => VectorSize::Bi,
4312                3 => VectorSize::Tri,
4313                4 => VectorSize::Quad,
4314                _ => unreachable!(),
4315            }
4316        }
4317    }
4318
4319    #[test]
4320    fn cast() {
4321        let mut types = UniqueArena::new();
4322        let mut constants = Arena::new();
4323        let overrides = Arena::new();
4324        let mut global_expressions = Arena::new();
4325
4326        let scalar_ty = types.insert(
4327            Type {
4328                name: None,
4329                inner: TypeInner::Scalar(crate::Scalar::I32),
4330            },
4331            Default::default(),
4332        );
4333
4334        let h = constants.append(
4335            Constant {
4336                name: None,
4337                ty: scalar_ty,
4338                init: global_expressions
4339                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4340            },
4341            Default::default(),
4342        );
4343
4344        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4345
4346        let root = Expression::As {
4347            expr,
4348            kind: ScalarKind::Bool,
4349            convert: Some(crate::BOOL_WIDTH),
4350        };
4351
4352        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4353        let mut solver = ConstantEvaluator {
4354            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4355            types: &mut types,
4356            constants: &constants,
4357            overrides: &overrides,
4358            expressions: &mut global_expressions,
4359            expression_kind_tracker,
4360            layouter: &mut crate::proc::Layouter::default(),
4361        };
4362
4363        let res = solver
4364            .try_eval_and_append(root, Default::default())
4365            .unwrap();
4366
4367        assert_eq!(
4368            global_expressions[res],
4369            Expression::Literal(Literal::Bool(true))
4370        );
4371    }
4372
4373    #[test]
4374    fn access() {
4375        let mut types = UniqueArena::new();
4376        let mut constants = Arena::new();
4377        let overrides = Arena::new();
4378        let mut global_expressions = Arena::new();
4379
4380        let matrix_ty = types.insert(
4381            Type {
4382                name: None,
4383                inner: TypeInner::Matrix {
4384                    columns: VectorSize::Bi,
4385                    rows: VectorSize::Tri,
4386                    scalar: crate::Scalar::F32,
4387                },
4388            },
4389            Default::default(),
4390        );
4391
4392        let vec_ty = types.insert(
4393            Type {
4394                name: None,
4395                inner: TypeInner::Vector {
4396                    size: VectorSize::Tri,
4397                    scalar: crate::Scalar::F32,
4398                },
4399            },
4400            Default::default(),
4401        );
4402
4403        let mut vec1_components = Vec::with_capacity(3);
4404        let mut vec2_components = Vec::with_capacity(3);
4405
4406        for i in 0..3 {
4407            let h = global_expressions.append(
4408                Expression::Literal(Literal::F32(i as f32)),
4409                Default::default(),
4410            );
4411
4412            vec1_components.push(h)
4413        }
4414
4415        for i in 3..6 {
4416            let h = global_expressions.append(
4417                Expression::Literal(Literal::F32(i as f32)),
4418                Default::default(),
4419            );
4420
4421            vec2_components.push(h)
4422        }
4423
4424        let vec1 = constants.append(
4425            Constant {
4426                name: None,
4427                ty: vec_ty,
4428                init: global_expressions.append(
4429                    Expression::Compose {
4430                        ty: vec_ty,
4431                        components: vec1_components,
4432                    },
4433                    Default::default(),
4434                ),
4435            },
4436            Default::default(),
4437        );
4438
4439        let vec2 = constants.append(
4440            Constant {
4441                name: None,
4442                ty: vec_ty,
4443                init: global_expressions.append(
4444                    Expression::Compose {
4445                        ty: vec_ty,
4446                        components: vec2_components,
4447                    },
4448                    Default::default(),
4449                ),
4450            },
4451            Default::default(),
4452        );
4453
4454        let h = constants.append(
4455            Constant {
4456                name: None,
4457                ty: matrix_ty,
4458                init: global_expressions.append(
4459                    Expression::Compose {
4460                        ty: matrix_ty,
4461                        components: vec![constants[vec1].init, constants[vec2].init],
4462                    },
4463                    Default::default(),
4464                ),
4465            },
4466            Default::default(),
4467        );
4468
4469        let base = global_expressions.append(Expression::Constant(h), Default::default());
4470
4471        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4472        let mut solver = ConstantEvaluator {
4473            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4474            types: &mut types,
4475            constants: &constants,
4476            overrides: &overrides,
4477            expressions: &mut global_expressions,
4478            expression_kind_tracker,
4479            layouter: &mut crate::proc::Layouter::default(),
4480        };
4481
4482        let root1 = Expression::AccessIndex { base, index: 1 };
4483
4484        let res1 = solver
4485            .try_eval_and_append(root1, Default::default())
4486            .unwrap();
4487
4488        let root2 = Expression::AccessIndex {
4489            base: res1,
4490            index: 2,
4491        };
4492
4493        let res2 = solver
4494            .try_eval_and_append(root2, Default::default())
4495            .unwrap();
4496
4497        match global_expressions[res1] {
4498            Expression::Compose {
4499                ref ty,
4500                ref components,
4501            } => {
4502                assert_eq!(*ty, vec_ty);
4503                let mut components_iter = components.iter().copied();
4504                assert_eq!(
4505                    global_expressions[components_iter.next().unwrap()],
4506                    Expression::Literal(Literal::F32(3.))
4507                );
4508                assert_eq!(
4509                    global_expressions[components_iter.next().unwrap()],
4510                    Expression::Literal(Literal::F32(4.))
4511                );
4512                assert_eq!(
4513                    global_expressions[components_iter.next().unwrap()],
4514                    Expression::Literal(Literal::F32(5.))
4515                );
4516                assert!(components_iter.next().is_none());
4517            }
4518            _ => panic!("Expected vector"),
4519        }
4520
4521        assert_eq!(
4522            global_expressions[res2],
4523            Expression::Literal(Literal::F32(5.))
4524        );
4525    }
4526
4527    #[test]
4528    fn compose_of_constants() {
4529        let mut types = UniqueArena::new();
4530        let mut constants = Arena::new();
4531        let overrides = Arena::new();
4532        let mut global_expressions = Arena::new();
4533
4534        let i32_ty = types.insert(
4535            Type {
4536                name: None,
4537                inner: TypeInner::Scalar(crate::Scalar::I32),
4538            },
4539            Default::default(),
4540        );
4541
4542        let vec2_i32_ty = types.insert(
4543            Type {
4544                name: None,
4545                inner: TypeInner::Vector {
4546                    size: VectorSize::Bi,
4547                    scalar: crate::Scalar::I32,
4548                },
4549            },
4550            Default::default(),
4551        );
4552
4553        let h = constants.append(
4554            Constant {
4555                name: None,
4556                ty: i32_ty,
4557                init: global_expressions
4558                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4559            },
4560            Default::default(),
4561        );
4562
4563        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4564
4565        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4566        let mut solver = ConstantEvaluator {
4567            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4568            types: &mut types,
4569            constants: &constants,
4570            overrides: &overrides,
4571            expressions: &mut global_expressions,
4572            expression_kind_tracker,
4573            layouter: &mut crate::proc::Layouter::default(),
4574        };
4575
4576        let solved_compose = solver
4577            .try_eval_and_append(
4578                Expression::Compose {
4579                    ty: vec2_i32_ty,
4580                    components: vec![h_expr, h_expr],
4581                },
4582                Default::default(),
4583            )
4584            .unwrap();
4585        let solved_negate = solver
4586            .try_eval_and_append(
4587                Expression::Unary {
4588                    op: UnaryOperator::Negate,
4589                    expr: solved_compose,
4590                },
4591                Default::default(),
4592            )
4593            .unwrap();
4594
4595        let pass = match global_expressions[solved_negate] {
4596            Expression::Compose { ty, ref components } => {
4597                ty == vec2_i32_ty
4598                    && components.iter().all(|&component| {
4599                        let component = &global_expressions[component];
4600                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4601                    })
4602            }
4603            _ => false,
4604        };
4605        if !pass {
4606            panic!("unexpected evaluation result")
4607        }
4608    }
4609
4610    #[test]
4611    fn splat_of_constant() {
4612        let mut types = UniqueArena::new();
4613        let mut constants = Arena::new();
4614        let overrides = Arena::new();
4615        let mut global_expressions = Arena::new();
4616
4617        let i32_ty = types.insert(
4618            Type {
4619                name: None,
4620                inner: TypeInner::Scalar(crate::Scalar::I32),
4621            },
4622            Default::default(),
4623        );
4624
4625        let vec2_i32_ty = types.insert(
4626            Type {
4627                name: None,
4628                inner: TypeInner::Vector {
4629                    size: VectorSize::Bi,
4630                    scalar: crate::Scalar::I32,
4631                },
4632            },
4633            Default::default(),
4634        );
4635
4636        let h = constants.append(
4637            Constant {
4638                name: None,
4639                ty: i32_ty,
4640                init: global_expressions
4641                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4642            },
4643            Default::default(),
4644        );
4645
4646        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4647
4648        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4649        let mut solver = ConstantEvaluator {
4650            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4651            types: &mut types,
4652            constants: &constants,
4653            overrides: &overrides,
4654            expressions: &mut global_expressions,
4655            expression_kind_tracker,
4656            layouter: &mut crate::proc::Layouter::default(),
4657        };
4658
4659        let solved_compose = solver
4660            .try_eval_and_append(
4661                Expression::Splat {
4662                    size: VectorSize::Bi,
4663                    value: h_expr,
4664                },
4665                Default::default(),
4666            )
4667            .unwrap();
4668        let solved_negate = solver
4669            .try_eval_and_append(
4670                Expression::Unary {
4671                    op: UnaryOperator::Negate,
4672                    expr: solved_compose,
4673                },
4674                Default::default(),
4675            )
4676            .unwrap();
4677
4678        let pass = match global_expressions[solved_negate] {
4679            Expression::Compose { ty, ref components } => {
4680                ty == vec2_i32_ty
4681                    && components.iter().all(|&component| {
4682                        let component = &global_expressions[component];
4683                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4684                    })
4685            }
4686            _ => false,
4687        };
4688        if !pass {
4689            panic!("unexpected evaluation result")
4690        }
4691    }
4692
4693    #[test]
4694    fn splat_of_zero_value() {
4695        let mut types = UniqueArena::new();
4696        let constants = Arena::new();
4697        let overrides = Arena::new();
4698        let mut global_expressions = Arena::new();
4699
4700        let f32_ty = types.insert(
4701            Type {
4702                name: None,
4703                inner: TypeInner::Scalar(crate::Scalar::F32),
4704            },
4705            Default::default(),
4706        );
4707
4708        let vec2_f32_ty = types.insert(
4709            Type {
4710                name: None,
4711                inner: TypeInner::Vector {
4712                    size: VectorSize::Bi,
4713                    scalar: crate::Scalar::F32,
4714                },
4715            },
4716            Default::default(),
4717        );
4718
4719        let five =
4720            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4721        let five_splat = global_expressions.append(
4722            Expression::Splat {
4723                size: VectorSize::Bi,
4724                value: five,
4725            },
4726            Default::default(),
4727        );
4728        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4729        let zero_splat = global_expressions.append(
4730            Expression::Splat {
4731                size: VectorSize::Bi,
4732                value: zero,
4733            },
4734            Default::default(),
4735        );
4736
4737        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4738        let mut solver = ConstantEvaluator {
4739            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4740            types: &mut types,
4741            constants: &constants,
4742            overrides: &overrides,
4743            expressions: &mut global_expressions,
4744            expression_kind_tracker,
4745            layouter: &mut crate::proc::Layouter::default(),
4746        };
4747
4748        let solved_add = solver
4749            .try_eval_and_append(
4750                Expression::Binary {
4751                    op: BinaryOperator::Add,
4752                    left: zero_splat,
4753                    right: five_splat,
4754                },
4755                Default::default(),
4756            )
4757            .unwrap();
4758
4759        let pass = match global_expressions[solved_add] {
4760            Expression::Compose { ty, ref components } => {
4761                ty == vec2_f32_ty
4762                    && components.iter().all(|&component| {
4763                        let component = &global_expressions[component];
4764                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4765                    })
4766            }
4767            _ => false,
4768        };
4769        if !pass {
4770            panic!("unexpected evaluation result")
4771        }
4772    }
4773}