naga/proc/
constant_evaluator.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14    arena::{Arena, Handle, HandleVec, UniqueArena},
15    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16    ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
23/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
24///
25/// Technique stolen directly from
26/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
27macro_rules! with_dollar_sign {
28    ($($body:tt)*) => {
29        macro_rules! __with_dollar_sign { $($body)* }
30        __with_dollar_sign!($);
31    }
32}
33
34macro_rules! gen_component_wise_extractor {
35    (
36        $ident:ident -> $target:ident,
37        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39    ) => {
40        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
41        #[derive(Debug)]
42        #[cfg_attr(test, derive(PartialEq))]
43        enum $target<const N: usize> {
44            $(
45                #[doc = concat!(
46                    "Maps to [`Literal::",
47                    stringify!($literal),
48                    "`]",
49                )]
50                $mapping([$ty; N]),
51            )+
52        }
53
54        impl From<$target<1>> for Expression {
55            fn from(value: $target<1>) -> Self {
56                match value {
57                    $(
58                        $target::$mapping([value]) => {
59                            Expression::Literal(Literal::$literal(value))
60                        }
61                    )+
62                }
63            }
64        }
65
66        #[doc = concat!(
67            "Attempts to evaluate multiple `exprs` as a combined [`",
68            stringify!($target),
69            "`] to pass to `handler`. ",
70        )]
71        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
72        /// component of each vector.
73        ///
74        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
75        /// same length, a new vector expression is registered, composed of each component emitted
76        /// by `handler`.
77        fn $ident<const N: usize, const M: usize, F>(
78            eval: &mut ConstantEvaluator<'_>,
79            span: Span,
80            exprs: [Handle<Expression>; N],
81            mut handler: F,
82        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83        where
84            $target<M>: Into<Expression>,
85            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86        {
87            assert!(N > 0);
88            let err = ConstantEvaluatorError::InvalidMathArg;
89            let mut exprs = exprs.into_iter();
90
91            macro_rules! sanitize {
92                ($expr:expr) => {
93                    eval.eval_zero_value_and_splat($expr, span)
94                        .map(|expr| &eval.expressions[expr])
95                };
96            }
97
98            let new_expr = match sanitize!(exprs.next().unwrap())? {
99                $(
100                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101                        .chain(exprs.map(|expr| {
102                            sanitize!(expr).and_then(|expr| match expr {
103                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
104                                _ => Err(err.clone()),
105                            })
106                        }))
107                        .collect::<Result<ArrayVec<_, N>, _>>()
108                        .map(|a| a.into_inner().unwrap())
109                        .map($target::$mapping)
110                        .and_then(|comps| Ok(handler(comps)?.into())),
111                )+
112                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113                    &TypeInner::Vector { size, scalar } => match scalar.kind {
114                        $(ScalarKind::$scalar_kind)|* => {
115                            let first_ty = ty;
116                            let mut component_groups =
117                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118                            component_groups.push(crate::proc::flatten_compose(
119                                first_ty,
120                                components,
121                                eval.expressions,
122                                eval.types,
123                            ).collect());
124                            component_groups.extend(
125                                exprs
126                                    .map(|expr| {
127                                        sanitize!(expr).and_then(|expr| match expr {
128                                            &Expression::Compose { ty, ref components }
129                                                if &eval.types[ty].inner
130                                                    == &eval.types[first_ty].inner =>
131                                            {
132                                                Ok(crate::proc::flatten_compose(
133                                                    ty,
134                                                    components,
135                                                    eval.expressions,
136                                                    eval.types,
137                                                ).collect())
138                                            }
139                                            _ => Err(err.clone()),
140                                        })
141                                    })
142                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143                                    )?,
144                            );
145                            let component_groups = component_groups.into_inner().unwrap();
146                            let mut new_components =
147                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148                            for idx in 0..(size as u8).into() {
149                                let group = component_groups
150                                    .iter()
151                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152                                    .collect::<Result<ArrayVec<_, N>, _>>()?
153                                    .into_inner()
154                                    .unwrap();
155                                new_components.push($ident(
156                                    eval,
157                                    span,
158                                    group,
159                                    handler.clone(),
160                                )?);
161                            }
162                            Ok(Expression::Compose {
163                                ty: first_ty,
164                                components: new_components.into_iter().collect(),
165                            })
166                        }
167                        _ => return Err(err),
168                    },
169                    _ => return Err(err),
170                },
171                _ => return Err(err),
172            }?;
173            eval.register_evaluated_expr(new_expr, span)
174        }
175
176        with_dollar_sign! {
177            ($d:tt) => {
178                #[allow(unused)]
179                #[doc = concat!(
180                    "A convenience macro for using the same RHS for each [`",
181                    stringify!($target),
182                    "`] variant in a call to [`",
183                    stringify!($ident),
184                    "`].",
185                )]
186                macro_rules! $ident {
187                    (
188                        $eval:expr,
189                        $span:expr,
190                        [$d ($d expr:expr),+ $d (,)?],
191                        |$d ($d arg:ident),+| $d tt:tt
192                    ) => {
193                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
194                            $(
195                                $target::$mapping([$d ($d arg),+]) => {
196                                    let res = $d tt;
197                                    Result::map(res, $target::$mapping)
198                                },
199                            )+
200                        })
201                    };
202                }
203            };
204        }
205    };
206}
207
208gen_component_wise_extractor! {
209    component_wise_scalar -> Scalar,
210    literals: [
211        AbstractFloat => AbstractFloat: f64,
212        F32 => F32: f32,
213        F16 => F16: f16,
214        AbstractInt => AbstractInt: i64,
215        U32 => U32: u32,
216        I32 => I32: i32,
217        U64 => U64: u64,
218        I64 => I64: i64,
219    ],
220    scalar_kinds: [
221        Float,
222        AbstractFloat,
223        Sint,
224        Uint,
225        AbstractInt,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_float -> Float,
231    literals: [
232        AbstractFloat => Abstract: f64,
233        F32 => F32: f32,
234        F16 => F16: f16,
235    ],
236    scalar_kinds: [
237        Float,
238        AbstractFloat,
239    ],
240}
241
242gen_component_wise_extractor! {
243    component_wise_concrete_int -> ConcreteInt,
244    literals: [
245        U32 => U32: u32,
246        I32 => I32: i32,
247    ],
248    scalar_kinds: [
249        Sint,
250        Uint,
251    ],
252}
253
254gen_component_wise_extractor! {
255    component_wise_signed -> Signed,
256    literals: [
257        AbstractFloat => AbstractFloat: f64,
258        AbstractInt => AbstractInt: i64,
259        F32 => F32: f32,
260        F16 => F16: f16,
261        I32 => I32: i32,
262    ],
263    scalar_kinds: [
264        Sint,
265        AbstractInt,
266        Float,
267        AbstractFloat,
268    ],
269}
270
271/// Vectors with a concrete element type.
272#[derive(Debug)]
273enum LiteralVector {
274    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288    fn len(&self) -> usize {
289        match *self {
290            LiteralVector::F64(ref v) => v.len(),
291            LiteralVector::F32(ref v) => v.len(),
292            LiteralVector::F16(ref v) => v.len(),
293            LiteralVector::U32(ref v) => v.len(),
294            LiteralVector::I32(ref v) => v.len(),
295            LiteralVector::U64(ref v) => v.len(),
296            LiteralVector::I64(ref v) => v.len(),
297            LiteralVector::Bool(ref v) => v.len(),
298            LiteralVector::AbstractInt(ref v) => v.len(),
299            LiteralVector::AbstractFloat(ref v) => v.len(),
300        }
301    }
302
303    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304    fn from_literal(literal: Literal) -> Self {
305        match literal {
306            Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307            Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308            Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309            Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310            Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311            Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312            Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313            Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314            Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315            Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316        }
317    }
318
319    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320    /// Returns error if components types do not match.
321    /// # Panics
322    /// Panics if vector is empty
323    fn from_literal_vec(
324        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325    ) -> Result<Self, ConstantEvaluatorError> {
326        assert!(!components.is_empty());
327        Ok(match components[0] {
328            Literal::I32(_) => Self::I32(
329                components
330                    .iter()
331                    .map(|l| match l {
332                        &Literal::I32(v) => Ok(v),
333                        // TODO: should we handle abstract int here?
334                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
335                    })
336                    .collect::<Result<_, _>>()?,
337            ),
338            Literal::U32(_) => Self::U32(
339                components
340                    .iter()
341                    .map(|l| match l {
342                        &Literal::U32(v) => Ok(v),
343                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
344                    })
345                    .collect::<Result<_, _>>()?,
346            ),
347            Literal::I64(_) => Self::I64(
348                components
349                    .iter()
350                    .map(|l| match l {
351                        &Literal::I64(v) => Ok(v),
352                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
353                    })
354                    .collect::<Result<_, _>>()?,
355            ),
356            Literal::U64(_) => Self::U64(
357                components
358                    .iter()
359                    .map(|l| match l {
360                        &Literal::U64(v) => Ok(v),
361                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
362                    })
363                    .collect::<Result<_, _>>()?,
364            ),
365            Literal::F32(_) => Self::F32(
366                components
367                    .iter()
368                    .map(|l| match l {
369                        &Literal::F32(v) => Ok(v),
370                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
371                    })
372                    .collect::<Result<_, _>>()?,
373            ),
374            Literal::F64(_) => Self::F64(
375                components
376                    .iter()
377                    .map(|l| match l {
378                        &Literal::F64(v) => Ok(v),
379                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
380                    })
381                    .collect::<Result<_, _>>()?,
382            ),
383            Literal::Bool(_) => Self::Bool(
384                components
385                    .iter()
386                    .map(|l| match l {
387                        &Literal::Bool(v) => Ok(v),
388                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
389                    })
390                    .collect::<Result<_, _>>()?,
391            ),
392            Literal::AbstractInt(_) => Self::AbstractInt(
393                components
394                    .iter()
395                    .map(|l| match l {
396                        &Literal::AbstractInt(v) => Ok(v),
397                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
398                    })
399                    .collect::<Result<_, _>>()?,
400            ),
401            Literal::AbstractFloat(_) => Self::AbstractFloat(
402                components
403                    .iter()
404                    .map(|l| match l {
405                        &Literal::AbstractFloat(v) => Ok(v),
406                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
407                    })
408                    .collect::<Result<_, _>>()?,
409            ),
410            Literal::F16(_) => Self::F16(
411                components
412                    .iter()
413                    .map(|l| match l {
414                        &Literal::F16(v) => Ok(v),
415                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
416                    })
417                    .collect::<Result<_, _>>()?,
418            ),
419        })
420    }
421
422    #[allow(dead_code)]
423    /// Returns [`ArrayVec`] of [`Literal`]s
424    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425        match *self {
426            LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427            LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428            LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429            LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430            LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431            LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432            LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433            LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434            LiteralVector::AbstractInt(ref v) => {
435                v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436            }
437            LiteralVector::AbstractFloat(ref v) => {
438                v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439            }
440        }
441    }
442
443    #[allow(dead_code)]
444    /// Puts self into eval's expressions arena and returns handle to it
445    fn register_as_evaluated_expr(
446        &self,
447        eval: &mut ConstantEvaluator<'_>,
448        span: Span,
449    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450        let lit_vec = self.to_literal_vec();
451        assert!(!lit_vec.is_empty());
452        let expr = if lit_vec.len() == 1 {
453            Expression::Literal(lit_vec[0])
454        } else {
455            Expression::Compose {
456                ty: eval.types.insert(
457                    Type {
458                        name: None,
459                        inner: TypeInner::Vector {
460                            size: match lit_vec.len() {
461                                2 => crate::VectorSize::Bi,
462                                3 => crate::VectorSize::Tri,
463                                4 => crate::VectorSize::Quad,
464                                _ => unreachable!(),
465                            },
466                            scalar: lit_vec[0].scalar(),
467                        },
468                    },
469                    Span::UNDEFINED,
470                ),
471                components: lit_vec
472                    .iter()
473                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474                    .collect::<Result<_, _>>()?,
475            }
476        };
477        eval.register_evaluated_expr(expr, span)
478    }
479}
480
481/// A macro for matching on [`LiteralVector`] variants.
482///
483/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
484/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
485///
486/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
487///
488/// Example usage:
489///
490/// ```rust,ignore
491/// match_literal_vector!(match v => Literal {
492///     F16 => |v| {v.sum()},
493///     Integer => |v| {v.sum()},
494///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
495/// })
496/// ```
497///
498/// ```rust,ignore
499/// match_literal_vector!(match (e1, e2) => LiteralVector {
500///     F16 => |e1, e2| {e1+e2},
501///     Integer => |e1, e2| {e1+e2},
502///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
503/// })
504/// ```
505macro_rules! match_literal_vector {
506    (match $lit_vec:expr => $out:ident {
507        $(
508            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509        ),+
510        $(,)?
511    }) => {
512        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513    };
514
515    (@inner_start
516        $lit_vec:expr;
517        $out:ident;
518        [$($ty:ident),+];
519        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520    ) => {
521        match_literal_vector!(@inner
522            $lit_vec;
523            $out;
524            [$($ty),+];
525            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526        )
527    };
528
529    (@inner
530        $lit_vec:expr;
531        $out:ident;
532        [$ty:ident $(, $ty1:ident)*];
533        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535    ) => {
536        match_literal_vector!(@inner
537            $ty;
538            $lit_vec;
539            $out;
540            [$($ty1),*];
541            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542            [$({ $($var),+ ; $($ret)? ; $body }),+]
543        )
544    };
545    (@inner
546        Integer;
547        $lit_vec:expr;
548        $out:ident;
549        [$($ty:ident),*];
550        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551        [
552            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554        ]
555    ) => {
556        match_literal_vector!(@inner
557            $lit_vec;
558            $out;
559            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561            [
562                { $($var),+ ; $($ret)? ; $body }, // U32
563                { $($var),+ ; $($ret)? ; $body }, // I32
564                { $($var),+ ; $($ret)? ; $body }, // U64
565                { $($var),+ ; $($ret)? ; $body }, // I64
566                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
567                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568            ]
569        )
570    };
571    (@inner
572        Float;
573        $lit_vec:expr;
574        $out:ident;
575        [$($ty:ident),*];
576        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577        [
578            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580        ]
581    ) => {
582        match_literal_vector!(@inner
583            $lit_vec;
584            $out;
585            [F16, F32, F64, AbstractFloat $(, $ty)*];
586            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587            [
588                { $($var),+ ; $($ret)? ; $body }, // F16
589                { $($var),+ ; $($ret)? ; $body }, // F32
590                { $($var),+ ; $($ret)? ; $body }, // F64
591                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
592                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593            ]
594        )
595    };
596    (@inner
597        $ty:ident;
598        $lit_vec:expr;
599        $out:ident;
600        [$ty1:ident $(,$ty2:ident)*];
601        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604        ]
605    ) => {
606        match_literal_vector!(@inner
607            $ty1;
608            $lit_vec;
609            $out;
610            [$($ty2),*];
611            [
612                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613                { $ty; $($var),+ ; $($ret)? ; $body }
614            ] <>
615            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617        )
618    };
619    (@inner
620        $ty:ident;
621        $lit_vec:expr;
622        $out:ident;
623        [];
624        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626    ) => {
627        match_literal_vector!(@inner_finish
628            $lit_vec;
629            $out;
630            [
631                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632                { $ty; $($var),+ ; $($ret)? ; $body }
633            ]
634        )
635    };
636    (@inner_finish
637        $lit_vec:expr;
638        $out:ident;
639        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640    ) => {
641        match $lit_vec {
642            $(
643                #[allow(unused_parens)]
644                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645            )+
646            _ => Err(ConstantEvaluatorError::InvalidMathArg),
647        }
648    };
649    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650        $out::$ty($body)
651    };
652    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653        $out::$ret($body)
654    };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659    Wgsl(WgslRestrictions<'a>),
660    Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
665    const fn has_runtime_restrictions(&self) -> bool {
666        matches!(
667            self,
668            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670        )
671    }
672}
673
674/// A context for evaluating constant expressions.
675///
676/// A `ConstantEvaluator` points at an expression arena to which it can append
677/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
678/// of Naga [`Expression`] you like, and if its value can be computed at compile
679/// time, `try_eval_and_append` appends an expression representing the computed
680/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
681/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
682///
683/// A `ConstantEvaluator` also holds whatever information we need to carry out
684/// that evaluation: types, other constants, and so on.
685///
686/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
687/// [`Compose`]: Expression::Compose
688/// [`ZeroValue`]: Expression::ZeroValue
689/// [`Literal`]: Expression::Literal
690/// [`Swizzle`]: Expression::Swizzle
691#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693    /// Which language's evaluation rules we should follow.
694    behavior: Behavior<'a>,
695
696    /// The module's type arena.
697    ///
698    /// Because expressions like [`Splat`] contain type handles, we need to be
699    /// able to add new types to produce those expressions.
700    ///
701    /// [`Splat`]: Expression::Splat
702    types: &'a mut UniqueArena<Type>,
703
704    /// The module's constant arena.
705    constants: &'a Arena<Constant>,
706
707    /// The module's override arena.
708    overrides: &'a Arena<Override>,
709
710    /// The arena to which we are contributing expressions.
711    expressions: &'a mut Arena<Expression>,
712
713    /// Tracks the constness of expressions residing in [`Self::expressions`]
714    expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716    layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721    /// - const-expressions will be evaluated and inserted in the arena
722    Const(Option<FunctionLocalData<'a>>),
723    /// - const-expressions will be evaluated and inserted in the arena
724    /// - override-expressions will be inserted in the arena
725    Override,
726    /// - const-expressions will be evaluated and inserted in the arena
727    /// - override-expressions will be inserted in the arena
728    /// - runtime-expressions will be inserted in the arena
729    Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734    /// - const-expressions will be evaluated and inserted in the arena
735    Const,
736    /// - const-expressions will be evaluated and inserted in the arena
737    /// - override-expressions will be inserted in the arena
738    /// - runtime-expressions will be inserted in the arena
739    Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744    /// Global constant expressions
745    global_expressions: &'a Arena<Expression>,
746    emitter: &'a mut super::Emitter,
747    block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752    Const,
753    Override,
754    Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759    inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763    pub const fn new() -> Self {
764        Self {
765            inner: HandleVec::new(),
766        }
767    }
768
769    /// Forces the the expression to not be const
770    pub fn force_non_const(&mut self, value: Handle<Expression>) {
771        self.inner[value] = ExpressionKind::Runtime;
772    }
773
774    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775        self.inner.insert(value, expr_type);
776    }
777
778    pub fn is_const(&self, h: Handle<Expression>) -> bool {
779        matches!(self.type_of(h), ExpressionKind::Const)
780    }
781
782    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783        matches!(
784            self.type_of(h),
785            ExpressionKind::Const | ExpressionKind::Override
786        )
787    }
788
789    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790        self.inner[value]
791    }
792
793    pub fn from_arena(arena: &Arena<Expression>) -> Self {
794        let mut tracker = Self {
795            inner: HandleVec::with_capacity(arena.len()),
796        };
797        for (handle, expr) in arena.iter() {
798            tracker
799                .inner
800                .insert(handle, tracker.type_of_with_expr(expr));
801        }
802        tracker
803    }
804
805    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806        match *expr {
807            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808                ExpressionKind::Const
809            }
810            Expression::Override(_) => ExpressionKind::Override,
811            Expression::Compose { ref components, .. } => {
812                let mut expr_type = ExpressionKind::Const;
813                for component in components {
814                    expr_type = expr_type.max(self.type_of(*component))
815                }
816                expr_type
817            }
818            Expression::Splat { value, .. } => self.type_of(value),
819            Expression::AccessIndex { base, .. } => self.type_of(base),
820            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821            Expression::Swizzle { vector, .. } => self.type_of(vector),
822            Expression::Unary { expr, .. } => self.type_of(expr),
823            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824            Expression::Math {
825                arg,
826                arg1,
827                arg2,
828                arg3,
829                ..
830            } => self
831                .type_of(arg)
832                .max(
833                    arg1.map(|arg| self.type_of(arg))
834                        .unwrap_or(ExpressionKind::Const),
835                )
836                .max(
837                    arg2.map(|arg| self.type_of(arg))
838                        .unwrap_or(ExpressionKind::Const),
839                )
840                .max(
841                    arg3.map(|arg| self.type_of(arg))
842                        .unwrap_or(ExpressionKind::Const),
843                ),
844            Expression::As { expr, .. } => self.type_of(expr),
845            Expression::Select {
846                condition,
847                accept,
848                reject,
849            } => self
850                .type_of(condition)
851                .max(self.type_of(accept))
852                .max(self.type_of(reject)),
853            Expression::Relational { argument, .. } => self.type_of(argument),
854            Expression::ArrayLength(expr) => self.type_of(expr),
855            _ => ExpressionKind::Runtime,
856        }
857    }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863    #[error("Constants cannot access function arguments")]
864    FunctionArg,
865    #[error("Constants cannot access global variables")]
866    GlobalVariable,
867    #[error("Constants cannot access local variables")]
868    LocalVariable,
869    #[error("Cannot get the array length of a non array type")]
870    InvalidArrayLengthArg,
871    #[error("Constants cannot get the array length of a dynamically sized array")]
872    ArrayLengthDynamic,
873    #[error("Cannot call arrayLength on array sized by override-expression")]
874    ArrayLengthOverridden,
875    #[error("Constants cannot call functions")]
876    Call,
877    #[error("Constants don't support workGroupUniformLoad")]
878    WorkGroupUniformLoadResult,
879    #[error("Constants don't support atomic functions")]
880    Atomic,
881    #[error("Constants don't support derivative functions")]
882    Derivative,
883    #[error("Constants don't support load expressions")]
884    Load,
885    #[error("Constants don't support image expressions")]
886    ImageExpression,
887    #[error("Constants don't support ray query expressions")]
888    RayQueryExpression,
889    #[error("Constants don't support subgroup expressions")]
890    SubgroupExpression,
891    #[error("Cannot access the type")]
892    InvalidAccessBase,
893    #[error("Cannot access at the index")]
894    InvalidAccessIndex,
895    #[error("Cannot access with index of type")]
896    InvalidAccessIndexTy,
897    #[error("Constants don't support array length expressions")]
898    ArrayLength,
899    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900    InvalidCastArg { from: String, to: String },
901    #[error("Cannot apply the unary op to the argument")]
902    InvalidUnaryOpArg,
903    #[error("Cannot apply the binary op to the arguments")]
904    InvalidBinaryOpArgs,
905    #[error("Cannot apply math function to type")]
906    InvalidMathArg,
907    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908    InvalidMathArgCount(crate::MathFunction, usize, usize),
909    #[error("Cannot apply relational function to type")]
910    InvalidRelationalArg(RelationalFunction),
911    #[error("value of `low` is greater than `high` for clamp built-in function")]
912    InvalidClamp,
913    #[error("Constructor expects {expected} components, found {actual}")]
914    InvalidVectorComposeLength { expected: usize, actual: usize },
915    #[error("Constructor must only contain vector or scalar arguments")]
916    InvalidVectorComposeComponent,
917    #[error("Splat is defined only on scalar values")]
918    SplatScalarOnly,
919    #[error("Can only swizzle vector constants")]
920    SwizzleVectorOnly,
921    #[error("swizzle component not present in source expression")]
922    SwizzleOutOfBounds,
923    #[error("Type is not constructible")]
924    TypeNotConstructible,
925    #[error("Subexpression(s) are not constant")]
926    SubexpressionsAreNotConstant,
927    #[error("Not implemented as constant expression: {0}")]
928    NotImplemented(String),
929    #[error("{0} operation overflowed")]
930    Overflow(String),
931    #[error(
932        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933    )]
934    AutomaticConversionLossy {
935        value: String,
936        to_type: &'static str,
937    },
938    #[error("Division by zero")]
939    DivisionByZero,
940    #[error("Remainder by zero")]
941    RemainderByZero,
942    #[error("RHS of shift operation is greater than or equal to 32")]
943    ShiftedMoreThan32Bits,
944    #[error(transparent)]
945    Literal(#[from] crate::valid::LiteralError),
946    #[error("Can't use pipeline-overridable constants in const-expressions")]
947    Override,
948    #[error("Unexpected runtime-expression")]
949    RuntimeExpr,
950    #[error("Unexpected override-expression")]
951    OverrideExpr,
952    #[error("Expected boolean expression for condition argument of `select`, got something else")]
953    SelectScalarConditionNotABool,
954    #[error(
955        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956        reject,
957        accept
958    )]
959    SelectVecRejectAcceptSizeMismatch {
960        reject: crate::VectorSize,
961        accept: crate::VectorSize,
962    },
963    #[error("Expected boolean vector for condition arg., got something else")]
964    SelectConditionNotAVecBool,
965    #[error(
966        "Expected same number of vector components between condition, accept, and reject args., got something else",
967    )]
968    SelectConditionVecSizeMismatch,
969    #[error(
970        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971    )]
972    SelectAcceptRejectTypeMismatch,
973    #[error("Cooperative operations can't be constant")]
974    CooperativeOperation,
975}
976
977impl<'a> ConstantEvaluator<'a> {
978    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
979    /// constant expression arena.
980    ///
981    /// Report errors according to WGSL's rules for constant evaluation.
982    pub fn for_wgsl_module(
983        module: &'a mut crate::Module,
984        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
985        layouter: &'a mut crate::proc::Layouter,
986        in_override_ctx: bool,
987    ) -> Self {
988        Self::for_module(
989            Behavior::Wgsl(if in_override_ctx {
990                WgslRestrictions::Override
991            } else {
992                WgslRestrictions::Const(None)
993            }),
994            module,
995            global_expression_kind_tracker,
996            layouter,
997        )
998    }
999
1000    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
1001    /// constant expression arena.
1002    ///
1003    /// Report errors according to GLSL's rules for constant evaluation.
1004    pub fn for_glsl_module(
1005        module: &'a mut crate::Module,
1006        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1007        layouter: &'a mut crate::proc::Layouter,
1008    ) -> Self {
1009        Self::for_module(
1010            Behavior::Glsl(GlslRestrictions::Const),
1011            module,
1012            global_expression_kind_tracker,
1013            layouter,
1014        )
1015    }
1016
1017    fn for_module(
1018        behavior: Behavior<'a>,
1019        module: &'a mut crate::Module,
1020        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1021        layouter: &'a mut crate::proc::Layouter,
1022    ) -> Self {
1023        Self {
1024            behavior,
1025            types: &mut module.types,
1026            constants: &module.constants,
1027            overrides: &module.overrides,
1028            expressions: &mut module.global_expressions,
1029            expression_kind_tracker: global_expression_kind_tracker,
1030            layouter,
1031        }
1032    }
1033
1034    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1035    /// expression arena.
1036    ///
1037    /// Report errors according to WGSL's rules for constant evaluation.
1038    pub fn for_wgsl_function(
1039        module: &'a mut crate::Module,
1040        expressions: &'a mut Arena<Expression>,
1041        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1042        layouter: &'a mut crate::proc::Layouter,
1043        emitter: &'a mut super::Emitter,
1044        block: &'a mut crate::Block,
1045        is_const: bool,
1046    ) -> Self {
1047        let local_data = FunctionLocalData {
1048            global_expressions: &module.global_expressions,
1049            emitter,
1050            block,
1051        };
1052        Self {
1053            behavior: Behavior::Wgsl(if is_const {
1054                WgslRestrictions::Const(Some(local_data))
1055            } else {
1056                WgslRestrictions::Runtime(local_data)
1057            }),
1058            types: &mut module.types,
1059            constants: &module.constants,
1060            overrides: &module.overrides,
1061            expressions,
1062            expression_kind_tracker: local_expression_kind_tracker,
1063            layouter,
1064        }
1065    }
1066
1067    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1068    /// expression arena.
1069    ///
1070    /// Report errors according to GLSL's rules for constant evaluation.
1071    pub fn for_glsl_function(
1072        module: &'a mut crate::Module,
1073        expressions: &'a mut Arena<Expression>,
1074        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1075        layouter: &'a mut crate::proc::Layouter,
1076        emitter: &'a mut super::Emitter,
1077        block: &'a mut crate::Block,
1078    ) -> Self {
1079        Self {
1080            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1081                global_expressions: &module.global_expressions,
1082                emitter,
1083                block,
1084            })),
1085            types: &mut module.types,
1086            constants: &module.constants,
1087            overrides: &module.overrides,
1088            expressions,
1089            expression_kind_tracker: local_expression_kind_tracker,
1090            layouter,
1091        }
1092    }
1093
1094    pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1095        crate::proc::GlobalCtx {
1096            types: self.types,
1097            constants: self.constants,
1098            overrides: self.overrides,
1099            global_expressions: match self.function_local_data() {
1100                Some(data) => data.global_expressions,
1101                None => self.expressions,
1102            },
1103        }
1104    }
1105
1106    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1107        if !self.expression_kind_tracker.is_const(expr) {
1108            log::debug!("check: SubexpressionsAreNotConstant");
1109            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1110        }
1111        Ok(())
1112    }
1113
1114    fn check_and_get(
1115        &mut self,
1116        expr: Handle<Expression>,
1117    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1118        match self.expressions[expr] {
1119            Expression::Constant(c) => {
1120                // Are we working in a function's expression arena, or the
1121                // module's constant expression arena?
1122                if let Some(function_local_data) = self.function_local_data() {
1123                    // Deep-copy the constant's value into our arena.
1124                    self.copy_from(
1125                        self.constants[c].init,
1126                        function_local_data.global_expressions,
1127                    )
1128                } else {
1129                    // "See through" the constant and use its initializer.
1130                    Ok(self.constants[c].init)
1131                }
1132            }
1133            _ => {
1134                self.check(expr)?;
1135                Ok(expr)
1136            }
1137        }
1138    }
1139
1140    /// Try to evaluate `expr` at compile time.
1141    ///
1142    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1143    /// we can determine its value at compile time, we append an expression
1144    /// representing its value - a tree of [`Literal`], [`Compose`],
1145    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1146    /// `self` contributes to.
1147    ///
1148    /// If `expr`'s value cannot be determined at compile time, and `self` is
1149    /// contributing to some function's expression arena, then append `expr` to
1150    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1151    /// contributing to the module's constant expression arena; since `expr`'s
1152    /// value is not a constant, return an error.
1153    ///
1154    /// We only consider `expr` itself, without recursing into its operands. Its
1155    /// operands must all have been produced by prior calls to
1156    /// `try_eval_and_append`, to ensure that they have already been reduced to
1157    /// an evaluated form if possible.
1158    ///
1159    /// [`Literal`]: Expression::Literal
1160    /// [`Compose`]: Expression::Compose
1161    /// [`ZeroValue`]: Expression::ZeroValue
1162    /// [`Swizzle`]: Expression::Swizzle
1163    pub fn try_eval_and_append(
1164        &mut self,
1165        expr: Expression,
1166        span: Span,
1167    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1168        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1169            ExpressionKind::Const => {
1170                let eval_result = self.try_eval_and_append_impl(&expr, span);
1171                // We should be able to evaluate `Const` expressions at this
1172                // point. If we failed to, then that probably means we just
1173                // haven't implemented that part of constant evaluation. Work
1174                // around this by simply emitting it as a run-time expression.
1175                if self.behavior.has_runtime_restrictions()
1176                    && matches!(
1177                        eval_result,
1178                        Err(ConstantEvaluatorError::NotImplemented(_)
1179                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1180                    )
1181                {
1182                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1183                } else {
1184                    eval_result
1185                }
1186            }
1187            ExpressionKind::Override => match self.behavior {
1188                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1189                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1190                }
1191                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1192                    Err(ConstantEvaluatorError::OverrideExpr)
1193                }
1194
1195                // GLSL specialization constants (constant_id) become Override expressions
1196                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1197                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1198                }
1199                Behavior::Glsl(GlslRestrictions::Const) => {
1200                    Err(ConstantEvaluatorError::OverrideExpr)
1201                }
1202            },
1203            ExpressionKind::Runtime => {
1204                if self.behavior.has_runtime_restrictions() {
1205                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1206                } else {
1207                    Err(ConstantEvaluatorError::RuntimeExpr)
1208                }
1209            }
1210        }
1211    }
1212
1213    /// Is the [`Self::expressions`] arena the global module expression arena?
1214    const fn is_global_arena(&self) -> bool {
1215        matches!(
1216            self.behavior,
1217            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1218                | Behavior::Glsl(GlslRestrictions::Const)
1219        )
1220    }
1221
1222    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1223        match self.behavior {
1224            Behavior::Wgsl(
1225                WgslRestrictions::Runtime(ref function_local_data)
1226                | WgslRestrictions::Const(Some(ref function_local_data)),
1227            )
1228            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1229                Some(function_local_data)
1230            }
1231            _ => None,
1232        }
1233    }
1234
1235    fn try_eval_and_append_impl(
1236        &mut self,
1237        expr: &Expression,
1238        span: Span,
1239    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1240        log::trace!("try_eval_and_append: {expr:?}");
1241        match *expr {
1242            Expression::Constant(c) if self.is_global_arena() => {
1243                // "See through" the constant and use its initializer.
1244                // This is mainly done to avoid having constants pointing to other constants.
1245                Ok(self.constants[c].init)
1246            }
1247            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1248            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1249                self.register_evaluated_expr(expr.clone(), span)
1250            }
1251            Expression::Compose { ty, ref components } => {
1252                let components = components
1253                    .iter()
1254                    .map(|component| self.check_and_get(*component))
1255                    .collect::<Result<Vec<_>, _>>()?;
1256                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1257            }
1258            Expression::Splat { size, value } => {
1259                let value = self.check_and_get(value)?;
1260                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1261            }
1262            Expression::AccessIndex { base, index } => {
1263                let base = self.check_and_get(base)?;
1264
1265                self.access(base, index as usize, span)
1266            }
1267            Expression::Access { base, index } => {
1268                let base = self.check_and_get(base)?;
1269                let index = self.check_and_get(index)?;
1270
1271                let index_val: u32 = self
1272                    .to_ctx()
1273                    .get_const_val_from(index, self.expressions)
1274                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1275                self.access(base, index_val as usize, span)
1276            }
1277            Expression::Swizzle {
1278                size,
1279                vector,
1280                pattern,
1281            } => {
1282                let vector = self.check_and_get(vector)?;
1283
1284                self.swizzle(size, span, vector, pattern)
1285            }
1286            Expression::Unary { expr, op } => {
1287                let expr = self.check_and_get(expr)?;
1288
1289                self.unary_op(op, expr, span)
1290            }
1291            Expression::Binary { left, right, op } => {
1292                let left = self.check_and_get(left)?;
1293                let right = self.check_and_get(right)?;
1294
1295                self.binary_op(op, left, right, span)
1296            }
1297            Expression::Math {
1298                fun,
1299                arg,
1300                arg1,
1301                arg2,
1302                arg3,
1303            } => {
1304                let arg = self.check_and_get(arg)?;
1305                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1306                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1307                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1308
1309                self.math(arg, arg1, arg2, arg3, fun, span)
1310            }
1311            Expression::As {
1312                convert,
1313                expr,
1314                kind,
1315            } => {
1316                let expr = self.check_and_get(expr)?;
1317
1318                match convert {
1319                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1320                    None => Err(ConstantEvaluatorError::NotImplemented(
1321                        "bitcast built-in function".into(),
1322                    )),
1323                }
1324            }
1325            Expression::Select {
1326                reject,
1327                accept,
1328                condition,
1329            } => {
1330                let mut arg = |expr| self.check_and_get(expr);
1331
1332                let reject = arg(reject)?;
1333                let accept = arg(accept)?;
1334                let condition = arg(condition)?;
1335
1336                self.select(reject, accept, condition, span)
1337            }
1338            Expression::Relational { fun, argument } => {
1339                let argument = self.check_and_get(argument)?;
1340                self.relational(fun, argument, span)
1341            }
1342            Expression::ArrayLength(expr) => match self.behavior {
1343                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1344                Behavior::Glsl(_) => {
1345                    let expr = self.check_and_get(expr)?;
1346                    self.array_length(expr, span)
1347                }
1348            },
1349            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1350            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1351            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1352            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1353            Expression::WorkGroupUniformLoadResult { .. } => {
1354                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1355            }
1356            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1357            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1358            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1359            Expression::ImageSample { .. }
1360            | Expression::ImageLoad { .. }
1361            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1362            Expression::RayQueryProceedResult
1363            | Expression::RayQueryGetIntersection { .. }
1364            | Expression::RayQueryVertexPositions { .. } => {
1365                Err(ConstantEvaluatorError::RayQueryExpression)
1366            }
1367            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1368            Expression::SubgroupOperationResult { .. } => {
1369                Err(ConstantEvaluatorError::SubgroupExpression)
1370            }
1371            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1372                Err(ConstantEvaluatorError::CooperativeOperation)
1373            }
1374        }
1375    }
1376
1377    /// Splat `value` to `size`, without using [`Splat`] expressions.
1378    ///
1379    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1380    /// build a vector with the given `size` whose components are all
1381    /// `value`.
1382    ///
1383    /// Use `span` as the span of the inserted expressions and
1384    /// resulting types.
1385    ///
1386    /// [`Splat`]: Expression::Splat
1387    /// [`Compose`]: Expression::Compose
1388    /// [`ZeroValue`]: Expression::ZeroValue
1389    fn splat(
1390        &mut self,
1391        value: Handle<Expression>,
1392        size: crate::VectorSize,
1393        span: Span,
1394    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1395        match self.expressions[value] {
1396            Expression::Literal(literal) => {
1397                let scalar = literal.scalar();
1398                let ty = self.types.insert(
1399                    Type {
1400                        name: None,
1401                        inner: TypeInner::Vector { size, scalar },
1402                    },
1403                    span,
1404                );
1405                let expr = Expression::Compose {
1406                    ty,
1407                    components: vec![value; size as usize],
1408                };
1409                self.register_evaluated_expr(expr, span)
1410            }
1411            Expression::ZeroValue(ty) => {
1412                let inner = match self.types[ty].inner {
1413                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1414                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1415                };
1416                let res_ty = self.types.insert(Type { name: None, inner }, span);
1417                let expr = Expression::ZeroValue(res_ty);
1418                self.register_evaluated_expr(expr, span)
1419            }
1420            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1421        }
1422    }
1423
1424    fn swizzle(
1425        &mut self,
1426        size: crate::VectorSize,
1427        span: Span,
1428        src_constant: Handle<Expression>,
1429        pattern: [crate::SwizzleComponent; 4],
1430    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1431        let mut get_dst_ty = |ty| match self.types[ty].inner {
1432            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1433                Type {
1434                    name: None,
1435                    inner: TypeInner::Vector { size, scalar },
1436                },
1437                span,
1438            )),
1439            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1440        };
1441
1442        match self.expressions[src_constant] {
1443            Expression::ZeroValue(ty) => {
1444                let dst_ty = get_dst_ty(ty)?;
1445                let expr = Expression::ZeroValue(dst_ty);
1446                self.register_evaluated_expr(expr, span)
1447            }
1448            Expression::Splat { value, .. } => {
1449                let expr = Expression::Splat { size, value };
1450                self.register_evaluated_expr(expr, span)
1451            }
1452            Expression::Compose { ty, ref components } => {
1453                let dst_ty = get_dst_ty(ty)?;
1454
1455                let mut flattened = [src_constant; 4]; // dummy value
1456                let len =
1457                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1458                        .zip(flattened.iter_mut())
1459                        .map(|(component, elt)| *elt = component)
1460                        .count();
1461                let flattened = &flattened[..len];
1462
1463                let swizzled_components = pattern[..size as usize]
1464                    .iter()
1465                    .map(|&sc| {
1466                        let sc = sc as usize;
1467                        if let Some(elt) = flattened.get(sc) {
1468                            Ok(*elt)
1469                        } else {
1470                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1471                        }
1472                    })
1473                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1474                let expr = Expression::Compose {
1475                    ty: dst_ty,
1476                    components: swizzled_components,
1477                };
1478                self.register_evaluated_expr(expr, span)
1479            }
1480            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1481        }
1482    }
1483
1484    fn math(
1485        &mut self,
1486        arg: Handle<Expression>,
1487        arg1: Option<Handle<Expression>>,
1488        arg2: Option<Handle<Expression>>,
1489        arg3: Option<Handle<Expression>>,
1490        fun: crate::MathFunction,
1491        span: Span,
1492    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1493        let expected = fun.argument_count();
1494        let given = Some(arg)
1495            .into_iter()
1496            .chain(arg1)
1497            .chain(arg2)
1498            .chain(arg3)
1499            .count();
1500        if expected != given {
1501            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1502                fun, expected, given,
1503            ));
1504        }
1505
1506        // NOTE: We try to match the declaration order of `MathFunction` here.
1507        match fun {
1508            // comparison
1509            crate::MathFunction::Abs => {
1510                component_wise_scalar(self, span, [arg], |args| match args {
1511                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1512                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1513                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1514                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1515                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1516                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1517                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1518                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1519                })
1520            }
1521            crate::MathFunction::Min => {
1522                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1523                    Ok([e1.min(e2)])
1524                })
1525            }
1526            crate::MathFunction::Max => {
1527                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1528                    Ok([e1.max(e2)])
1529                })
1530            }
1531            crate::MathFunction::Clamp => {
1532                component_wise_scalar!(
1533                    self,
1534                    span,
1535                    [arg, arg1.unwrap(), arg2.unwrap()],
1536                    |e, low, high| {
1537                        if low > high {
1538                            Err(ConstantEvaluatorError::InvalidClamp)
1539                        } else {
1540                            Ok([e.clamp(low, high)])
1541                        }
1542                    }
1543                )
1544            }
1545            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1546                Float::F16([e]) => Ok(Float::F16(
1547                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1548                )),
1549                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1550                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1551            }),
1552
1553            // trigonometry
1554            crate::MathFunction::Cos => {
1555                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1556            }
1557            crate::MathFunction::Cosh => {
1558                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1559            }
1560            crate::MathFunction::Sin => {
1561                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1562            }
1563            crate::MathFunction::Sinh => {
1564                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1565            }
1566            crate::MathFunction::Tan => {
1567                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1568            }
1569            crate::MathFunction::Tanh => {
1570                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1571            }
1572            crate::MathFunction::Acos => {
1573                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1574            }
1575            crate::MathFunction::Asin => {
1576                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1577            }
1578            crate::MathFunction::Atan => {
1579                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1580            }
1581            crate::MathFunction::Atan2 => {
1582                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1583                    Ok([y.atan2(x)])
1584                })
1585            }
1586            crate::MathFunction::Asinh => {
1587                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1588            }
1589            crate::MathFunction::Acosh => {
1590                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1591            }
1592            crate::MathFunction::Atanh => {
1593                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1594            }
1595            crate::MathFunction::Radians => {
1596                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1597            }
1598            crate::MathFunction::Degrees => {
1599                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1600            }
1601
1602            // decomposition
1603            crate::MathFunction::Ceil => {
1604                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1605            }
1606            crate::MathFunction::Floor => {
1607                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1608            }
1609            crate::MathFunction::Round => {
1610                component_wise_float(self, span, [arg], |e| match e {
1611                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1612                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1613                    Float::F16([e]) => {
1614                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1615                        //
1616                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1617                        // which has licensing compatible with ours. See also
1618                        // <https://github.com/rust-lang/rust/issues/96710>.
1619                        //
1620                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1621                        fn round_ties_even(x: f64) -> f64 {
1622                            let i = x as i64;
1623                            let f = (x - i as f64).abs();
1624                            if f == 0.5 {
1625                                if i & 1 == 1 {
1626                                    // -1.5, 1.5, 3.5, ...
1627                                    (x.abs() + 0.5).copysign(x)
1628                                } else {
1629                                    (x.abs() - 0.5).copysign(x)
1630                                }
1631                            } else {
1632                                x.round()
1633                            }
1634                        }
1635
1636                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1637                    }
1638                })
1639            }
1640            crate::MathFunction::Fract => {
1641                component_wise_float!(self, span, [arg], |e| {
1642                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1643                    // here.
1644                    Ok([e - e.floor()])
1645                })
1646            }
1647            crate::MathFunction::Trunc => {
1648                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1649            }
1650
1651            // exponent
1652            crate::MathFunction::Exp => {
1653                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1654            }
1655            crate::MathFunction::Exp2 => {
1656                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1657            }
1658            crate::MathFunction::Log => {
1659                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1660            }
1661            crate::MathFunction::Log2 => {
1662                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1663            }
1664            crate::MathFunction::Pow => {
1665                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1666                    Ok([e1.powf(e2)])
1667                })
1668            }
1669
1670            // computational
1671            crate::MathFunction::Sign => {
1672                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1673            }
1674            crate::MathFunction::Fma => {
1675                component_wise_float!(
1676                    self,
1677                    span,
1678                    [arg, arg1.unwrap(), arg2.unwrap()],
1679                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1680                )
1681            }
1682            crate::MathFunction::Step => {
1683                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1684                    Float::Abstract([edge, x]) => {
1685                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1686                    }
1687                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1688                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1689                        f16::one()
1690                    } else {
1691                        f16::zero()
1692                    }])),
1693                })
1694            }
1695            crate::MathFunction::Sqrt => {
1696                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1697            }
1698            crate::MathFunction::InverseSqrt => {
1699                component_wise_float(self, span, [arg], |e| match e {
1700                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1701                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1702                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1703                })
1704            }
1705
1706            // bits
1707            crate::MathFunction::CountTrailingZeros => {
1708                component_wise_concrete_int!(self, span, [arg], |e| {
1709                    #[allow(clippy::useless_conversion)]
1710                    Ok([e
1711                        .trailing_zeros()
1712                        .try_into()
1713                        .expect("bit count overflowed 32 bits, somehow!?")])
1714                })
1715            }
1716            crate::MathFunction::CountLeadingZeros => {
1717                component_wise_concrete_int!(self, span, [arg], |e| {
1718                    #[allow(clippy::useless_conversion)]
1719                    Ok([e
1720                        .leading_zeros()
1721                        .try_into()
1722                        .expect("bit count overflowed 32 bits, somehow!?")])
1723                })
1724            }
1725            crate::MathFunction::CountOneBits => {
1726                component_wise_concrete_int!(self, span, [arg], |e| {
1727                    #[allow(clippy::useless_conversion)]
1728                    Ok([e
1729                        .count_ones()
1730                        .try_into()
1731                        .expect("bit count overflowed 32 bits, somehow!?")])
1732                })
1733            }
1734            crate::MathFunction::ReverseBits => {
1735                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1736            }
1737            crate::MathFunction::FirstTrailingBit => {
1738                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1739            }
1740            crate::MathFunction::FirstLeadingBit => {
1741                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1742            }
1743
1744            // vector
1745            crate::MathFunction::Dot4I8Packed => {
1746                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1747            }
1748            crate::MathFunction::Dot4U8Packed => {
1749                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1750            }
1751            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1752            crate::MathFunction::Dot => {
1753                // https://www.w3.org/TR/WGSL/#dot-builtin
1754                let e1 = self.extract_vec(arg, false)?;
1755                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1756                if e1.len() != e2.len() {
1757                    return Err(ConstantEvaluatorError::InvalidMathArg);
1758                }
1759
1760                fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1761                where
1762                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1763                {
1764                    a.iter()
1765                        .zip(b.iter())
1766                        .map(|(&aa, bb)| aa.checked_mul(bb))
1767                        .try_fold(P::zero(), |acc, x| {
1768                            if let Some(x) = x {
1769                                acc.checked_add(&x)
1770                            } else {
1771                                None
1772                            }
1773                        })
1774                        .ok_or(ConstantEvaluatorError::Overflow(
1775                            "in dot built-in".to_string(),
1776                        ))
1777                }
1778
1779                let result = match_literal_vector!(match (e1, e2) => Literal {
1780                    Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1781                    Integer => |e1, e2| { int_dot(e1, e2)? },
1782                })?;
1783                self.register_evaluated_expr(Expression::Literal(result), span)
1784            }
1785            crate::MathFunction::Length => {
1786                // https://www.w3.org/TR/WGSL/#length-builtin
1787                let e1 = self.extract_vec(arg, true)?;
1788
1789                fn float_length<F>(e: &[F]) -> F
1790                where
1791                    F: core::ops::Mul<F>,
1792                    F: num_traits::Float + iter::Sum,
1793                {
1794                    e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1795                }
1796
1797                let result = match_literal_vector!(match e1 => Literal {
1798                    Float => |e1| { float_length(e1) },
1799                })?;
1800                self.register_evaluated_expr(Expression::Literal(result), span)
1801            }
1802            crate::MathFunction::Distance => {
1803                // https://www.w3.org/TR/WGSL/#distance-builtin
1804                let e1 = self.extract_vec(arg, true)?;
1805                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1806                if e1.len() != e2.len() {
1807                    return Err(ConstantEvaluatorError::InvalidMathArg);
1808                }
1809
1810                fn float_distance<F>(a: &[F], b: &[F]) -> F
1811                where
1812                    F: core::ops::Mul<F>,
1813                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1814                {
1815                    a.iter()
1816                        .zip(b.iter())
1817                        .map(|(&aa, &bb)| aa - bb)
1818                        .map(|ei| ei * ei)
1819                        .sum::<F>()
1820                        .sqrt()
1821                }
1822                let result = match_literal_vector!(match (e1, e2) => Literal {
1823                    Float => |e1, e2| { float_distance(e1, e2) },
1824                })?;
1825                self.register_evaluated_expr(Expression::Literal(result), span)
1826            }
1827            crate::MathFunction::Normalize => {
1828                // https://www.w3.org/TR/WGSL/#normalize-builtin
1829                let e1 = self.extract_vec(arg, true)?;
1830
1831                fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1832                where
1833                    F: core::ops::Mul<F>,
1834                    F: num_traits::Float + iter::Sum,
1835                {
1836                    let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1837                    e.iter().map(|&ei| ei / len).collect()
1838                }
1839
1840                let result = match_literal_vector!(match e1 => LiteralVector {
1841                    Float => |e1| { float_normalize(e1) },
1842                })?;
1843                result.register_as_evaluated_expr(self, span)
1844            }
1845
1846            // unimplemented
1847            crate::MathFunction::Modf
1848            | crate::MathFunction::Frexp
1849            | crate::MathFunction::Ldexp
1850            | crate::MathFunction::Outer
1851            | crate::MathFunction::FaceForward
1852            | crate::MathFunction::Reflect
1853            | crate::MathFunction::Refract
1854            | crate::MathFunction::Mix
1855            | crate::MathFunction::SmoothStep
1856            | crate::MathFunction::Inverse
1857            | crate::MathFunction::Transpose
1858            | crate::MathFunction::Determinant
1859            | crate::MathFunction::QuantizeToF16
1860            | crate::MathFunction::ExtractBits
1861            | crate::MathFunction::InsertBits
1862            | crate::MathFunction::Pack4x8snorm
1863            | crate::MathFunction::Pack4x8unorm
1864            | crate::MathFunction::Pack2x16snorm
1865            | crate::MathFunction::Pack2x16unorm
1866            | crate::MathFunction::Pack2x16float
1867            | crate::MathFunction::Pack4xI8
1868            | crate::MathFunction::Pack4xU8
1869            | crate::MathFunction::Pack4xI8Clamp
1870            | crate::MathFunction::Pack4xU8Clamp
1871            | crate::MathFunction::Unpack4x8snorm
1872            | crate::MathFunction::Unpack4x8unorm
1873            | crate::MathFunction::Unpack2x16snorm
1874            | crate::MathFunction::Unpack2x16unorm
1875            | crate::MathFunction::Unpack2x16float
1876            | crate::MathFunction::Unpack4xI8
1877            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1878                format!("{fun:?} built-in function"),
1879            )),
1880        }
1881    }
1882
1883    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1884    fn packed_dot_product(
1885        &mut self,
1886        a: Handle<Expression>,
1887        b: Handle<Expression>,
1888        span: Span,
1889        signed: bool,
1890    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1891        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1892            return Err(ConstantEvaluatorError::InvalidMathArg);
1893        };
1894        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1895            return Err(ConstantEvaluatorError::InvalidMathArg);
1896        };
1897
1898        let result = if signed {
1899            Literal::I32(
1900                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1901                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1902                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1903                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1904            )
1905        } else {
1906            Literal::U32(
1907                (a & 0xFF) * (b & 0xFF)
1908                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1909                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1910                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1911            )
1912        };
1913
1914        self.register_evaluated_expr(Expression::Literal(result), span)
1915    }
1916
1917    /// Vector cross product.
1918    fn cross_product(
1919        &mut self,
1920        a: Handle<Expression>,
1921        b: Handle<Expression>,
1922        span: Span,
1923    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1924        use Literal as Li;
1925
1926        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1927        let (b, _) = self.extract_vec_with_size::<3>(b)?;
1928
1929        let product = match (a, b) {
1930            (
1931                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1932                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1933            ) => {
1934                // `cross` has no overload for AbstractInt, so AbstractInt
1935                // arguments are automatically converted to AbstractFloat. Since
1936                // `f64` has a much wider range than `i64`, there's no danger of
1937                // overflow here.
1938                let p = cross_product(
1939                    [a0 as f64, a1 as f64, a2 as f64],
1940                    [b0 as f64, b1 as f64, b2 as f64],
1941                );
1942                [
1943                    Li::AbstractFloat(p[0]),
1944                    Li::AbstractFloat(p[1]),
1945                    Li::AbstractFloat(p[2]),
1946                ]
1947            }
1948            (
1949                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1950                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1951            ) => {
1952                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1953                [
1954                    Li::AbstractFloat(p[0]),
1955                    Li::AbstractFloat(p[1]),
1956                    Li::AbstractFloat(p[2]),
1957                ]
1958            }
1959            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1960                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1961                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1962            }
1963            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1964                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1965                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1966            }
1967            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1968                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1969                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1970            }
1971            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1972        };
1973
1974        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1975        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1976        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1977
1978        self.register_evaluated_expr(
1979            Expression::Compose {
1980                ty,
1981                components: vec![p0, p1, p2],
1982            },
1983            span,
1984        )
1985    }
1986
1987    /// Extract the values of a `vecN` from `expr`.
1988    ///
1989    /// Return the value of `expr`, whose type is `vecN<S>` for some
1990    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1991    /// values.
1992    ///
1993    /// Also return the type handle from the `Compose` expression.
1994    fn extract_vec_with_size<const N: usize>(
1995        &mut self,
1996        expr: Handle<Expression>,
1997    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1998        let span = self.expressions.get_span(expr);
1999        let expr = self.eval_zero_value_and_splat(expr, span)?;
2000        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2001            return Err(ConstantEvaluatorError::InvalidMathArg);
2002        };
2003
2004        let mut value = [Literal::Bool(false); N];
2005        for (component, elt) in
2006            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2007                .zip(value.iter_mut())
2008        {
2009            let Expression::Literal(literal) = self.expressions[component] else {
2010                return Err(ConstantEvaluatorError::InvalidMathArg);
2011            };
2012            *elt = literal;
2013        }
2014
2015        Ok((value, ty))
2016    }
2017
2018    /// Extract the values of a `vecN` from `expr`.
2019    ///
2020    /// Return the value of `expr`, whose type is `vecN<S>` for some
2021    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2022    /// values.
2023    ///
2024    /// Also return the type handle from the `Compose` expression.
2025    fn extract_vec(
2026        &mut self,
2027        expr: Handle<Expression>,
2028        allow_single: bool,
2029    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2030        let span = self.expressions.get_span(expr);
2031        let expr = self.eval_zero_value_and_splat(expr, span)?;
2032
2033        match self.expressions[expr] {
2034            Expression::Literal(literal) if allow_single => {
2035                Ok(LiteralVector::from_literal(literal))
2036            }
2037            Expression::Compose { ty, ref components } => {
2038                let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2039                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2040                        .map(|expr| match self.expressions[expr] {
2041                            Expression::Literal(l) => Ok(l),
2042                            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2043                        })
2044                        .collect::<Result<_, ConstantEvaluatorError>>()?;
2045                LiteralVector::from_literal_vec(components)
2046            }
2047            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2048        }
2049    }
2050
2051    fn array_length(
2052        &mut self,
2053        array: Handle<Expression>,
2054        span: Span,
2055    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2056        match self.expressions[array] {
2057            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2058                match self.types[ty].inner {
2059                    TypeInner::Array { size, .. } => match size {
2060                        ArraySize::Constant(len) => {
2061                            let expr = Expression::Literal(Literal::U32(len.get()));
2062                            self.register_evaluated_expr(expr, span)
2063                        }
2064                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2065                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2066                    },
2067                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2068                }
2069            }
2070            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2071        }
2072    }
2073
2074    fn access(
2075        &mut self,
2076        base: Handle<Expression>,
2077        index: usize,
2078        span: Span,
2079    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2080        match self.expressions[base] {
2081            Expression::ZeroValue(ty) => {
2082                let ty_inner = &self.types[ty].inner;
2083                let components = ty_inner
2084                    .components()
2085                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2086
2087                if index >= components as usize {
2088                    Err(ConstantEvaluatorError::InvalidAccessBase)
2089                } else {
2090                    let ty_res = ty_inner
2091                        .component_type(index)
2092                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2093                    let ty = match ty_res {
2094                        crate::proc::TypeResolution::Handle(ty) => ty,
2095                        crate::proc::TypeResolution::Value(inner) => {
2096                            self.types.insert(Type { name: None, inner }, span)
2097                        }
2098                    };
2099                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2100                }
2101            }
2102            Expression::Splat { size, value } => {
2103                if index >= size as usize {
2104                    Err(ConstantEvaluatorError::InvalidAccessBase)
2105                } else {
2106                    Ok(value)
2107                }
2108            }
2109            Expression::Compose { ty, ref components } => {
2110                let _ = self.types[ty]
2111                    .inner
2112                    .components()
2113                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2114
2115                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2116                    .nth(index)
2117                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2118            }
2119            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2120        }
2121    }
2122
2123    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2124    ///
2125    /// [`ZeroValue`]: Expression::ZeroValue
2126    /// [`Splat`]: Expression::Splat
2127    /// [`Literal`]: Expression::Literal
2128    /// [`Compose`]: Expression::Compose
2129    fn eval_zero_value_and_splat(
2130        &mut self,
2131        mut expr: Handle<Expression>,
2132        span: Span,
2133    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2134        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2135        // each of its components.
2136        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2137            let components = components
2138                .clone()
2139                .iter()
2140                .map(|component| self.eval_zero_value_and_splat(*component, span))
2141                .collect::<Result<_, _>>()?;
2142            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2143        }
2144
2145        // The result of the splat() for a Splat of a scalar ZeroValue is a
2146        // vector ZeroValue, so we must call eval_zero_value_impl() after
2147        // splat() in order to ensure we have no ZeroValues remaining.
2148        if let Expression::Splat { size, value } = self.expressions[expr] {
2149            expr = self.splat(value, size, span)?;
2150        }
2151        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2152            expr = self.eval_zero_value_impl(ty, span)?;
2153        }
2154        Ok(expr)
2155    }
2156
2157    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2158    ///
2159    /// [`ZeroValue`]: Expression::ZeroValue
2160    /// [`Literal`]: Expression::Literal
2161    /// [`Compose`]: Expression::Compose
2162    fn eval_zero_value(
2163        &mut self,
2164        expr: Handle<Expression>,
2165        span: Span,
2166    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2167        match self.expressions[expr] {
2168            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2169            _ => Ok(expr),
2170        }
2171    }
2172
2173    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2174    ///
2175    /// [`ZeroValue`]: Expression::ZeroValue
2176    /// [`Literal`]: Expression::Literal
2177    /// [`Compose`]: Expression::Compose
2178    fn eval_zero_value_impl(
2179        &mut self,
2180        ty: Handle<Type>,
2181        span: Span,
2182    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2183        match self.types[ty].inner {
2184            TypeInner::Scalar(scalar) => {
2185                let expr = Expression::Literal(
2186                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2187                );
2188                self.register_evaluated_expr(expr, span)
2189            }
2190            TypeInner::Vector { size, scalar } => {
2191                let scalar_ty = self.types.insert(
2192                    Type {
2193                        name: None,
2194                        inner: TypeInner::Scalar(scalar),
2195                    },
2196                    span,
2197                );
2198                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2199                let expr = Expression::Compose {
2200                    ty,
2201                    components: vec![el; size as usize],
2202                };
2203                self.register_evaluated_expr(expr, span)
2204            }
2205            TypeInner::Matrix {
2206                columns,
2207                rows,
2208                scalar,
2209            } => {
2210                let vec_ty = self.types.insert(
2211                    Type {
2212                        name: None,
2213                        inner: TypeInner::Vector { size: rows, scalar },
2214                    },
2215                    span,
2216                );
2217                let el = self.eval_zero_value_impl(vec_ty, span)?;
2218                let expr = Expression::Compose {
2219                    ty,
2220                    components: vec![el; columns as usize],
2221                };
2222                self.register_evaluated_expr(expr, span)
2223            }
2224            TypeInner::Array {
2225                base,
2226                size: ArraySize::Constant(size),
2227                ..
2228            } => {
2229                let el = self.eval_zero_value_impl(base, span)?;
2230                let expr = Expression::Compose {
2231                    ty,
2232                    components: vec![el; size.get() as usize],
2233                };
2234                self.register_evaluated_expr(expr, span)
2235            }
2236            TypeInner::Struct { ref members, .. } => {
2237                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2238                let mut components = Vec::with_capacity(members.len());
2239                for ty in types {
2240                    components.push(self.eval_zero_value_impl(ty, span)?);
2241                }
2242                let expr = Expression::Compose { ty, components };
2243                self.register_evaluated_expr(expr, span)
2244            }
2245            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2246        }
2247    }
2248
2249    /// Convert the scalar components of `expr` to `target`.
2250    ///
2251    /// Treat `span` as the location of the resulting expression.
2252    pub fn cast(
2253        &mut self,
2254        expr: Handle<Expression>,
2255        target: crate::Scalar,
2256        span: Span,
2257    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2258        use crate::Scalar as Sc;
2259
2260        let expr = self.eval_zero_value(expr, span)?;
2261
2262        let make_error = || -> Result<_, ConstantEvaluatorError> {
2263            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2264
2265            #[cfg(feature = "wgsl-in")]
2266            let to = target.to_wgsl_for_diagnostics();
2267
2268            #[cfg(not(feature = "wgsl-in"))]
2269            let to = format!("{target:?}");
2270
2271            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2272        };
2273
2274        use crate::proc::type_methods::IntFloatLimits;
2275
2276        let expr = match self.expressions[expr] {
2277            Expression::Literal(literal) => {
2278                let literal = match target {
2279                    Sc::I32 => Literal::I32(match literal {
2280                        Literal::I32(v) => v,
2281                        Literal::U32(v) => v as i32,
2282                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2283                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2284                        Literal::Bool(v) => v as i32,
2285                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2286                            return make_error();
2287                        }
2288                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2289                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2290                    }),
2291                    Sc::U32 => Literal::U32(match literal {
2292                        Literal::I32(v) => v as u32,
2293                        Literal::U32(v) => v,
2294                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2295                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2296                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2297                        Literal::Bool(v) => v as u32,
2298                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2299                            return make_error();
2300                        }
2301                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2302                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2303                    }),
2304                    Sc::I64 => Literal::I64(match literal {
2305                        Literal::I32(v) => v as i64,
2306                        Literal::U32(v) => v as i64,
2307                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2308                        Literal::Bool(v) => v as i64,
2309                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2310                        Literal::I64(v) => v,
2311                        Literal::U64(v) => v as i64,
2312                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2313                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2314                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2315                    }),
2316                    Sc::U64 => Literal::U64(match literal {
2317                        Literal::I32(v) => v as u64,
2318                        Literal::U32(v) => v as u64,
2319                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2320                        Literal::Bool(v) => v as u64,
2321                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2322                        Literal::I64(v) => v as u64,
2323                        Literal::U64(v) => v,
2324                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2325                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2326                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2327                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2328                    }),
2329                    Sc::F16 => Literal::F16(match literal {
2330                        Literal::F16(v) => v,
2331                        Literal::F32(v) => f16::from_f32(v),
2332                        Literal::F64(v) => f16::from_f64(v),
2333                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2334                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2335                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2336                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2337                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2338                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2339                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2340                    }),
2341                    Sc::F32 => Literal::F32(match literal {
2342                        Literal::I32(v) => v as f32,
2343                        Literal::U32(v) => v as f32,
2344                        Literal::F32(v) => v,
2345                        Literal::Bool(v) => v as u32 as f32,
2346                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2347                            return make_error();
2348                        }
2349                        Literal::F16(v) => f16::to_f32(v),
2350                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2351                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2352                    }),
2353                    Sc::F64 => Literal::F64(match literal {
2354                        Literal::I32(v) => v as f64,
2355                        Literal::U32(v) => v as f64,
2356                        Literal::F16(v) => f16::to_f64(v),
2357                        Literal::F32(v) => v as f64,
2358                        Literal::F64(v) => v,
2359                        Literal::Bool(v) => v as u32 as f64,
2360                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2361                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2362                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2363                    }),
2364                    Sc::BOOL => Literal::Bool(match literal {
2365                        Literal::I32(v) => v != 0,
2366                        Literal::U32(v) => v != 0,
2367                        Literal::F32(v) => v != 0.0,
2368                        Literal::F16(v) => v != f16::zero(),
2369                        Literal::Bool(v) => v,
2370                        Literal::AbstractInt(v) => v != 0,
2371                        Literal::AbstractFloat(v) => v != 0.0,
2372                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2373                            return make_error();
2374                        }
2375                    }),
2376                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2377                        Literal::AbstractInt(v) => {
2378                            // Overflow is forbidden, but inexact conversions
2379                            // are fine. The range of f64 is far larger than
2380                            // that of i64, so we don't have to check anything
2381                            // here.
2382                            v as f64
2383                        }
2384                        Literal::AbstractFloat(v) => v,
2385                        _ => return make_error(),
2386                    }),
2387                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2388                        Literal::AbstractInt(v) => v,
2389                        _ => return make_error(),
2390                    }),
2391                    _ => {
2392                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2393                        return make_error();
2394                    }
2395                };
2396                Expression::Literal(literal)
2397            }
2398            Expression::Compose {
2399                ty,
2400                components: ref src_components,
2401            } => {
2402                let ty_inner = match self.types[ty].inner {
2403                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2404                        size,
2405                        scalar: target,
2406                    },
2407                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2408                        columns,
2409                        rows,
2410                        scalar: target,
2411                    },
2412                    _ => return make_error(),
2413                };
2414
2415                let mut components = src_components.clone();
2416                for component in &mut components {
2417                    *component = self.cast(*component, target, span)?;
2418                }
2419
2420                let ty = self.types.insert(
2421                    Type {
2422                        name: None,
2423                        inner: ty_inner,
2424                    },
2425                    span,
2426                );
2427
2428                Expression::Compose { ty, components }
2429            }
2430            Expression::Splat { size, value } => {
2431                let value_span = self.expressions.get_span(value);
2432                let cast_value = self.cast(value, target, value_span)?;
2433                Expression::Splat {
2434                    size,
2435                    value: cast_value,
2436                }
2437            }
2438            _ => return make_error(),
2439        };
2440
2441        self.register_evaluated_expr(expr, span)
2442    }
2443
2444    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2445    ///
2446    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2447    /// matrix, or nested arrays of such.
2448    ///
2449    /// This is basically the same as the [`cast`] method, except that that
2450    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2451    ///
2452    /// Treat `span` as the location of the resulting expression.
2453    ///
2454    /// [`cast`]: ConstantEvaluator::cast
2455    /// [`As`]: crate::Expression::As
2456    pub fn cast_array(
2457        &mut self,
2458        expr: Handle<Expression>,
2459        target: crate::Scalar,
2460        span: Span,
2461    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2462        let expr = self.check_and_get(expr)?;
2463
2464        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2465            return self.cast(expr, target, span);
2466        };
2467
2468        let TypeInner::Array {
2469            base: _,
2470            size,
2471            stride: _,
2472        } = self.types[ty].inner
2473        else {
2474            return self.cast(expr, target, span);
2475        };
2476
2477        let mut components = components.clone();
2478        for component in &mut components {
2479            *component = self.cast_array(*component, target, span)?;
2480        }
2481
2482        let first = components.first().unwrap();
2483        let new_base = match self.resolve_type(*first)? {
2484            crate::proc::TypeResolution::Handle(ty) => ty,
2485            crate::proc::TypeResolution::Value(inner) => {
2486                self.types.insert(Type { name: None, inner }, span)
2487            }
2488        };
2489        let mut layouter = core::mem::take(self.layouter);
2490        layouter.update(self.to_ctx()).unwrap();
2491        *self.layouter = layouter;
2492
2493        let new_base_stride = self.layouter[new_base].to_stride();
2494        let new_array_ty = self.types.insert(
2495            Type {
2496                name: None,
2497                inner: TypeInner::Array {
2498                    base: new_base,
2499                    size,
2500                    stride: new_base_stride,
2501                },
2502            },
2503            span,
2504        );
2505
2506        let compose = Expression::Compose {
2507            ty: new_array_ty,
2508            components,
2509        };
2510        self.register_evaluated_expr(compose, span)
2511    }
2512
2513    fn unary_op(
2514        &mut self,
2515        op: UnaryOperator,
2516        expr: Handle<Expression>,
2517        span: Span,
2518    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2519        let expr = self.eval_zero_value_and_splat(expr, span)?;
2520
2521        let expr = match self.expressions[expr] {
2522            Expression::Literal(value) => Expression::Literal(match op {
2523                UnaryOperator::Negate => match value {
2524                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2525                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2526                    Literal::F32(v) => Literal::F32(-v),
2527                    Literal::F16(v) => Literal::F16(-v),
2528                    Literal::F64(v) => Literal::F64(-v),
2529                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2530                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2531                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2532                },
2533                UnaryOperator::LogicalNot => match value {
2534                    Literal::Bool(v) => Literal::Bool(!v),
2535                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2536                },
2537                UnaryOperator::BitwiseNot => match value {
2538                    Literal::I32(v) => Literal::I32(!v),
2539                    Literal::I64(v) => Literal::I64(!v),
2540                    Literal::U32(v) => Literal::U32(!v),
2541                    Literal::U64(v) => Literal::U64(!v),
2542                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2543                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2544                },
2545            }),
2546            Expression::Compose {
2547                ty,
2548                components: ref src_components,
2549            } => {
2550                match self.types[ty].inner {
2551                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2552                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2553                }
2554
2555                let mut components = src_components.clone();
2556                for component in &mut components {
2557                    *component = self.unary_op(op, *component, span)?;
2558                }
2559
2560                Expression::Compose { ty, components }
2561            }
2562            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2563        };
2564
2565        self.register_evaluated_expr(expr, span)
2566    }
2567
2568    fn binary_op(
2569        &mut self,
2570        op: BinaryOperator,
2571        left: Handle<Expression>,
2572        right: Handle<Expression>,
2573        span: Span,
2574    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2575        let left = self.eval_zero_value_and_splat(left, span)?;
2576        let right = self.eval_zero_value_and_splat(right, span)?;
2577
2578        // Note: in most cases constant evaluation checks for overflow, but for
2579        // i32/u32, it uses wrapping arithmetic. See
2580        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2581
2582        let expr = match (&self.expressions[left], &self.expressions[right]) {
2583            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2584                let literal = match op {
2585                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2586                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2587                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2588                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2589                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2590                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2591
2592                    _ => match (left_value, right_value) {
2593                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2594                            BinaryOperator::Add => a.wrapping_add(b),
2595                            BinaryOperator::Subtract => a.wrapping_sub(b),
2596                            BinaryOperator::Multiply => a.wrapping_mul(b),
2597                            BinaryOperator::Divide => {
2598                                if b == 0 {
2599                                    return Err(ConstantEvaluatorError::DivisionByZero);
2600                                } else {
2601                                    a.wrapping_div(b)
2602                                }
2603                            }
2604                            BinaryOperator::Modulo => {
2605                                if b == 0 {
2606                                    return Err(ConstantEvaluatorError::RemainderByZero);
2607                                } else {
2608                                    a.wrapping_rem(b)
2609                                }
2610                            }
2611                            BinaryOperator::And => a & b,
2612                            BinaryOperator::ExclusiveOr => a ^ b,
2613                            BinaryOperator::InclusiveOr => a | b,
2614                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2615                        }),
2616                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2617                            BinaryOperator::ShiftLeft => {
2618                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2619                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2620                                }
2621                                a.checked_shl(b)
2622                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2623                            }
2624                            BinaryOperator::ShiftRight => a
2625                                .checked_shr(b)
2626                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2627                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2628                        }),
2629                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2630                            BinaryOperator::Add => a.wrapping_add(b),
2631                            BinaryOperator::Subtract => a.wrapping_sub(b),
2632                            BinaryOperator::Multiply => a.wrapping_mul(b),
2633                            BinaryOperator::Divide => a
2634                                .checked_div(b)
2635                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2636                            BinaryOperator::Modulo => a
2637                                .checked_rem(b)
2638                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2639                            BinaryOperator::And => a & b,
2640                            BinaryOperator::ExclusiveOr => a ^ b,
2641                            BinaryOperator::InclusiveOr => a | b,
2642                            BinaryOperator::ShiftLeft => a
2643                                .checked_mul(
2644                                    1u32.checked_shl(b)
2645                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2646                                )
2647                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2648                            BinaryOperator::ShiftRight => a
2649                                .checked_shr(b)
2650                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2651                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2652                        }),
2653                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2654                            BinaryOperator::Add => a + b,
2655                            BinaryOperator::Subtract => a - b,
2656                            BinaryOperator::Multiply => a * b,
2657                            BinaryOperator::Divide => a / b,
2658                            BinaryOperator::Modulo => a % b,
2659                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2660                        }),
2661                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2662                            Literal::AbstractInt(match op {
2663                                BinaryOperator::ShiftLeft => {
2664                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2665                                        return Err(ConstantEvaluatorError::Overflow(
2666                                            "<<".to_string(),
2667                                        ));
2668                                    }
2669                                    a.checked_shl(b).unwrap_or(0)
2670                                }
2671                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2672                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2673                            })
2674                        }
2675                        (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2676                            BinaryOperator::Add => a + b,
2677                            BinaryOperator::Subtract => a - b,
2678                            BinaryOperator::Multiply => a * b,
2679                            BinaryOperator::Divide => a / b,
2680                            BinaryOperator::Modulo => a % b,
2681                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2682                        }),
2683                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2684                            Literal::AbstractInt(match op {
2685                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2686                                    ConstantEvaluatorError::Overflow("addition".into())
2687                                })?,
2688                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2689                                    ConstantEvaluatorError::Overflow("subtraction".into())
2690                                })?,
2691                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2692                                    ConstantEvaluatorError::Overflow("multiplication".into())
2693                                })?,
2694                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2695                                    if b == 0 {
2696                                        ConstantEvaluatorError::DivisionByZero
2697                                    } else {
2698                                        ConstantEvaluatorError::Overflow("division".into())
2699                                    }
2700                                })?,
2701                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2702                                    if b == 0 {
2703                                        ConstantEvaluatorError::RemainderByZero
2704                                    } else {
2705                                        ConstantEvaluatorError::Overflow("remainder".into())
2706                                    }
2707                                })?,
2708                                BinaryOperator::And => a & b,
2709                                BinaryOperator::ExclusiveOr => a ^ b,
2710                                BinaryOperator::InclusiveOr => a | b,
2711                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2712                            })
2713                        }
2714                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2715                            Literal::AbstractFloat(match op {
2716                                BinaryOperator::Add => a + b,
2717                                BinaryOperator::Subtract => a - b,
2718                                BinaryOperator::Multiply => a * b,
2719                                BinaryOperator::Divide => a / b,
2720                                BinaryOperator::Modulo => a % b,
2721                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2722                            })
2723                        }
2724                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2725                            BinaryOperator::LogicalAnd => a && b,
2726                            BinaryOperator::LogicalOr => a || b,
2727                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728                        }),
2729                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2730                    },
2731                };
2732                Expression::Literal(literal)
2733            }
2734            (
2735                &Expression::Compose {
2736                    components: ref src_components,
2737                    ty,
2738                },
2739                &Expression::Literal(_),
2740            ) => {
2741                let mut components = src_components.clone();
2742                for component in &mut components {
2743                    *component = self.binary_op(op, *component, right, span)?;
2744                }
2745                Expression::Compose { ty, components }
2746            }
2747            (
2748                &Expression::Literal(_),
2749                &Expression::Compose {
2750                    components: ref src_components,
2751                    ty,
2752                },
2753            ) => {
2754                let mut components = src_components.clone();
2755                for component in &mut components {
2756                    *component = self.binary_op(op, left, *component, span)?;
2757                }
2758                Expression::Compose { ty, components }
2759            }
2760            (
2761                &Expression::Compose {
2762                    components: ref left_components,
2763                    ty: left_ty,
2764                },
2765                &Expression::Compose {
2766                    components: ref right_components,
2767                    ty: right_ty,
2768                },
2769            ) => {
2770                // We have to make a copy of the component lists, because the
2771                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2772                // the component lists.
2773                let left_flattened = crate::proc::flatten_compose(
2774                    left_ty,
2775                    left_components,
2776                    self.expressions,
2777                    self.types,
2778                );
2779                let right_flattened = crate::proc::flatten_compose(
2780                    right_ty,
2781                    right_components,
2782                    self.expressions,
2783                    self.types,
2784                );
2785
2786                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2787                // make a reasonable guess of the capacity we'll need.
2788                let mut flattened = Vec::with_capacity(left_components.len());
2789                flattened.extend(left_flattened.zip(right_flattened));
2790
2791                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2792                    (
2793                        &TypeInner::Vector {
2794                            size: left_size, ..
2795                        },
2796                        &TypeInner::Vector {
2797                            size: right_size, ..
2798                        },
2799                    ) if left_size == right_size => {
2800                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2801                    }
2802                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2803                }
2804            }
2805            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2806        };
2807
2808        self.register_evaluated_expr(expr, span)
2809    }
2810
2811    fn binary_op_vector(
2812        &mut self,
2813        op: BinaryOperator,
2814        size: crate::VectorSize,
2815        components: &[(Handle<Expression>, Handle<Expression>)],
2816        left_ty: Handle<Type>,
2817        span: Span,
2818    ) -> Result<Expression, ConstantEvaluatorError> {
2819        let ty = match op {
2820            // Relational operators produce vectors of booleans.
2821            BinaryOperator::Equal
2822            | BinaryOperator::NotEqual
2823            | BinaryOperator::Less
2824            | BinaryOperator::LessEqual
2825            | BinaryOperator::Greater
2826            | BinaryOperator::GreaterEqual => self.types.insert(
2827                Type {
2828                    name: None,
2829                    inner: TypeInner::Vector {
2830                        size,
2831                        scalar: crate::Scalar::BOOL,
2832                    },
2833                },
2834                span,
2835            ),
2836
2837            // Other operators produce the same type as their left
2838            // operand.
2839            BinaryOperator::Add
2840            | BinaryOperator::Subtract
2841            | BinaryOperator::Multiply
2842            | BinaryOperator::Divide
2843            | BinaryOperator::Modulo
2844            | BinaryOperator::And
2845            | BinaryOperator::ExclusiveOr
2846            | BinaryOperator::InclusiveOr
2847            | BinaryOperator::ShiftLeft
2848            | BinaryOperator::ShiftRight => left_ty,
2849
2850            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2851                // Not supported on vectors
2852                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2853            }
2854        };
2855
2856        let components = components
2857            .iter()
2858            .map(|&(left, right)| self.binary_op(op, left, right, span))
2859            .collect::<Result<Vec<_>, _>>()?;
2860
2861        Ok(Expression::Compose { ty, components })
2862    }
2863
2864    fn relational(
2865        &mut self,
2866        fun: RelationalFunction,
2867        arg: Handle<Expression>,
2868        span: Span,
2869    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2870        let arg = self.eval_zero_value_and_splat(arg, span)?;
2871        match fun {
2872            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2873                Expression::Literal(Literal::Bool(_)) => Ok(arg),
2874                Expression::Compose { ty, ref components }
2875                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2876                {
2877                    let components =
2878                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2879                            .map(|component| match self.expressions[component] {
2880                                Expression::Literal(Literal::Bool(val)) => Ok(val),
2881                                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2882                            })
2883                            .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2884                    let result = match fun {
2885                        RelationalFunction::All => components.iter().all(|c| *c),
2886                        RelationalFunction::Any => components.iter().any(|c| *c),
2887                        _ => unreachable!(),
2888                    };
2889                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2890                }
2891                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2892            },
2893            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2894                "{fun:?} built-in function"
2895            ))),
2896        }
2897    }
2898
2899    /// Deep copy `expr` from `expressions` into `self.expressions`.
2900    ///
2901    /// Return the root of the new copy.
2902    ///
2903    /// This is used when we're evaluating expressions in a function's
2904    /// expression arena that refer to a constant: we need to copy the
2905    /// constant's value into the function's arena so we can operate on it.
2906    fn copy_from(
2907        &mut self,
2908        expr: Handle<Expression>,
2909        expressions: &Arena<Expression>,
2910    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2911        let span = expressions.get_span(expr);
2912        match expressions[expr] {
2913            ref expr @ (Expression::Literal(_)
2914            | Expression::Constant(_)
2915            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2916            Expression::Compose { ty, ref components } => {
2917                let mut components = components.clone();
2918                for component in &mut components {
2919                    *component = self.copy_from(*component, expressions)?;
2920                }
2921                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2922            }
2923            Expression::Splat { size, value } => {
2924                let value = self.copy_from(value, expressions)?;
2925                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2926            }
2927            _ => {
2928                log::debug!("copy_from: SubexpressionsAreNotConstant");
2929                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2930            }
2931        }
2932    }
2933
2934    /// Returns the total number of components, after flattening, of a vector compose expression.
2935    fn vector_compose_flattened_size(
2936        &self,
2937        components: &[Handle<Expression>],
2938    ) -> Result<usize, ConstantEvaluatorError> {
2939        components
2940            .iter()
2941            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2942                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2943                    TypeInner::Scalar(_) => 1,
2944                    // We trust that the vector size of `component` is correct,
2945                    // as it will have already been validated when `component`
2946                    // was registered.
2947                    TypeInner::Vector { size, .. } => size as usize,
2948                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2949                };
2950                Ok(acc + size)
2951            })
2952    }
2953
2954    fn register_evaluated_expr(
2955        &mut self,
2956        expr: Expression,
2957        span: Span,
2958    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2959        // It suffices to only check_literal_value() for `Literal` expressions,
2960        // since we only register one expression at a time, `Compose`
2961        // expressions can only refer to other expressions, and `ZeroValue`
2962        // expressions are always okay.
2963        if let Expression::Literal(literal) = expr {
2964            crate::valid::check_literal_value(literal)?;
2965        }
2966
2967        // Ensure vector composes contain the correct number of components. We
2968        // do so here when each compose is registered to avoid having to deal
2969        // with the mess each time the compose is used in another expression.
2970        if let Expression::Compose { ty, ref components } = expr {
2971            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2972                let expected = size as usize;
2973                let actual = self.vector_compose_flattened_size(components)?;
2974                if expected != actual {
2975                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2976                        expected,
2977                        actual,
2978                    });
2979                }
2980            }
2981        }
2982
2983        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2984    }
2985
2986    fn append_expr(
2987        &mut self,
2988        expr: Expression,
2989        span: Span,
2990        expr_type: ExpressionKind,
2991    ) -> Handle<Expression> {
2992        let h = match self.behavior {
2993            Behavior::Wgsl(
2994                WgslRestrictions::Runtime(ref mut function_local_data)
2995                | WgslRestrictions::Const(Some(ref mut function_local_data)),
2996            )
2997            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2998                let is_running = function_local_data.emitter.is_running();
2999                let needs_pre_emit = expr.needs_pre_emit();
3000                if is_running && needs_pre_emit {
3001                    function_local_data
3002                        .block
3003                        .extend(function_local_data.emitter.finish(self.expressions));
3004                    let h = self.expressions.append(expr, span);
3005                    function_local_data.emitter.start(self.expressions);
3006                    h
3007                } else {
3008                    self.expressions.append(expr, span)
3009                }
3010            }
3011            _ => self.expressions.append(expr, span),
3012        };
3013        self.expression_kind_tracker.insert(h, expr_type);
3014        h
3015    }
3016
3017    fn resolve_type(
3018        &self,
3019        expr: Handle<Expression>,
3020    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3021        use crate::proc::TypeResolution as Tr;
3022        use crate::Expression as Ex;
3023        let resolution = match self.expressions[expr] {
3024            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3025            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3026            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3027            Ex::Splat { size, value } => {
3028                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3029                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3030                };
3031                Tr::Value(TypeInner::Vector { scalar, size })
3032            }
3033            _ => {
3034                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3035                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3036            }
3037        };
3038
3039        Ok(resolution)
3040    }
3041
3042    fn select(
3043        &mut self,
3044        reject: Handle<Expression>,
3045        accept: Handle<Expression>,
3046        condition: Handle<Expression>,
3047        span: Span,
3048    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3049        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3050
3051        let reject = arg(reject)?;
3052        let accept = arg(accept)?;
3053        let condition = arg(condition)?;
3054
3055        let select_single_component =
3056            |this: &mut Self, reject_scalar, reject, accept, condition| {
3057                let accept = this.cast(accept, reject_scalar, span)?;
3058                if condition {
3059                    Ok(accept)
3060                } else {
3061                    Ok(reject)
3062                }
3063            };
3064
3065        match (&self.expressions[reject], &self.expressions[accept]) {
3066            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3067                let reject_scalar = reject_lit.scalar();
3068                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3069                else {
3070                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3071                };
3072                select_single_component(self, reject_scalar, reject, accept, condition)
3073            }
3074            (
3075                &Expression::Compose {
3076                    ty: reject_ty,
3077                    components: ref reject_components,
3078                },
3079                &Expression::Compose {
3080                    ty: accept_ty,
3081                    components: ref accept_components,
3082                },
3083            ) => {
3084                let ty_deets = |ty| {
3085                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3086                    (size.unwrap(), scalar)
3087                };
3088
3089                let expected_vec_size = {
3090                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3091                        [reject_ty, accept_ty].map(ty_deets);
3092
3093                    if reject_vec_size != accept_vec_size {
3094                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3095                            reject: reject_vec_size,
3096                            accept: accept_vec_size,
3097                        });
3098                    }
3099                    reject_vec_size
3100                };
3101
3102                let condition_components = match self.expressions[condition] {
3103                    Expression::Literal(Literal::Bool(condition)) => {
3104                        vec![condition; (expected_vec_size as u8).into()]
3105                    }
3106                    Expression::Compose {
3107                        ty: condition_ty,
3108                        components: ref condition_components,
3109                    } => {
3110                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3111                        if condition_scalar.kind != ScalarKind::Bool {
3112                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3113                        }
3114                        if condition_vec_size != expected_vec_size {
3115                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3116                        }
3117                        condition_components
3118                            .iter()
3119                            .copied()
3120                            .map(|component| match &self.expressions[component] {
3121                                &Expression::Literal(Literal::Bool(condition)) => condition,
3122                                _ => unreachable!(),
3123                            })
3124                            .collect()
3125                    }
3126
3127                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3128                };
3129
3130                let evaluated = Expression::Compose {
3131                    ty: reject_ty,
3132                    components: reject_components
3133                        .clone()
3134                        .into_iter()
3135                        .zip(accept_components.clone().into_iter())
3136                        .zip(condition_components.into_iter())
3137                        .map(|((reject, accept), condition)| {
3138                            let reject_scalar = match &self.expressions[reject] {
3139                                &Expression::Literal(lit) => lit.scalar(),
3140                                _ => unreachable!(),
3141                            };
3142                            select_single_component(self, reject_scalar, reject, accept, condition)
3143                        })
3144                        .collect::<Result<_, _>>()?,
3145                };
3146                self.register_evaluated_expr(evaluated, span)
3147            }
3148            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3149        }
3150    }
3151}
3152
3153fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3154    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3155    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3156    // return a right-to-left bit index of 0.
3157    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3158        match e {
3159            idx @ 0..=31 => idx,
3160            32 => u32::MAX,
3161            _ => unreachable!(),
3162        }
3163    };
3164    match concrete_int {
3165        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3166        ConcreteInt::I32([e]) => {
3167            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3168        }
3169    }
3170}
3171
3172#[test]
3173fn first_trailing_bit_smoke() {
3174    assert_eq!(
3175        first_trailing_bit(ConcreteInt::I32([0])),
3176        ConcreteInt::I32([-1])
3177    );
3178    assert_eq!(
3179        first_trailing_bit(ConcreteInt::I32([1])),
3180        ConcreteInt::I32([0])
3181    );
3182    assert_eq!(
3183        first_trailing_bit(ConcreteInt::I32([2])),
3184        ConcreteInt::I32([1])
3185    );
3186    assert_eq!(
3187        first_trailing_bit(ConcreteInt::I32([-1])),
3188        ConcreteInt::I32([0]),
3189    );
3190    assert_eq!(
3191        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3192        ConcreteInt::I32([31]),
3193    );
3194    assert_eq!(
3195        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3196        ConcreteInt::I32([0]),
3197    );
3198    for idx in 0..32 {
3199        assert_eq!(
3200            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3201            ConcreteInt::I32([idx])
3202        )
3203    }
3204
3205    assert_eq!(
3206        first_trailing_bit(ConcreteInt::U32([0])),
3207        ConcreteInt::U32([u32::MAX])
3208    );
3209    assert_eq!(
3210        first_trailing_bit(ConcreteInt::U32([1])),
3211        ConcreteInt::U32([0])
3212    );
3213    assert_eq!(
3214        first_trailing_bit(ConcreteInt::U32([2])),
3215        ConcreteInt::U32([1])
3216    );
3217    assert_eq!(
3218        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3219        ConcreteInt::U32([31]),
3220    );
3221    assert_eq!(
3222        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3223        ConcreteInt::U32([0]),
3224    );
3225    for idx in 0..32 {
3226        assert_eq!(
3227            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3228            ConcreteInt::U32([idx])
3229        )
3230    }
3231}
3232
3233fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3234    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3235    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3236    // index of 0.
3237    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3238        match e {
3239            idx @ 0..=31 => 31 - idx,
3240            32 => u32::MAX,
3241            _ => unreachable!(),
3242        }
3243    };
3244    match concrete_int {
3245        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3246            let rtl_bit_index = if e.is_negative() {
3247                e.leading_ones()
3248            } else {
3249                e.leading_zeros()
3250            };
3251            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3252        }]),
3253        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3254    }
3255}
3256
3257#[test]
3258fn first_leading_bit_smoke() {
3259    assert_eq!(
3260        first_leading_bit(ConcreteInt::I32([-1])),
3261        ConcreteInt::I32([-1])
3262    );
3263    assert_eq!(
3264        first_leading_bit(ConcreteInt::I32([0])),
3265        ConcreteInt::I32([-1])
3266    );
3267    assert_eq!(
3268        first_leading_bit(ConcreteInt::I32([1])),
3269        ConcreteInt::I32([0])
3270    );
3271    assert_eq!(
3272        first_leading_bit(ConcreteInt::I32([-2])),
3273        ConcreteInt::I32([0])
3274    );
3275    assert_eq!(
3276        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3277        ConcreteInt::I32([12])
3278    );
3279    assert_eq!(
3280        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3281        ConcreteInt::I32([30])
3282    );
3283    assert_eq!(
3284        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3285        ConcreteInt::I32([30])
3286    );
3287    // NOTE: Ignore the sign bit, which is a separate (above) case.
3288    for idx in 0..(32 - 1) {
3289        assert_eq!(
3290            first_leading_bit(ConcreteInt::I32([1 << idx])),
3291            ConcreteInt::I32([idx])
3292        );
3293    }
3294    for idx in 1..(32 - 1) {
3295        assert_eq!(
3296            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3297            ConcreteInt::I32([idx - 1])
3298        );
3299    }
3300
3301    assert_eq!(
3302        first_leading_bit(ConcreteInt::U32([0])),
3303        ConcreteInt::U32([u32::MAX])
3304    );
3305    assert_eq!(
3306        first_leading_bit(ConcreteInt::U32([1])),
3307        ConcreteInt::U32([0])
3308    );
3309    assert_eq!(
3310        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3311        ConcreteInt::U32([31])
3312    );
3313    for idx in 0..32 {
3314        assert_eq!(
3315            first_leading_bit(ConcreteInt::U32([1 << idx])),
3316            ConcreteInt::U32([idx])
3317        )
3318    }
3319}
3320
3321/// Trait for conversions of abstract values to concrete types.
3322trait TryFromAbstract<T>: Sized {
3323    /// Convert an abstract literal `value` to `Self`.
3324    ///
3325    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3326    /// WGSL, we follow WGSL's conversion rules here:
3327    ///
3328    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3329    ///   from [`AbstractInt`] to an integer type are either lossless or an
3330    ///   error.
3331    ///
3332    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3333    ///   to floating point in constant expressions and override
3334    ///   expressions are errors if the value is out of range for the
3335    ///   destination type, but rounding is okay.
3336    ///
3337    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3338    ///   other floating point type, following the scalar floating point to
3339    ///   integral conversion algorithm (§15.7.6). There is no automatic
3340    ///   conversion from AbstractFloat to integer types.
3341    ///
3342    /// [`AbstractInt`]: crate::Literal::AbstractInt
3343    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3344    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3345}
3346
3347impl TryFromAbstract<i64> for i32 {
3348    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3349        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3350            value: format!("{value:?}"),
3351            to_type: "i32",
3352        })
3353    }
3354}
3355
3356impl TryFromAbstract<i64> for u32 {
3357    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3358        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3359            value: format!("{value:?}"),
3360            to_type: "u32",
3361        })
3362    }
3363}
3364
3365impl TryFromAbstract<i64> for u64 {
3366    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3367        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3368            value: format!("{value:?}"),
3369            to_type: "u64",
3370        })
3371    }
3372}
3373
3374impl TryFromAbstract<i64> for i64 {
3375    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3376        Ok(value)
3377    }
3378}
3379
3380impl TryFromAbstract<i64> for f32 {
3381    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3382        let f = value as f32;
3383        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3384        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3385        // overflow here.
3386        Ok(f)
3387    }
3388}
3389
3390impl TryFromAbstract<f64> for f32 {
3391    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3392        let f = value as f32;
3393        if f.is_infinite() {
3394            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3395                value: format!("{value:?}"),
3396                to_type: "f32",
3397            });
3398        }
3399        Ok(f)
3400    }
3401}
3402
3403impl TryFromAbstract<i64> for f64 {
3404    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3405        let f = value as f64;
3406        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3407        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3408        // overflow here.
3409        Ok(f)
3410    }
3411}
3412
3413impl TryFromAbstract<f64> for f64 {
3414    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3415        Ok(value)
3416    }
3417}
3418
3419impl TryFromAbstract<f64> for i32 {
3420    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3421        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3422        // To convert a floating point scalar value X to an integer scalar type T:
3423        // * If X is a NaN, the result is an indeterminate value in T.
3424        // * If X is exactly representable in the target type T, then the
3425        //   result is that value.
3426        // * Otherwise, the result is the value in T closest to truncate(X) and
3427        //   also exactly representable in the original floating point type.
3428        //
3429        // A rust cast satisfies these requirements apart from "the result
3430        // is... exactly representable in the original floating point type".
3431        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3432        // we're all good.
3433        Ok(value as i32)
3434    }
3435}
3436
3437impl TryFromAbstract<f64> for u32 {
3438    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3439        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3440        // so a simple rust cast is sufficient.
3441        Ok(value as u32)
3442    }
3443}
3444
3445impl TryFromAbstract<f64> for i64 {
3446    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3447        // As above, except we clamp to the minimum and maximum values
3448        // representable by both f64 and i64.
3449        use crate::proc::type_methods::IntFloatLimits;
3450        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3451    }
3452}
3453
3454impl TryFromAbstract<f64> for u64 {
3455    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3456        // As above, this time clamping to the minimum and maximum values
3457        // representable by both f64 and u64.
3458        use crate::proc::type_methods::IntFloatLimits;
3459        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3460    }
3461}
3462
3463impl TryFromAbstract<f64> for f16 {
3464    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3465        let f = f16::from_f64(value);
3466        if f.is_infinite() {
3467            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3468                value: format!("{value:?}"),
3469                to_type: "f16",
3470            });
3471        }
3472        Ok(f)
3473    }
3474}
3475
3476impl TryFromAbstract<i64> for f16 {
3477    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3478        let f = f16::from_i64(value);
3479        if f.is_none() {
3480            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3481                value: format!("{value:?}"),
3482                to_type: "f16",
3483            });
3484        }
3485        Ok(f.unwrap())
3486    }
3487}
3488
3489fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3490where
3491    T: Copy,
3492    T: core::ops::Mul<T, Output = T>,
3493    T: core::ops::Sub<T, Output = T>,
3494{
3495    [
3496        a[1] * b[2] - a[2] * b[1],
3497        a[2] * b[0] - a[0] * b[2],
3498        a[0] * b[1] - a[1] * b[0],
3499    ]
3500}
3501
3502#[cfg(test)]
3503mod tests {
3504    use alloc::{vec, vec::Vec};
3505
3506    use crate::{
3507        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3508        UniqueArena, VectorSize,
3509    };
3510
3511    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3512
3513    #[test]
3514    fn unary_op() {
3515        let mut types = UniqueArena::new();
3516        let mut constants = Arena::new();
3517        let overrides = Arena::new();
3518        let mut global_expressions = Arena::new();
3519
3520        let scalar_ty = types.insert(
3521            Type {
3522                name: None,
3523                inner: TypeInner::Scalar(crate::Scalar::I32),
3524            },
3525            Default::default(),
3526        );
3527
3528        let vec_ty = types.insert(
3529            Type {
3530                name: None,
3531                inner: TypeInner::Vector {
3532                    size: VectorSize::Bi,
3533                    scalar: crate::Scalar::I32,
3534                },
3535            },
3536            Default::default(),
3537        );
3538
3539        let h = constants.append(
3540            Constant {
3541                name: None,
3542                ty: scalar_ty,
3543                init: global_expressions
3544                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3545            },
3546            Default::default(),
3547        );
3548
3549        let h1 = constants.append(
3550            Constant {
3551                name: None,
3552                ty: scalar_ty,
3553                init: global_expressions
3554                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3555            },
3556            Default::default(),
3557        );
3558
3559        let vec_h = constants.append(
3560            Constant {
3561                name: None,
3562                ty: vec_ty,
3563                init: global_expressions.append(
3564                    Expression::Compose {
3565                        ty: vec_ty,
3566                        components: vec![constants[h].init, constants[h1].init],
3567                    },
3568                    Default::default(),
3569                ),
3570            },
3571            Default::default(),
3572        );
3573
3574        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3575        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3576
3577        let expr2 = Expression::Unary {
3578            op: UnaryOperator::Negate,
3579            expr,
3580        };
3581
3582        let expr3 = Expression::Unary {
3583            op: UnaryOperator::BitwiseNot,
3584            expr,
3585        };
3586
3587        let expr4 = Expression::Unary {
3588            op: UnaryOperator::BitwiseNot,
3589            expr: expr1,
3590        };
3591
3592        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3593        let mut solver = ConstantEvaluator {
3594            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3595            types: &mut types,
3596            constants: &constants,
3597            overrides: &overrides,
3598            expressions: &mut global_expressions,
3599            expression_kind_tracker,
3600            layouter: &mut crate::proc::Layouter::default(),
3601        };
3602
3603        let res1 = solver
3604            .try_eval_and_append(expr2, Default::default())
3605            .unwrap();
3606        let res2 = solver
3607            .try_eval_and_append(expr3, Default::default())
3608            .unwrap();
3609        let res3 = solver
3610            .try_eval_and_append(expr4, Default::default())
3611            .unwrap();
3612
3613        assert_eq!(
3614            global_expressions[res1],
3615            Expression::Literal(Literal::I32(-4))
3616        );
3617
3618        assert_eq!(
3619            global_expressions[res2],
3620            Expression::Literal(Literal::I32(!4))
3621        );
3622
3623        let res3_inner = &global_expressions[res3];
3624
3625        match *res3_inner {
3626            Expression::Compose {
3627                ref ty,
3628                ref components,
3629            } => {
3630                assert_eq!(*ty, vec_ty);
3631                let mut components_iter = components.iter().copied();
3632                assert_eq!(
3633                    global_expressions[components_iter.next().unwrap()],
3634                    Expression::Literal(Literal::I32(!4))
3635                );
3636                assert_eq!(
3637                    global_expressions[components_iter.next().unwrap()],
3638                    Expression::Literal(Literal::I32(!8))
3639                );
3640                assert!(components_iter.next().is_none());
3641            }
3642            _ => panic!("Expected vector"),
3643        }
3644    }
3645
3646    #[test]
3647    fn cast() {
3648        let mut types = UniqueArena::new();
3649        let mut constants = Arena::new();
3650        let overrides = Arena::new();
3651        let mut global_expressions = Arena::new();
3652
3653        let scalar_ty = types.insert(
3654            Type {
3655                name: None,
3656                inner: TypeInner::Scalar(crate::Scalar::I32),
3657            },
3658            Default::default(),
3659        );
3660
3661        let h = constants.append(
3662            Constant {
3663                name: None,
3664                ty: scalar_ty,
3665                init: global_expressions
3666                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3667            },
3668            Default::default(),
3669        );
3670
3671        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3672
3673        let root = Expression::As {
3674            expr,
3675            kind: ScalarKind::Bool,
3676            convert: Some(crate::BOOL_WIDTH),
3677        };
3678
3679        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3680        let mut solver = ConstantEvaluator {
3681            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3682            types: &mut types,
3683            constants: &constants,
3684            overrides: &overrides,
3685            expressions: &mut global_expressions,
3686            expression_kind_tracker,
3687            layouter: &mut crate::proc::Layouter::default(),
3688        };
3689
3690        let res = solver
3691            .try_eval_and_append(root, Default::default())
3692            .unwrap();
3693
3694        assert_eq!(
3695            global_expressions[res],
3696            Expression::Literal(Literal::Bool(true))
3697        );
3698    }
3699
3700    #[test]
3701    fn access() {
3702        let mut types = UniqueArena::new();
3703        let mut constants = Arena::new();
3704        let overrides = Arena::new();
3705        let mut global_expressions = Arena::new();
3706
3707        let matrix_ty = types.insert(
3708            Type {
3709                name: None,
3710                inner: TypeInner::Matrix {
3711                    columns: VectorSize::Bi,
3712                    rows: VectorSize::Tri,
3713                    scalar: crate::Scalar::F32,
3714                },
3715            },
3716            Default::default(),
3717        );
3718
3719        let vec_ty = types.insert(
3720            Type {
3721                name: None,
3722                inner: TypeInner::Vector {
3723                    size: VectorSize::Tri,
3724                    scalar: crate::Scalar::F32,
3725                },
3726            },
3727            Default::default(),
3728        );
3729
3730        let mut vec1_components = Vec::with_capacity(3);
3731        let mut vec2_components = Vec::with_capacity(3);
3732
3733        for i in 0..3 {
3734            let h = global_expressions.append(
3735                Expression::Literal(Literal::F32(i as f32)),
3736                Default::default(),
3737            );
3738
3739            vec1_components.push(h)
3740        }
3741
3742        for i in 3..6 {
3743            let h = global_expressions.append(
3744                Expression::Literal(Literal::F32(i as f32)),
3745                Default::default(),
3746            );
3747
3748            vec2_components.push(h)
3749        }
3750
3751        let vec1 = constants.append(
3752            Constant {
3753                name: None,
3754                ty: vec_ty,
3755                init: global_expressions.append(
3756                    Expression::Compose {
3757                        ty: vec_ty,
3758                        components: vec1_components,
3759                    },
3760                    Default::default(),
3761                ),
3762            },
3763            Default::default(),
3764        );
3765
3766        let vec2 = constants.append(
3767            Constant {
3768                name: None,
3769                ty: vec_ty,
3770                init: global_expressions.append(
3771                    Expression::Compose {
3772                        ty: vec_ty,
3773                        components: vec2_components,
3774                    },
3775                    Default::default(),
3776                ),
3777            },
3778            Default::default(),
3779        );
3780
3781        let h = constants.append(
3782            Constant {
3783                name: None,
3784                ty: matrix_ty,
3785                init: global_expressions.append(
3786                    Expression::Compose {
3787                        ty: matrix_ty,
3788                        components: vec![constants[vec1].init, constants[vec2].init],
3789                    },
3790                    Default::default(),
3791                ),
3792            },
3793            Default::default(),
3794        );
3795
3796        let base = global_expressions.append(Expression::Constant(h), Default::default());
3797
3798        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3799        let mut solver = ConstantEvaluator {
3800            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3801            types: &mut types,
3802            constants: &constants,
3803            overrides: &overrides,
3804            expressions: &mut global_expressions,
3805            expression_kind_tracker,
3806            layouter: &mut crate::proc::Layouter::default(),
3807        };
3808
3809        let root1 = Expression::AccessIndex { base, index: 1 };
3810
3811        let res1 = solver
3812            .try_eval_and_append(root1, Default::default())
3813            .unwrap();
3814
3815        let root2 = Expression::AccessIndex {
3816            base: res1,
3817            index: 2,
3818        };
3819
3820        let res2 = solver
3821            .try_eval_and_append(root2, Default::default())
3822            .unwrap();
3823
3824        match global_expressions[res1] {
3825            Expression::Compose {
3826                ref ty,
3827                ref components,
3828            } => {
3829                assert_eq!(*ty, vec_ty);
3830                let mut components_iter = components.iter().copied();
3831                assert_eq!(
3832                    global_expressions[components_iter.next().unwrap()],
3833                    Expression::Literal(Literal::F32(3.))
3834                );
3835                assert_eq!(
3836                    global_expressions[components_iter.next().unwrap()],
3837                    Expression::Literal(Literal::F32(4.))
3838                );
3839                assert_eq!(
3840                    global_expressions[components_iter.next().unwrap()],
3841                    Expression::Literal(Literal::F32(5.))
3842                );
3843                assert!(components_iter.next().is_none());
3844            }
3845            _ => panic!("Expected vector"),
3846        }
3847
3848        assert_eq!(
3849            global_expressions[res2],
3850            Expression::Literal(Literal::F32(5.))
3851        );
3852    }
3853
3854    #[test]
3855    fn compose_of_constants() {
3856        let mut types = UniqueArena::new();
3857        let mut constants = Arena::new();
3858        let overrides = Arena::new();
3859        let mut global_expressions = Arena::new();
3860
3861        let i32_ty = types.insert(
3862            Type {
3863                name: None,
3864                inner: TypeInner::Scalar(crate::Scalar::I32),
3865            },
3866            Default::default(),
3867        );
3868
3869        let vec2_i32_ty = types.insert(
3870            Type {
3871                name: None,
3872                inner: TypeInner::Vector {
3873                    size: VectorSize::Bi,
3874                    scalar: crate::Scalar::I32,
3875                },
3876            },
3877            Default::default(),
3878        );
3879
3880        let h = constants.append(
3881            Constant {
3882                name: None,
3883                ty: i32_ty,
3884                init: global_expressions
3885                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3886            },
3887            Default::default(),
3888        );
3889
3890        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3891
3892        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3893        let mut solver = ConstantEvaluator {
3894            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3895            types: &mut types,
3896            constants: &constants,
3897            overrides: &overrides,
3898            expressions: &mut global_expressions,
3899            expression_kind_tracker,
3900            layouter: &mut crate::proc::Layouter::default(),
3901        };
3902
3903        let solved_compose = solver
3904            .try_eval_and_append(
3905                Expression::Compose {
3906                    ty: vec2_i32_ty,
3907                    components: vec![h_expr, h_expr],
3908                },
3909                Default::default(),
3910            )
3911            .unwrap();
3912        let solved_negate = solver
3913            .try_eval_and_append(
3914                Expression::Unary {
3915                    op: UnaryOperator::Negate,
3916                    expr: solved_compose,
3917                },
3918                Default::default(),
3919            )
3920            .unwrap();
3921
3922        let pass = match global_expressions[solved_negate] {
3923            Expression::Compose { ty, ref components } => {
3924                ty == vec2_i32_ty
3925                    && components.iter().all(|&component| {
3926                        let component = &global_expressions[component];
3927                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3928                    })
3929            }
3930            _ => false,
3931        };
3932        if !pass {
3933            panic!("unexpected evaluation result")
3934        }
3935    }
3936
3937    #[test]
3938    fn splat_of_constant() {
3939        let mut types = UniqueArena::new();
3940        let mut constants = Arena::new();
3941        let overrides = Arena::new();
3942        let mut global_expressions = Arena::new();
3943
3944        let i32_ty = types.insert(
3945            Type {
3946                name: None,
3947                inner: TypeInner::Scalar(crate::Scalar::I32),
3948            },
3949            Default::default(),
3950        );
3951
3952        let vec2_i32_ty = types.insert(
3953            Type {
3954                name: None,
3955                inner: TypeInner::Vector {
3956                    size: VectorSize::Bi,
3957                    scalar: crate::Scalar::I32,
3958                },
3959            },
3960            Default::default(),
3961        );
3962
3963        let h = constants.append(
3964            Constant {
3965                name: None,
3966                ty: i32_ty,
3967                init: global_expressions
3968                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3969            },
3970            Default::default(),
3971        );
3972
3973        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3974
3975        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3976        let mut solver = ConstantEvaluator {
3977            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3978            types: &mut types,
3979            constants: &constants,
3980            overrides: &overrides,
3981            expressions: &mut global_expressions,
3982            expression_kind_tracker,
3983            layouter: &mut crate::proc::Layouter::default(),
3984        };
3985
3986        let solved_compose = solver
3987            .try_eval_and_append(
3988                Expression::Splat {
3989                    size: VectorSize::Bi,
3990                    value: h_expr,
3991                },
3992                Default::default(),
3993            )
3994            .unwrap();
3995        let solved_negate = solver
3996            .try_eval_and_append(
3997                Expression::Unary {
3998                    op: UnaryOperator::Negate,
3999                    expr: solved_compose,
4000                },
4001                Default::default(),
4002            )
4003            .unwrap();
4004
4005        let pass = match global_expressions[solved_negate] {
4006            Expression::Compose { ty, ref components } => {
4007                ty == vec2_i32_ty
4008                    && components.iter().all(|&component| {
4009                        let component = &global_expressions[component];
4010                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4011                    })
4012            }
4013            _ => false,
4014        };
4015        if !pass {
4016            panic!("unexpected evaluation result")
4017        }
4018    }
4019
4020    #[test]
4021    fn splat_of_zero_value() {
4022        let mut types = UniqueArena::new();
4023        let constants = Arena::new();
4024        let overrides = Arena::new();
4025        let mut global_expressions = Arena::new();
4026
4027        let f32_ty = types.insert(
4028            Type {
4029                name: None,
4030                inner: TypeInner::Scalar(crate::Scalar::F32),
4031            },
4032            Default::default(),
4033        );
4034
4035        let vec2_f32_ty = types.insert(
4036            Type {
4037                name: None,
4038                inner: TypeInner::Vector {
4039                    size: VectorSize::Bi,
4040                    scalar: crate::Scalar::F32,
4041                },
4042            },
4043            Default::default(),
4044        );
4045
4046        let five =
4047            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4048        let five_splat = global_expressions.append(
4049            Expression::Splat {
4050                size: VectorSize::Bi,
4051                value: five,
4052            },
4053            Default::default(),
4054        );
4055        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4056        let zero_splat = global_expressions.append(
4057            Expression::Splat {
4058                size: VectorSize::Bi,
4059                value: zero,
4060            },
4061            Default::default(),
4062        );
4063
4064        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4065        let mut solver = ConstantEvaluator {
4066            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4067            types: &mut types,
4068            constants: &constants,
4069            overrides: &overrides,
4070            expressions: &mut global_expressions,
4071            expression_kind_tracker,
4072            layouter: &mut crate::proc::Layouter::default(),
4073        };
4074
4075        let solved_add = solver
4076            .try_eval_and_append(
4077                Expression::Binary {
4078                    op: crate::BinaryOperator::Add,
4079                    left: zero_splat,
4080                    right: five_splat,
4081                },
4082                Default::default(),
4083            )
4084            .unwrap();
4085
4086        let pass = match global_expressions[solved_add] {
4087            Expression::Compose { ty, ref components } => {
4088                ty == vec2_f32_ty
4089                    && components.iter().all(|&component| {
4090                        let component = &global_expressions[component];
4091                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4092                    })
4093            }
4094            _ => false,
4095        };
4096        if !pass {
4097            panic!("unexpected evaluation result")
4098        }
4099    }
4100}