naga/proc/
constant_evaluator.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14    arena::{Arena, Handle, HandleVec, UniqueArena},
15    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16    ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
23/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
24///
25/// Technique stolen directly from
26/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
27macro_rules! with_dollar_sign {
28    ($($body:tt)*) => {
29        macro_rules! __with_dollar_sign { $($body)* }
30        __with_dollar_sign!($);
31    }
32}
33
34macro_rules! gen_component_wise_extractor {
35    (
36        $ident:ident -> $target:ident,
37        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39    ) => {
40        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
41        #[derive(Debug)]
42        #[cfg_attr(test, derive(PartialEq))]
43        enum $target<const N: usize> {
44            $(
45                #[doc = concat!(
46                    "Maps to [`Literal::",
47                    stringify!($literal),
48                    "`]",
49                )]
50                $mapping([$ty; N]),
51            )+
52        }
53
54        impl From<$target<1>> for Expression {
55            fn from(value: $target<1>) -> Self {
56                match value {
57                    $(
58                        $target::$mapping([value]) => {
59                            Expression::Literal(Literal::$literal(value))
60                        }
61                    )+
62                }
63            }
64        }
65
66        #[doc = concat!(
67            "Attempts to evaluate multiple `exprs` as a combined [`",
68            stringify!($target),
69            "`] to pass to `handler`. ",
70        )]
71        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
72        /// component of each vector.
73        ///
74        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
75        /// same length, a new vector expression is registered, composed of each component emitted
76        /// by `handler`.
77        fn $ident<const N: usize, const M: usize, F>(
78            eval: &mut ConstantEvaluator<'_>,
79            span: Span,
80            exprs: [Handle<Expression>; N],
81            mut handler: F,
82        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83        where
84            $target<M>: Into<Expression>,
85            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86        {
87            assert!(N > 0);
88            let err = ConstantEvaluatorError::InvalidMathArg;
89            let mut exprs = exprs.into_iter();
90
91            macro_rules! sanitize {
92                ($expr:expr) => {
93                    eval.eval_zero_value_and_splat($expr, span)
94                        .map(|expr| &eval.expressions[expr])
95                };
96            }
97
98            let new_expr = match sanitize!(exprs.next().unwrap())? {
99                $(
100                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101                        .chain(exprs.map(|expr| {
102                            sanitize!(expr).and_then(|expr| match expr {
103                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
104                                _ => Err(err.clone()),
105                            })
106                        }))
107                        .collect::<Result<ArrayVec<_, N>, _>>()
108                        .map(|a| a.into_inner().unwrap())
109                        .map($target::$mapping)
110                        .and_then(|comps| Ok(handler(comps)?.into())),
111                )+
112                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113                    &TypeInner::Vector { size, scalar } => match scalar.kind {
114                        $(ScalarKind::$scalar_kind)|* => {
115                            let first_ty = ty;
116                            let mut component_groups =
117                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118                            component_groups.push(crate::proc::flatten_compose(
119                                first_ty,
120                                components,
121                                eval.expressions,
122                                eval.types,
123                            ).collect());
124                            component_groups.extend(
125                                exprs
126                                    .map(|expr| {
127                                        sanitize!(expr).and_then(|expr| match expr {
128                                            &Expression::Compose { ty, ref components }
129                                                if &eval.types[ty].inner
130                                                    == &eval.types[first_ty].inner =>
131                                            {
132                                                Ok(crate::proc::flatten_compose(
133                                                    ty,
134                                                    components,
135                                                    eval.expressions,
136                                                    eval.types,
137                                                ).collect())
138                                            }
139                                            _ => Err(err.clone()),
140                                        })
141                                    })
142                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143                                    )?,
144                            );
145                            let component_groups = component_groups.into_inner().unwrap();
146                            let mut new_components =
147                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148                            for idx in 0..(size as u8).into() {
149                                let group = component_groups
150                                    .iter()
151                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152                                    .collect::<Result<ArrayVec<_, N>, _>>()?
153                                    .into_inner()
154                                    .unwrap();
155                                new_components.push($ident(
156                                    eval,
157                                    span,
158                                    group,
159                                    handler.clone(),
160                                )?);
161                            }
162                            Ok(Expression::Compose {
163                                ty: first_ty,
164                                components: new_components.into_iter().collect(),
165                            })
166                        }
167                        _ => return Err(err),
168                    },
169                    _ => return Err(err),
170                },
171                _ => return Err(err),
172            }?;
173            eval.register_evaluated_expr(new_expr, span)
174        }
175
176        with_dollar_sign! {
177            ($d:tt) => {
178                #[allow(unused)]
179                #[doc = concat!(
180                    "A convenience macro for using the same RHS for each [`",
181                    stringify!($target),
182                    "`] variant in a call to [`",
183                    stringify!($ident),
184                    "`].",
185                )]
186                macro_rules! $ident {
187                    (
188                        $eval:expr,
189                        $span:expr,
190                        [$d ($d expr:expr),+ $d (,)?],
191                        |$d ($d arg:ident),+| $d tt:tt
192                    ) => {
193                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
194                            $(
195                                $target::$mapping([$d ($d arg),+]) => {
196                                    let res = $d tt;
197                                    Result::map(res, $target::$mapping)
198                                },
199                            )+
200                        })
201                    };
202                }
203            };
204        }
205    };
206}
207
208gen_component_wise_extractor! {
209    component_wise_scalar -> Scalar,
210    literals: [
211        AbstractFloat => AbstractFloat: f64,
212        F32 => F32: f32,
213        F16 => F16: f16,
214        AbstractInt => AbstractInt: i64,
215        U32 => U32: u32,
216        I32 => I32: i32,
217        U64 => U64: u64,
218        I64 => I64: i64,
219    ],
220    scalar_kinds: [
221        Float,
222        AbstractFloat,
223        Sint,
224        Uint,
225        AbstractInt,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_float -> Float,
231    literals: [
232        AbstractFloat => Abstract: f64,
233        F32 => F32: f32,
234        F16 => F16: f16,
235    ],
236    scalar_kinds: [
237        Float,
238        AbstractFloat,
239    ],
240}
241
242gen_component_wise_extractor! {
243    component_wise_concrete_int -> ConcreteInt,
244    literals: [
245        U32 => U32: u32,
246        I32 => I32: i32,
247    ],
248    scalar_kinds: [
249        Sint,
250        Uint,
251    ],
252}
253
254gen_component_wise_extractor! {
255    component_wise_signed -> Signed,
256    literals: [
257        AbstractFloat => AbstractFloat: f64,
258        AbstractInt => AbstractInt: i64,
259        F32 => F32: f32,
260        F16 => F16: f16,
261        I32 => I32: i32,
262    ],
263    scalar_kinds: [
264        Sint,
265        AbstractInt,
266        Float,
267        AbstractFloat,
268    ],
269}
270
271/// Vectors with a concrete element type.
272#[derive(Debug)]
273enum LiteralVector {
274    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288    fn len(&self) -> usize {
289        match *self {
290            LiteralVector::F64(ref v) => v.len(),
291            LiteralVector::F32(ref v) => v.len(),
292            LiteralVector::F16(ref v) => v.len(),
293            LiteralVector::U32(ref v) => v.len(),
294            LiteralVector::I32(ref v) => v.len(),
295            LiteralVector::U64(ref v) => v.len(),
296            LiteralVector::I64(ref v) => v.len(),
297            LiteralVector::Bool(ref v) => v.len(),
298            LiteralVector::AbstractInt(ref v) => v.len(),
299            LiteralVector::AbstractFloat(ref v) => v.len(),
300        }
301    }
302
303    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304    fn from_literal(literal: Literal) -> Self {
305        match literal {
306            Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307            Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308            Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309            Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310            Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311            Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312            Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313            Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314            Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315            Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316        }
317    }
318
319    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320    /// Returns error if components types do not match.
321    /// # Panics
322    /// Panics if vector is empty
323    fn from_literal_vec(
324        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325    ) -> Result<Self, ConstantEvaluatorError> {
326        assert!(!components.is_empty());
327        Ok(match components[0] {
328            Literal::I32(_) => Self::I32(
329                components
330                    .iter()
331                    .map(|l| match l {
332                        &Literal::I32(v) => Ok(v),
333                        // TODO: should we handle abstract int here?
334                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
335                    })
336                    .collect::<Result<_, _>>()?,
337            ),
338            Literal::U32(_) => Self::U32(
339                components
340                    .iter()
341                    .map(|l| match l {
342                        &Literal::U32(v) => Ok(v),
343                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
344                    })
345                    .collect::<Result<_, _>>()?,
346            ),
347            Literal::I64(_) => Self::I64(
348                components
349                    .iter()
350                    .map(|l| match l {
351                        &Literal::I64(v) => Ok(v),
352                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
353                    })
354                    .collect::<Result<_, _>>()?,
355            ),
356            Literal::U64(_) => Self::U64(
357                components
358                    .iter()
359                    .map(|l| match l {
360                        &Literal::U64(v) => Ok(v),
361                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
362                    })
363                    .collect::<Result<_, _>>()?,
364            ),
365            Literal::F32(_) => Self::F32(
366                components
367                    .iter()
368                    .map(|l| match l {
369                        &Literal::F32(v) => Ok(v),
370                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
371                    })
372                    .collect::<Result<_, _>>()?,
373            ),
374            Literal::F64(_) => Self::F64(
375                components
376                    .iter()
377                    .map(|l| match l {
378                        &Literal::F64(v) => Ok(v),
379                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
380                    })
381                    .collect::<Result<_, _>>()?,
382            ),
383            Literal::Bool(_) => Self::Bool(
384                components
385                    .iter()
386                    .map(|l| match l {
387                        &Literal::Bool(v) => Ok(v),
388                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
389                    })
390                    .collect::<Result<_, _>>()?,
391            ),
392            Literal::AbstractInt(_) => Self::AbstractInt(
393                components
394                    .iter()
395                    .map(|l| match l {
396                        &Literal::AbstractInt(v) => Ok(v),
397                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
398                    })
399                    .collect::<Result<_, _>>()?,
400            ),
401            Literal::AbstractFloat(_) => Self::AbstractFloat(
402                components
403                    .iter()
404                    .map(|l| match l {
405                        &Literal::AbstractFloat(v) => Ok(v),
406                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
407                    })
408                    .collect::<Result<_, _>>()?,
409            ),
410            Literal::F16(_) => Self::F16(
411                components
412                    .iter()
413                    .map(|l| match l {
414                        &Literal::F16(v) => Ok(v),
415                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
416                    })
417                    .collect::<Result<_, _>>()?,
418            ),
419        })
420    }
421
422    #[allow(dead_code)]
423    /// Returns [`ArrayVec`] of [`Literal`]s
424    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425        match *self {
426            LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427            LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428            LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429            LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430            LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431            LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432            LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433            LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434            LiteralVector::AbstractInt(ref v) => {
435                v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436            }
437            LiteralVector::AbstractFloat(ref v) => {
438                v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439            }
440        }
441    }
442
443    #[allow(dead_code)]
444    /// Puts self into eval's expressions arena and returns handle to it
445    fn register_as_evaluated_expr(
446        &self,
447        eval: &mut ConstantEvaluator<'_>,
448        span: Span,
449    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450        let lit_vec = self.to_literal_vec();
451        assert!(!lit_vec.is_empty());
452        let expr = if lit_vec.len() == 1 {
453            Expression::Literal(lit_vec[0])
454        } else {
455            Expression::Compose {
456                ty: eval.types.insert(
457                    Type {
458                        name: None,
459                        inner: TypeInner::Vector {
460                            size: match lit_vec.len() {
461                                2 => crate::VectorSize::Bi,
462                                3 => crate::VectorSize::Tri,
463                                4 => crate::VectorSize::Quad,
464                                _ => unreachable!(),
465                            },
466                            scalar: lit_vec[0].scalar(),
467                        },
468                    },
469                    Span::UNDEFINED,
470                ),
471                components: lit_vec
472                    .iter()
473                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474                    .collect::<Result<_, _>>()?,
475            }
476        };
477        eval.register_evaluated_expr(expr, span)
478    }
479}
480
481/// A macro for matching on [`LiteralVector`] variants.
482///
483/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
484/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
485///
486/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
487///
488/// Example usage:
489///
490/// ```rust,ignore
491/// match_literal_vector!(match v => Literal {
492///     F16 => |v| {v.sum()},
493///     Integer => |v| {v.sum()},
494///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
495/// })
496/// ```
497///
498/// ```rust,ignore
499/// match_literal_vector!(match (e1, e2) => LiteralVector {
500///     F16 => |e1, e2| {e1+e2},
501///     Integer => |e1, e2| {e1+e2},
502///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
503/// })
504/// ```
505macro_rules! match_literal_vector {
506    (match $lit_vec:expr => $out:ident {
507        $(
508            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509        ),+
510        $(,)?
511    }) => {
512        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513    };
514
515    (@inner_start
516        $lit_vec:expr;
517        $out:ident;
518        [$($ty:ident),+];
519        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520    ) => {
521        match_literal_vector!(@inner
522            $lit_vec;
523            $out;
524            [$($ty),+];
525            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526        )
527    };
528
529    (@inner
530        $lit_vec:expr;
531        $out:ident;
532        [$ty:ident $(, $ty1:ident)*];
533        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535    ) => {
536        match_literal_vector!(@inner
537            $ty;
538            $lit_vec;
539            $out;
540            [$($ty1),*];
541            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542            [$({ $($var),+ ; $($ret)? ; $body }),+]
543        )
544    };
545    (@inner
546        Integer;
547        $lit_vec:expr;
548        $out:ident;
549        [$($ty:ident),*];
550        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551        [
552            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554        ]
555    ) => {
556        match_literal_vector!(@inner
557            $lit_vec;
558            $out;
559            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561            [
562                { $($var),+ ; $($ret)? ; $body }, // U32
563                { $($var),+ ; $($ret)? ; $body }, // I32
564                { $($var),+ ; $($ret)? ; $body }, // U64
565                { $($var),+ ; $($ret)? ; $body }, // I64
566                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
567                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568            ]
569        )
570    };
571    (@inner
572        Float;
573        $lit_vec:expr;
574        $out:ident;
575        [$($ty:ident),*];
576        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577        [
578            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580        ]
581    ) => {
582        match_literal_vector!(@inner
583            $lit_vec;
584            $out;
585            [F16, F32, F64, AbstractFloat $(, $ty)*];
586            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587            [
588                { $($var),+ ; $($ret)? ; $body }, // F16
589                { $($var),+ ; $($ret)? ; $body }, // F32
590                { $($var),+ ; $($ret)? ; $body }, // F64
591                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
592                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593            ]
594        )
595    };
596    (@inner
597        $ty:ident;
598        $lit_vec:expr;
599        $out:ident;
600        [$ty1:ident $(,$ty2:ident)*];
601        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604        ]
605    ) => {
606        match_literal_vector!(@inner
607            $ty1;
608            $lit_vec;
609            $out;
610            [$($ty2),*];
611            [
612                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613                { $ty; $($var),+ ; $($ret)? ; $body }
614            ] <>
615            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617        )
618    };
619    (@inner
620        $ty:ident;
621        $lit_vec:expr;
622        $out:ident;
623        [];
624        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626    ) => {
627        match_literal_vector!(@inner_finish
628            $lit_vec;
629            $out;
630            [
631                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632                { $ty; $($var),+ ; $($ret)? ; $body }
633            ]
634        )
635    };
636    (@inner_finish
637        $lit_vec:expr;
638        $out:ident;
639        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640    ) => {
641        match $lit_vec {
642            $(
643                #[allow(unused_parens)]
644                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645            )+
646            _ => Err(ConstantEvaluatorError::InvalidMathArg),
647        }
648    };
649    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650        $out::$ty($body)
651    };
652    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653        $out::$ret($body)
654    };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659    Wgsl(WgslRestrictions<'a>),
660    Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
665    const fn has_runtime_restrictions(&self) -> bool {
666        matches!(
667            self,
668            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670        )
671    }
672}
673
674/// A context for evaluating constant expressions.
675///
676/// A `ConstantEvaluator` points at an expression arena to which it can append
677/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
678/// of Naga [`Expression`] you like, and if its value can be computed at compile
679/// time, `try_eval_and_append` appends an expression representing the computed
680/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
681/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
682///
683/// A `ConstantEvaluator` also holds whatever information we need to carry out
684/// that evaluation: types, other constants, and so on.
685///
686/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
687/// [`Compose`]: Expression::Compose
688/// [`ZeroValue`]: Expression::ZeroValue
689/// [`Literal`]: Expression::Literal
690/// [`Swizzle`]: Expression::Swizzle
691#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693    /// Which language's evaluation rules we should follow.
694    behavior: Behavior<'a>,
695
696    /// The module's type arena.
697    ///
698    /// Because expressions like [`Splat`] contain type handles, we need to be
699    /// able to add new types to produce those expressions.
700    ///
701    /// [`Splat`]: Expression::Splat
702    types: &'a mut UniqueArena<Type>,
703
704    /// The module's constant arena.
705    constants: &'a Arena<Constant>,
706
707    /// The module's override arena.
708    overrides: &'a Arena<Override>,
709
710    /// The arena to which we are contributing expressions.
711    expressions: &'a mut Arena<Expression>,
712
713    /// Tracks the constness of expressions residing in [`Self::expressions`]
714    expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716    layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721    /// - const-expressions will be evaluated and inserted in the arena
722    Const(Option<FunctionLocalData<'a>>),
723    /// - const-expressions will be evaluated and inserted in the arena
724    /// - override-expressions will be inserted in the arena
725    Override,
726    /// - const-expressions will be evaluated and inserted in the arena
727    /// - override-expressions will be inserted in the arena
728    /// - runtime-expressions will be inserted in the arena
729    Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734    /// - const-expressions will be evaluated and inserted in the arena
735    Const,
736    /// - const-expressions will be evaluated and inserted in the arena
737    /// - override-expressions will be inserted in the arena
738    /// - runtime-expressions will be inserted in the arena
739    Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744    /// Global constant expressions
745    global_expressions: &'a Arena<Expression>,
746    emitter: &'a mut super::Emitter,
747    block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752    Const,
753    Override,
754    Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759    inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763    pub const fn new() -> Self {
764        Self {
765            inner: HandleVec::new(),
766        }
767    }
768
769    /// Forces the the expression to not be const
770    pub fn force_non_const(&mut self, value: Handle<Expression>) {
771        self.inner[value] = ExpressionKind::Runtime;
772    }
773
774    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775        self.inner.insert(value, expr_type);
776    }
777
778    pub fn is_const(&self, h: Handle<Expression>) -> bool {
779        matches!(self.type_of(h), ExpressionKind::Const)
780    }
781
782    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783        matches!(
784            self.type_of(h),
785            ExpressionKind::Const | ExpressionKind::Override
786        )
787    }
788
789    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790        self.inner[value]
791    }
792
793    pub fn from_arena(arena: &Arena<Expression>) -> Self {
794        let mut tracker = Self {
795            inner: HandleVec::with_capacity(arena.len()),
796        };
797        for (handle, expr) in arena.iter() {
798            tracker
799                .inner
800                .insert(handle, tracker.type_of_with_expr(expr));
801        }
802        tracker
803    }
804
805    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806        match *expr {
807            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808                ExpressionKind::Const
809            }
810            Expression::Override(_) => ExpressionKind::Override,
811            Expression::Compose { ref components, .. } => {
812                let mut expr_type = ExpressionKind::Const;
813                for component in components {
814                    expr_type = expr_type.max(self.type_of(*component))
815                }
816                expr_type
817            }
818            Expression::Splat { value, .. } => self.type_of(value),
819            Expression::AccessIndex { base, .. } => self.type_of(base),
820            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821            Expression::Swizzle { vector, .. } => self.type_of(vector),
822            Expression::Unary { expr, .. } => self.type_of(expr),
823            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824            Expression::Math {
825                arg,
826                arg1,
827                arg2,
828                arg3,
829                ..
830            } => self
831                .type_of(arg)
832                .max(
833                    arg1.map(|arg| self.type_of(arg))
834                        .unwrap_or(ExpressionKind::Const),
835                )
836                .max(
837                    arg2.map(|arg| self.type_of(arg))
838                        .unwrap_or(ExpressionKind::Const),
839                )
840                .max(
841                    arg3.map(|arg| self.type_of(arg))
842                        .unwrap_or(ExpressionKind::Const),
843                ),
844            Expression::As { expr, .. } => self.type_of(expr),
845            Expression::Select {
846                condition,
847                accept,
848                reject,
849            } => self
850                .type_of(condition)
851                .max(self.type_of(accept))
852                .max(self.type_of(reject)),
853            Expression::Relational { argument, .. } => self.type_of(argument),
854            Expression::ArrayLength(expr) => self.type_of(expr),
855            _ => ExpressionKind::Runtime,
856        }
857    }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863    #[error("Constants cannot access function arguments")]
864    FunctionArg,
865    #[error("Constants cannot access global variables")]
866    GlobalVariable,
867    #[error("Constants cannot access local variables")]
868    LocalVariable,
869    #[error("Cannot get the array length of a non array type")]
870    InvalidArrayLengthArg,
871    #[error("Constants cannot get the array length of a dynamically sized array")]
872    ArrayLengthDynamic,
873    #[error("Cannot call arrayLength on array sized by override-expression")]
874    ArrayLengthOverridden,
875    #[error("Constants cannot call functions")]
876    Call,
877    #[error("Constants don't support workGroupUniformLoad")]
878    WorkGroupUniformLoadResult,
879    #[error("Constants don't support atomic functions")]
880    Atomic,
881    #[error("Constants don't support derivative functions")]
882    Derivative,
883    #[error("Constants don't support load expressions")]
884    Load,
885    #[error("Constants don't support image expressions")]
886    ImageExpression,
887    #[error("Constants don't support ray query expressions")]
888    RayQueryExpression,
889    #[error("Constants don't support subgroup expressions")]
890    SubgroupExpression,
891    #[error("Cannot access the type")]
892    InvalidAccessBase,
893    #[error("Cannot access at the index")]
894    InvalidAccessIndex,
895    #[error("Cannot access with index of type")]
896    InvalidAccessIndexTy,
897    #[error("Constants don't support array length expressions")]
898    ArrayLength,
899    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900    InvalidCastArg { from: String, to: String },
901    #[error("Cannot apply the unary op to the argument")]
902    InvalidUnaryOpArg,
903    #[error("Cannot apply the binary op to the arguments")]
904    InvalidBinaryOpArgs,
905    #[error("Cannot apply math function to type")]
906    InvalidMathArg,
907    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908    InvalidMathArgCount(crate::MathFunction, usize, usize),
909    #[error("Cannot apply relational function to type")]
910    InvalidRelationalArg(RelationalFunction),
911    #[error("value of `low` is greater than `high` for clamp built-in function")]
912    InvalidClamp,
913    #[error("Constructor expects {expected} components, found {actual}")]
914    InvalidVectorComposeLength { expected: usize, actual: usize },
915    #[error("Constructor must only contain vector or scalar arguments")]
916    InvalidVectorComposeComponent,
917    #[error("Splat is defined only on scalar values")]
918    SplatScalarOnly,
919    #[error("Can only swizzle vector constants")]
920    SwizzleVectorOnly,
921    #[error("swizzle component not present in source expression")]
922    SwizzleOutOfBounds,
923    #[error("Type is not constructible")]
924    TypeNotConstructible,
925    #[error("Subexpression(s) are not constant")]
926    SubexpressionsAreNotConstant,
927    #[error("Not implemented as constant expression: {0}")]
928    NotImplemented(String),
929    #[error("{0} operation overflowed")]
930    Overflow(String),
931    #[error(
932        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933    )]
934    AutomaticConversionLossy {
935        value: String,
936        to_type: &'static str,
937    },
938    #[error("Division by zero")]
939    DivisionByZero,
940    #[error("Remainder by zero")]
941    RemainderByZero,
942    #[error("RHS of shift operation is greater than or equal to 32")]
943    ShiftedMoreThan32Bits,
944    #[error(transparent)]
945    Literal(#[from] crate::valid::LiteralError),
946    #[error("Can't use pipeline-overridable constants in const-expressions")]
947    Override,
948    #[error("Unexpected runtime-expression")]
949    RuntimeExpr,
950    #[error("Unexpected override-expression")]
951    OverrideExpr,
952    #[error("Expected boolean expression for condition argument of `select`, got something else")]
953    SelectScalarConditionNotABool,
954    #[error(
955        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956        reject,
957        accept
958    )]
959    SelectVecRejectAcceptSizeMismatch {
960        reject: crate::VectorSize,
961        accept: crate::VectorSize,
962    },
963    #[error("Expected boolean vector for condition arg., got something else")]
964    SelectConditionNotAVecBool,
965    #[error(
966        "Expected same number of vector components between condition, accept, and reject args., got something else",
967    )]
968    SelectConditionVecSizeMismatch,
969    #[error(
970        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971    )]
972    SelectAcceptRejectTypeMismatch,
973}
974
975impl<'a> ConstantEvaluator<'a> {
976    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
977    /// constant expression arena.
978    ///
979    /// Report errors according to WGSL's rules for constant evaluation.
980    pub fn for_wgsl_module(
981        module: &'a mut crate::Module,
982        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
983        layouter: &'a mut crate::proc::Layouter,
984        in_override_ctx: bool,
985    ) -> Self {
986        Self::for_module(
987            Behavior::Wgsl(if in_override_ctx {
988                WgslRestrictions::Override
989            } else {
990                WgslRestrictions::Const(None)
991            }),
992            module,
993            global_expression_kind_tracker,
994            layouter,
995        )
996    }
997
998    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
999    /// constant expression arena.
1000    ///
1001    /// Report errors according to GLSL's rules for constant evaluation.
1002    pub fn for_glsl_module(
1003        module: &'a mut crate::Module,
1004        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1005        layouter: &'a mut crate::proc::Layouter,
1006    ) -> Self {
1007        Self::for_module(
1008            Behavior::Glsl(GlslRestrictions::Const),
1009            module,
1010            global_expression_kind_tracker,
1011            layouter,
1012        )
1013    }
1014
1015    fn for_module(
1016        behavior: Behavior<'a>,
1017        module: &'a mut crate::Module,
1018        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1019        layouter: &'a mut crate::proc::Layouter,
1020    ) -> Self {
1021        Self {
1022            behavior,
1023            types: &mut module.types,
1024            constants: &module.constants,
1025            overrides: &module.overrides,
1026            expressions: &mut module.global_expressions,
1027            expression_kind_tracker: global_expression_kind_tracker,
1028            layouter,
1029        }
1030    }
1031
1032    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1033    /// expression arena.
1034    ///
1035    /// Report errors according to WGSL's rules for constant evaluation.
1036    pub fn for_wgsl_function(
1037        module: &'a mut crate::Module,
1038        expressions: &'a mut Arena<Expression>,
1039        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1040        layouter: &'a mut crate::proc::Layouter,
1041        emitter: &'a mut super::Emitter,
1042        block: &'a mut crate::Block,
1043        is_const: bool,
1044    ) -> Self {
1045        let local_data = FunctionLocalData {
1046            global_expressions: &module.global_expressions,
1047            emitter,
1048            block,
1049        };
1050        Self {
1051            behavior: Behavior::Wgsl(if is_const {
1052                WgslRestrictions::Const(Some(local_data))
1053            } else {
1054                WgslRestrictions::Runtime(local_data)
1055            }),
1056            types: &mut module.types,
1057            constants: &module.constants,
1058            overrides: &module.overrides,
1059            expressions,
1060            expression_kind_tracker: local_expression_kind_tracker,
1061            layouter,
1062        }
1063    }
1064
1065    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1066    /// expression arena.
1067    ///
1068    /// Report errors according to GLSL's rules for constant evaluation.
1069    pub fn for_glsl_function(
1070        module: &'a mut crate::Module,
1071        expressions: &'a mut Arena<Expression>,
1072        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1073        layouter: &'a mut crate::proc::Layouter,
1074        emitter: &'a mut super::Emitter,
1075        block: &'a mut crate::Block,
1076    ) -> Self {
1077        Self {
1078            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1079                global_expressions: &module.global_expressions,
1080                emitter,
1081                block,
1082            })),
1083            types: &mut module.types,
1084            constants: &module.constants,
1085            overrides: &module.overrides,
1086            expressions,
1087            expression_kind_tracker: local_expression_kind_tracker,
1088            layouter,
1089        }
1090    }
1091
1092    pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1093        crate::proc::GlobalCtx {
1094            types: self.types,
1095            constants: self.constants,
1096            overrides: self.overrides,
1097            global_expressions: match self.function_local_data() {
1098                Some(data) => data.global_expressions,
1099                None => self.expressions,
1100            },
1101        }
1102    }
1103
1104    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1105        if !self.expression_kind_tracker.is_const(expr) {
1106            log::debug!("check: SubexpressionsAreNotConstant");
1107            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1108        }
1109        Ok(())
1110    }
1111
1112    fn check_and_get(
1113        &mut self,
1114        expr: Handle<Expression>,
1115    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1116        match self.expressions[expr] {
1117            Expression::Constant(c) => {
1118                // Are we working in a function's expression arena, or the
1119                // module's constant expression arena?
1120                if let Some(function_local_data) = self.function_local_data() {
1121                    // Deep-copy the constant's value into our arena.
1122                    self.copy_from(
1123                        self.constants[c].init,
1124                        function_local_data.global_expressions,
1125                    )
1126                } else {
1127                    // "See through" the constant and use its initializer.
1128                    Ok(self.constants[c].init)
1129                }
1130            }
1131            _ => {
1132                self.check(expr)?;
1133                Ok(expr)
1134            }
1135        }
1136    }
1137
1138    /// Try to evaluate `expr` at compile time.
1139    ///
1140    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1141    /// we can determine its value at compile time, we append an expression
1142    /// representing its value - a tree of [`Literal`], [`Compose`],
1143    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1144    /// `self` contributes to.
1145    ///
1146    /// If `expr`'s value cannot be determined at compile time, and `self` is
1147    /// contributing to some function's expression arena, then append `expr` to
1148    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1149    /// contributing to the module's constant expression arena; since `expr`'s
1150    /// value is not a constant, return an error.
1151    ///
1152    /// We only consider `expr` itself, without recursing into its operands. Its
1153    /// operands must all have been produced by prior calls to
1154    /// `try_eval_and_append`, to ensure that they have already been reduced to
1155    /// an evaluated form if possible.
1156    ///
1157    /// [`Literal`]: Expression::Literal
1158    /// [`Compose`]: Expression::Compose
1159    /// [`ZeroValue`]: Expression::ZeroValue
1160    /// [`Swizzle`]: Expression::Swizzle
1161    pub fn try_eval_and_append(
1162        &mut self,
1163        expr: Expression,
1164        span: Span,
1165    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1166        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1167            ExpressionKind::Const => {
1168                let eval_result = self.try_eval_and_append_impl(&expr, span);
1169                // We should be able to evaluate `Const` expressions at this
1170                // point. If we failed to, then that probably means we just
1171                // haven't implemented that part of constant evaluation. Work
1172                // around this by simply emitting it as a run-time expression.
1173                if self.behavior.has_runtime_restrictions()
1174                    && matches!(
1175                        eval_result,
1176                        Err(ConstantEvaluatorError::NotImplemented(_)
1177                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1178                    )
1179                {
1180                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1181                } else {
1182                    eval_result
1183                }
1184            }
1185            ExpressionKind::Override => match self.behavior {
1186                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1187                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1188                }
1189                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1190                    Err(ConstantEvaluatorError::OverrideExpr)
1191                }
1192                Behavior::Glsl(_) => {
1193                    unreachable!()
1194                }
1195            },
1196            ExpressionKind::Runtime => {
1197                if self.behavior.has_runtime_restrictions() {
1198                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1199                } else {
1200                    Err(ConstantEvaluatorError::RuntimeExpr)
1201                }
1202            }
1203        }
1204    }
1205
1206    /// Is the [`Self::expressions`] arena the global module expression arena?
1207    const fn is_global_arena(&self) -> bool {
1208        matches!(
1209            self.behavior,
1210            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1211                | Behavior::Glsl(GlslRestrictions::Const)
1212        )
1213    }
1214
1215    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1216        match self.behavior {
1217            Behavior::Wgsl(
1218                WgslRestrictions::Runtime(ref function_local_data)
1219                | WgslRestrictions::Const(Some(ref function_local_data)),
1220            )
1221            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1222                Some(function_local_data)
1223            }
1224            _ => None,
1225        }
1226    }
1227
1228    fn try_eval_and_append_impl(
1229        &mut self,
1230        expr: &Expression,
1231        span: Span,
1232    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1233        log::trace!("try_eval_and_append: {expr:?}");
1234        match *expr {
1235            Expression::Constant(c) if self.is_global_arena() => {
1236                // "See through" the constant and use its initializer.
1237                // This is mainly done to avoid having constants pointing to other constants.
1238                Ok(self.constants[c].init)
1239            }
1240            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1241            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1242                self.register_evaluated_expr(expr.clone(), span)
1243            }
1244            Expression::Compose { ty, ref components } => {
1245                let components = components
1246                    .iter()
1247                    .map(|component| self.check_and_get(*component))
1248                    .collect::<Result<Vec<_>, _>>()?;
1249                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1250            }
1251            Expression::Splat { size, value } => {
1252                let value = self.check_and_get(value)?;
1253                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1254            }
1255            Expression::AccessIndex { base, index } => {
1256                let base = self.check_and_get(base)?;
1257
1258                self.access(base, index as usize, span)
1259            }
1260            Expression::Access { base, index } => {
1261                let base = self.check_and_get(base)?;
1262                let index = self.check_and_get(index)?;
1263
1264                self.access(base, self.constant_index(index)?, span)
1265            }
1266            Expression::Swizzle {
1267                size,
1268                vector,
1269                pattern,
1270            } => {
1271                let vector = self.check_and_get(vector)?;
1272
1273                self.swizzle(size, span, vector, pattern)
1274            }
1275            Expression::Unary { expr, op } => {
1276                let expr = self.check_and_get(expr)?;
1277
1278                self.unary_op(op, expr, span)
1279            }
1280            Expression::Binary { left, right, op } => {
1281                let left = self.check_and_get(left)?;
1282                let right = self.check_and_get(right)?;
1283
1284                self.binary_op(op, left, right, span)
1285            }
1286            Expression::Math {
1287                fun,
1288                arg,
1289                arg1,
1290                arg2,
1291                arg3,
1292            } => {
1293                let arg = self.check_and_get(arg)?;
1294                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1295                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1296                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1297
1298                self.math(arg, arg1, arg2, arg3, fun, span)
1299            }
1300            Expression::As {
1301                convert,
1302                expr,
1303                kind,
1304            } => {
1305                let expr = self.check_and_get(expr)?;
1306
1307                match convert {
1308                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1309                    None => Err(ConstantEvaluatorError::NotImplemented(
1310                        "bitcast built-in function".into(),
1311                    )),
1312                }
1313            }
1314            Expression::Select {
1315                reject,
1316                accept,
1317                condition,
1318            } => {
1319                let mut arg = |expr| self.check_and_get(expr);
1320
1321                let reject = arg(reject)?;
1322                let accept = arg(accept)?;
1323                let condition = arg(condition)?;
1324
1325                self.select(reject, accept, condition, span)
1326            }
1327            Expression::Relational { fun, argument } => {
1328                let argument = self.check_and_get(argument)?;
1329                self.relational(fun, argument, span)
1330            }
1331            Expression::ArrayLength(expr) => match self.behavior {
1332                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1333                Behavior::Glsl(_) => {
1334                    let expr = self.check_and_get(expr)?;
1335                    self.array_length(expr, span)
1336                }
1337            },
1338            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1339            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1340            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1341            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1342            Expression::WorkGroupUniformLoadResult { .. } => {
1343                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1344            }
1345            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1346            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1347            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1348            Expression::ImageSample { .. }
1349            | Expression::ImageLoad { .. }
1350            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1351            Expression::RayQueryProceedResult
1352            | Expression::RayQueryGetIntersection { .. }
1353            | Expression::RayQueryVertexPositions { .. } => {
1354                Err(ConstantEvaluatorError::RayQueryExpression)
1355            }
1356            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1357            Expression::SubgroupOperationResult { .. } => {
1358                Err(ConstantEvaluatorError::SubgroupExpression)
1359            }
1360        }
1361    }
1362
1363    /// Splat `value` to `size`, without using [`Splat`] expressions.
1364    ///
1365    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1366    /// build a vector with the given `size` whose components are all
1367    /// `value`.
1368    ///
1369    /// Use `span` as the span of the inserted expressions and
1370    /// resulting types.
1371    ///
1372    /// [`Splat`]: Expression::Splat
1373    /// [`Compose`]: Expression::Compose
1374    /// [`ZeroValue`]: Expression::ZeroValue
1375    fn splat(
1376        &mut self,
1377        value: Handle<Expression>,
1378        size: crate::VectorSize,
1379        span: Span,
1380    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1381        match self.expressions[value] {
1382            Expression::Literal(literal) => {
1383                let scalar = literal.scalar();
1384                let ty = self.types.insert(
1385                    Type {
1386                        name: None,
1387                        inner: TypeInner::Vector { size, scalar },
1388                    },
1389                    span,
1390                );
1391                let expr = Expression::Compose {
1392                    ty,
1393                    components: vec![value; size as usize],
1394                };
1395                self.register_evaluated_expr(expr, span)
1396            }
1397            Expression::ZeroValue(ty) => {
1398                let inner = match self.types[ty].inner {
1399                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1400                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1401                };
1402                let res_ty = self.types.insert(Type { name: None, inner }, span);
1403                let expr = Expression::ZeroValue(res_ty);
1404                self.register_evaluated_expr(expr, span)
1405            }
1406            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1407        }
1408    }
1409
1410    fn swizzle(
1411        &mut self,
1412        size: crate::VectorSize,
1413        span: Span,
1414        src_constant: Handle<Expression>,
1415        pattern: [crate::SwizzleComponent; 4],
1416    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1417        let mut get_dst_ty = |ty| match self.types[ty].inner {
1418            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1419                Type {
1420                    name: None,
1421                    inner: TypeInner::Vector { size, scalar },
1422                },
1423                span,
1424            )),
1425            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1426        };
1427
1428        match self.expressions[src_constant] {
1429            Expression::ZeroValue(ty) => {
1430                let dst_ty = get_dst_ty(ty)?;
1431                let expr = Expression::ZeroValue(dst_ty);
1432                self.register_evaluated_expr(expr, span)
1433            }
1434            Expression::Splat { value, .. } => {
1435                let expr = Expression::Splat { size, value };
1436                self.register_evaluated_expr(expr, span)
1437            }
1438            Expression::Compose { ty, ref components } => {
1439                let dst_ty = get_dst_ty(ty)?;
1440
1441                let mut flattened = [src_constant; 4]; // dummy value
1442                let len =
1443                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1444                        .zip(flattened.iter_mut())
1445                        .map(|(component, elt)| *elt = component)
1446                        .count();
1447                let flattened = &flattened[..len];
1448
1449                let swizzled_components = pattern[..size as usize]
1450                    .iter()
1451                    .map(|&sc| {
1452                        let sc = sc as usize;
1453                        if let Some(elt) = flattened.get(sc) {
1454                            Ok(*elt)
1455                        } else {
1456                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1457                        }
1458                    })
1459                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1460                let expr = Expression::Compose {
1461                    ty: dst_ty,
1462                    components: swizzled_components,
1463                };
1464                self.register_evaluated_expr(expr, span)
1465            }
1466            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1467        }
1468    }
1469
1470    fn math(
1471        &mut self,
1472        arg: Handle<Expression>,
1473        arg1: Option<Handle<Expression>>,
1474        arg2: Option<Handle<Expression>>,
1475        arg3: Option<Handle<Expression>>,
1476        fun: crate::MathFunction,
1477        span: Span,
1478    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1479        let expected = fun.argument_count();
1480        let given = Some(arg)
1481            .into_iter()
1482            .chain(arg1)
1483            .chain(arg2)
1484            .chain(arg3)
1485            .count();
1486        if expected != given {
1487            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1488                fun, expected, given,
1489            ));
1490        }
1491
1492        // NOTE: We try to match the declaration order of `MathFunction` here.
1493        match fun {
1494            // comparison
1495            crate::MathFunction::Abs => {
1496                component_wise_scalar(self, span, [arg], |args| match args {
1497                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1498                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1499                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1500                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1501                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1502                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1503                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1504                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1505                })
1506            }
1507            crate::MathFunction::Min => {
1508                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1509                    Ok([e1.min(e2)])
1510                })
1511            }
1512            crate::MathFunction::Max => {
1513                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1514                    Ok([e1.max(e2)])
1515                })
1516            }
1517            crate::MathFunction::Clamp => {
1518                component_wise_scalar!(
1519                    self,
1520                    span,
1521                    [arg, arg1.unwrap(), arg2.unwrap()],
1522                    |e, low, high| {
1523                        if low > high {
1524                            Err(ConstantEvaluatorError::InvalidClamp)
1525                        } else {
1526                            Ok([e.clamp(low, high)])
1527                        }
1528                    }
1529                )
1530            }
1531            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1532                Float::F16([e]) => Ok(Float::F16(
1533                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1534                )),
1535                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1536                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1537            }),
1538
1539            // trigonometry
1540            crate::MathFunction::Cos => {
1541                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1542            }
1543            crate::MathFunction::Cosh => {
1544                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1545            }
1546            crate::MathFunction::Sin => {
1547                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1548            }
1549            crate::MathFunction::Sinh => {
1550                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1551            }
1552            crate::MathFunction::Tan => {
1553                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1554            }
1555            crate::MathFunction::Tanh => {
1556                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1557            }
1558            crate::MathFunction::Acos => {
1559                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1560            }
1561            crate::MathFunction::Asin => {
1562                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1563            }
1564            crate::MathFunction::Atan => {
1565                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1566            }
1567            crate::MathFunction::Atan2 => {
1568                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1569                    Ok([y.atan2(x)])
1570                })
1571            }
1572            crate::MathFunction::Asinh => {
1573                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1574            }
1575            crate::MathFunction::Acosh => {
1576                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1577            }
1578            crate::MathFunction::Atanh => {
1579                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1580            }
1581            crate::MathFunction::Radians => {
1582                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1583            }
1584            crate::MathFunction::Degrees => {
1585                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1586            }
1587
1588            // decomposition
1589            crate::MathFunction::Ceil => {
1590                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1591            }
1592            crate::MathFunction::Floor => {
1593                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1594            }
1595            crate::MathFunction::Round => {
1596                component_wise_float(self, span, [arg], |e| match e {
1597                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1598                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1599                    Float::F16([e]) => {
1600                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1601                        //
1602                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1603                        // which has licensing compatible with ours. See also
1604                        // <https://github.com/rust-lang/rust/issues/96710>.
1605                        //
1606                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1607                        fn round_ties_even(x: f64) -> f64 {
1608                            let i = x as i64;
1609                            let f = (x - i as f64).abs();
1610                            if f == 0.5 {
1611                                if i & 1 == 1 {
1612                                    // -1.5, 1.5, 3.5, ...
1613                                    (x.abs() + 0.5).copysign(x)
1614                                } else {
1615                                    (x.abs() - 0.5).copysign(x)
1616                                }
1617                            } else {
1618                                x.round()
1619                            }
1620                        }
1621
1622                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1623                    }
1624                })
1625            }
1626            crate::MathFunction::Fract => {
1627                component_wise_float!(self, span, [arg], |e| {
1628                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1629                    // here.
1630                    Ok([e - e.floor()])
1631                })
1632            }
1633            crate::MathFunction::Trunc => {
1634                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1635            }
1636
1637            // exponent
1638            crate::MathFunction::Exp => {
1639                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1640            }
1641            crate::MathFunction::Exp2 => {
1642                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1643            }
1644            crate::MathFunction::Log => {
1645                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1646            }
1647            crate::MathFunction::Log2 => {
1648                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1649            }
1650            crate::MathFunction::Pow => {
1651                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1652                    Ok([e1.powf(e2)])
1653                })
1654            }
1655
1656            // computational
1657            crate::MathFunction::Sign => {
1658                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1659            }
1660            crate::MathFunction::Fma => {
1661                component_wise_float!(
1662                    self,
1663                    span,
1664                    [arg, arg1.unwrap(), arg2.unwrap()],
1665                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1666                )
1667            }
1668            crate::MathFunction::Step => {
1669                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1670                    Float::Abstract([edge, x]) => {
1671                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1672                    }
1673                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1674                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1675                        f16::one()
1676                    } else {
1677                        f16::zero()
1678                    }])),
1679                })
1680            }
1681            crate::MathFunction::Sqrt => {
1682                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1683            }
1684            crate::MathFunction::InverseSqrt => {
1685                component_wise_float(self, span, [arg], |e| match e {
1686                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1687                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1688                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1689                })
1690            }
1691
1692            // bits
1693            crate::MathFunction::CountTrailingZeros => {
1694                component_wise_concrete_int!(self, span, [arg], |e| {
1695                    #[allow(clippy::useless_conversion)]
1696                    Ok([e
1697                        .trailing_zeros()
1698                        .try_into()
1699                        .expect("bit count overflowed 32 bits, somehow!?")])
1700                })
1701            }
1702            crate::MathFunction::CountLeadingZeros => {
1703                component_wise_concrete_int!(self, span, [arg], |e| {
1704                    #[allow(clippy::useless_conversion)]
1705                    Ok([e
1706                        .leading_zeros()
1707                        .try_into()
1708                        .expect("bit count overflowed 32 bits, somehow!?")])
1709                })
1710            }
1711            crate::MathFunction::CountOneBits => {
1712                component_wise_concrete_int!(self, span, [arg], |e| {
1713                    #[allow(clippy::useless_conversion)]
1714                    Ok([e
1715                        .count_ones()
1716                        .try_into()
1717                        .expect("bit count overflowed 32 bits, somehow!?")])
1718                })
1719            }
1720            crate::MathFunction::ReverseBits => {
1721                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1722            }
1723            crate::MathFunction::FirstTrailingBit => {
1724                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1725            }
1726            crate::MathFunction::FirstLeadingBit => {
1727                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1728            }
1729
1730            // vector
1731            crate::MathFunction::Dot4I8Packed => {
1732                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1733            }
1734            crate::MathFunction::Dot4U8Packed => {
1735                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1736            }
1737            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1738            crate::MathFunction::Dot => {
1739                // https://www.w3.org/TR/WGSL/#dot-builtin
1740                let e1 = self.extract_vec(arg, false)?;
1741                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1742                if e1.len() != e2.len() {
1743                    return Err(ConstantEvaluatorError::InvalidMathArg);
1744                }
1745
1746                fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1747                where
1748                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1749                {
1750                    a.iter()
1751                        .zip(b.iter())
1752                        .map(|(&aa, bb)| aa.checked_mul(bb))
1753                        .try_fold(P::zero(), |acc, x| {
1754                            if let Some(x) = x {
1755                                acc.checked_add(&x)
1756                            } else {
1757                                None
1758                            }
1759                        })
1760                        .ok_or(ConstantEvaluatorError::Overflow(
1761                            "in dot built-in".to_string(),
1762                        ))
1763                }
1764
1765                let result = match_literal_vector!(match (e1, e2) => Literal {
1766                    Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1767                    Integer => |e1, e2| { int_dot(e1, e2)? },
1768                })?;
1769                self.register_evaluated_expr(Expression::Literal(result), span)
1770            }
1771            crate::MathFunction::Length => {
1772                // https://www.w3.org/TR/WGSL/#length-builtin
1773                let e1 = self.extract_vec(arg, true)?;
1774
1775                fn float_length<F>(e: &[F]) -> F
1776                where
1777                    F: core::ops::Mul<F>,
1778                    F: num_traits::Float + iter::Sum,
1779                {
1780                    e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1781                }
1782
1783                let result = match_literal_vector!(match e1 => Literal {
1784                    Float => |e1| { float_length(e1) },
1785                })?;
1786                self.register_evaluated_expr(Expression::Literal(result), span)
1787            }
1788            crate::MathFunction::Distance => {
1789                // https://www.w3.org/TR/WGSL/#distance-builtin
1790                let e1 = self.extract_vec(arg, true)?;
1791                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1792                if e1.len() != e2.len() {
1793                    return Err(ConstantEvaluatorError::InvalidMathArg);
1794                }
1795
1796                fn float_distance<F>(a: &[F], b: &[F]) -> F
1797                where
1798                    F: core::ops::Mul<F>,
1799                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1800                {
1801                    a.iter()
1802                        .zip(b.iter())
1803                        .map(|(&aa, &bb)| aa - bb)
1804                        .map(|ei| ei * ei)
1805                        .sum::<F>()
1806                        .sqrt()
1807                }
1808                let result = match_literal_vector!(match (e1, e2) => Literal {
1809                    Float => |e1, e2| { float_distance(e1, e2) },
1810                })?;
1811                self.register_evaluated_expr(Expression::Literal(result), span)
1812            }
1813            crate::MathFunction::Normalize => {
1814                // https://www.w3.org/TR/WGSL/#normalize-builtin
1815                let e1 = self.extract_vec(arg, true)?;
1816
1817                fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1818                where
1819                    F: core::ops::Mul<F>,
1820                    F: num_traits::Float + iter::Sum,
1821                {
1822                    let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1823                    e.iter().map(|&ei| ei / len).collect()
1824                }
1825
1826                let result = match_literal_vector!(match e1 => LiteralVector {
1827                    Float => |e1| { float_normalize(e1) },
1828                })?;
1829                result.register_as_evaluated_expr(self, span)
1830            }
1831
1832            // unimplemented
1833            crate::MathFunction::Modf
1834            | crate::MathFunction::Frexp
1835            | crate::MathFunction::Ldexp
1836            | crate::MathFunction::Outer
1837            | crate::MathFunction::FaceForward
1838            | crate::MathFunction::Reflect
1839            | crate::MathFunction::Refract
1840            | crate::MathFunction::Mix
1841            | crate::MathFunction::SmoothStep
1842            | crate::MathFunction::Inverse
1843            | crate::MathFunction::Transpose
1844            | crate::MathFunction::Determinant
1845            | crate::MathFunction::QuantizeToF16
1846            | crate::MathFunction::ExtractBits
1847            | crate::MathFunction::InsertBits
1848            | crate::MathFunction::Pack4x8snorm
1849            | crate::MathFunction::Pack4x8unorm
1850            | crate::MathFunction::Pack2x16snorm
1851            | crate::MathFunction::Pack2x16unorm
1852            | crate::MathFunction::Pack2x16float
1853            | crate::MathFunction::Pack4xI8
1854            | crate::MathFunction::Pack4xU8
1855            | crate::MathFunction::Pack4xI8Clamp
1856            | crate::MathFunction::Pack4xU8Clamp
1857            | crate::MathFunction::Unpack4x8snorm
1858            | crate::MathFunction::Unpack4x8unorm
1859            | crate::MathFunction::Unpack2x16snorm
1860            | crate::MathFunction::Unpack2x16unorm
1861            | crate::MathFunction::Unpack2x16float
1862            | crate::MathFunction::Unpack4xI8
1863            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1864                format!("{fun:?} built-in function"),
1865            )),
1866        }
1867    }
1868
1869    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1870    fn packed_dot_product(
1871        &mut self,
1872        a: Handle<Expression>,
1873        b: Handle<Expression>,
1874        span: Span,
1875        signed: bool,
1876    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1877        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1878            return Err(ConstantEvaluatorError::InvalidMathArg);
1879        };
1880        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1881            return Err(ConstantEvaluatorError::InvalidMathArg);
1882        };
1883
1884        let result = if signed {
1885            Literal::I32(
1886                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1887                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1888                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1889                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1890            )
1891        } else {
1892            Literal::U32(
1893                (a & 0xFF) * (b & 0xFF)
1894                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1895                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1896                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1897            )
1898        };
1899
1900        self.register_evaluated_expr(Expression::Literal(result), span)
1901    }
1902
1903    /// Vector cross product.
1904    fn cross_product(
1905        &mut self,
1906        a: Handle<Expression>,
1907        b: Handle<Expression>,
1908        span: Span,
1909    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1910        use Literal as Li;
1911
1912        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1913        let (b, _) = self.extract_vec_with_size::<3>(b)?;
1914
1915        let product = match (a, b) {
1916            (
1917                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1918                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1919            ) => {
1920                // `cross` has no overload for AbstractInt, so AbstractInt
1921                // arguments are automatically converted to AbstractFloat. Since
1922                // `f64` has a much wider range than `i64`, there's no danger of
1923                // overflow here.
1924                let p = cross_product(
1925                    [a0 as f64, a1 as f64, a2 as f64],
1926                    [b0 as f64, b1 as f64, b2 as f64],
1927                );
1928                [
1929                    Li::AbstractFloat(p[0]),
1930                    Li::AbstractFloat(p[1]),
1931                    Li::AbstractFloat(p[2]),
1932                ]
1933            }
1934            (
1935                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1936                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1937            ) => {
1938                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1939                [
1940                    Li::AbstractFloat(p[0]),
1941                    Li::AbstractFloat(p[1]),
1942                    Li::AbstractFloat(p[2]),
1943                ]
1944            }
1945            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1946                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1947                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1948            }
1949            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1950                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1951                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1952            }
1953            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1954                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1955                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1956            }
1957            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1958        };
1959
1960        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1961        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1962        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1963
1964        self.register_evaluated_expr(
1965            Expression::Compose {
1966                ty,
1967                components: vec![p0, p1, p2],
1968            },
1969            span,
1970        )
1971    }
1972
1973    /// Extract the values of a `vecN` from `expr`.
1974    ///
1975    /// Return the value of `expr`, whose type is `vecN<S>` for some
1976    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1977    /// values.
1978    ///
1979    /// Also return the type handle from the `Compose` expression.
1980    fn extract_vec_with_size<const N: usize>(
1981        &mut self,
1982        expr: Handle<Expression>,
1983    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1984        let span = self.expressions.get_span(expr);
1985        let expr = self.eval_zero_value_and_splat(expr, span)?;
1986        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1987            return Err(ConstantEvaluatorError::InvalidMathArg);
1988        };
1989
1990        let mut value = [Literal::Bool(false); N];
1991        for (component, elt) in
1992            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1993                .zip(value.iter_mut())
1994        {
1995            let Expression::Literal(literal) = self.expressions[component] else {
1996                return Err(ConstantEvaluatorError::InvalidMathArg);
1997            };
1998            *elt = literal;
1999        }
2000
2001        Ok((value, ty))
2002    }
2003
2004    /// Extract the values of a `vecN` from `expr`.
2005    ///
2006    /// Return the value of `expr`, whose type is `vecN<S>` for some
2007    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2008    /// values.
2009    ///
2010    /// Also return the type handle from the `Compose` expression.
2011    fn extract_vec(
2012        &mut self,
2013        expr: Handle<Expression>,
2014        allow_single: bool,
2015    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2016        let span = self.expressions.get_span(expr);
2017        let expr = self.eval_zero_value_and_splat(expr, span)?;
2018
2019        match self.expressions[expr] {
2020            Expression::Literal(literal) if allow_single => {
2021                Ok(LiteralVector::from_literal(literal))
2022            }
2023            Expression::Compose { ty, ref components } => {
2024                let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2025                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2026                        .map(|expr| match self.expressions[expr] {
2027                            Expression::Literal(l) => Ok(l),
2028                            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2029                        })
2030                        .collect::<Result<_, ConstantEvaluatorError>>()?;
2031                LiteralVector::from_literal_vec(components)
2032            }
2033            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2034        }
2035    }
2036
2037    fn array_length(
2038        &mut self,
2039        array: Handle<Expression>,
2040        span: Span,
2041    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2042        match self.expressions[array] {
2043            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2044                match self.types[ty].inner {
2045                    TypeInner::Array { size, .. } => match size {
2046                        ArraySize::Constant(len) => {
2047                            let expr = Expression::Literal(Literal::U32(len.get()));
2048                            self.register_evaluated_expr(expr, span)
2049                        }
2050                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2051                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2052                    },
2053                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2054                }
2055            }
2056            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2057        }
2058    }
2059
2060    fn access(
2061        &mut self,
2062        base: Handle<Expression>,
2063        index: usize,
2064        span: Span,
2065    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2066        match self.expressions[base] {
2067            Expression::ZeroValue(ty) => {
2068                let ty_inner = &self.types[ty].inner;
2069                let components = ty_inner
2070                    .components()
2071                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2072
2073                if index >= components as usize {
2074                    Err(ConstantEvaluatorError::InvalidAccessBase)
2075                } else {
2076                    let ty_res = ty_inner
2077                        .component_type(index)
2078                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2079                    let ty = match ty_res {
2080                        crate::proc::TypeResolution::Handle(ty) => ty,
2081                        crate::proc::TypeResolution::Value(inner) => {
2082                            self.types.insert(Type { name: None, inner }, span)
2083                        }
2084                    };
2085                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2086                }
2087            }
2088            Expression::Splat { size, value } => {
2089                if index >= size as usize {
2090                    Err(ConstantEvaluatorError::InvalidAccessBase)
2091                } else {
2092                    Ok(value)
2093                }
2094            }
2095            Expression::Compose { ty, ref components } => {
2096                let _ = self.types[ty]
2097                    .inner
2098                    .components()
2099                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2100
2101                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2102                    .nth(index)
2103                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2104            }
2105            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2106        }
2107    }
2108
2109    fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
2110        match self.expressions[expr] {
2111            Expression::ZeroValue(ty)
2112                if matches!(
2113                    self.types[ty].inner,
2114                    TypeInner::Scalar(crate::Scalar {
2115                        kind: ScalarKind::Uint,
2116                        ..
2117                    })
2118                ) =>
2119            {
2120                Ok(0)
2121            }
2122            Expression::Literal(Literal::U32(index)) => Ok(index as usize),
2123            _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
2124        }
2125    }
2126
2127    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2128    ///
2129    /// [`ZeroValue`]: Expression::ZeroValue
2130    /// [`Splat`]: Expression::Splat
2131    /// [`Literal`]: Expression::Literal
2132    /// [`Compose`]: Expression::Compose
2133    fn eval_zero_value_and_splat(
2134        &mut self,
2135        mut expr: Handle<Expression>,
2136        span: Span,
2137    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2138        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2139        // each of its components.
2140        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2141            let components = components
2142                .clone()
2143                .iter()
2144                .map(|component| self.eval_zero_value_and_splat(*component, span))
2145                .collect::<Result<_, _>>()?;
2146            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2147        }
2148
2149        // The result of the splat() for a Splat of a scalar ZeroValue is a
2150        // vector ZeroValue, so we must call eval_zero_value_impl() after
2151        // splat() in order to ensure we have no ZeroValues remaining.
2152        if let Expression::Splat { size, value } = self.expressions[expr] {
2153            expr = self.splat(value, size, span)?;
2154        }
2155        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2156            expr = self.eval_zero_value_impl(ty, span)?;
2157        }
2158        Ok(expr)
2159    }
2160
2161    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2162    ///
2163    /// [`ZeroValue`]: Expression::ZeroValue
2164    /// [`Literal`]: Expression::Literal
2165    /// [`Compose`]: Expression::Compose
2166    fn eval_zero_value(
2167        &mut self,
2168        expr: Handle<Expression>,
2169        span: Span,
2170    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2171        match self.expressions[expr] {
2172            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2173            _ => Ok(expr),
2174        }
2175    }
2176
2177    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2178    ///
2179    /// [`ZeroValue`]: Expression::ZeroValue
2180    /// [`Literal`]: Expression::Literal
2181    /// [`Compose`]: Expression::Compose
2182    fn eval_zero_value_impl(
2183        &mut self,
2184        ty: Handle<Type>,
2185        span: Span,
2186    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2187        match self.types[ty].inner {
2188            TypeInner::Scalar(scalar) => {
2189                let expr = Expression::Literal(
2190                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2191                );
2192                self.register_evaluated_expr(expr, span)
2193            }
2194            TypeInner::Vector { size, scalar } => {
2195                let scalar_ty = self.types.insert(
2196                    Type {
2197                        name: None,
2198                        inner: TypeInner::Scalar(scalar),
2199                    },
2200                    span,
2201                );
2202                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2203                let expr = Expression::Compose {
2204                    ty,
2205                    components: vec![el; size as usize],
2206                };
2207                self.register_evaluated_expr(expr, span)
2208            }
2209            TypeInner::Matrix {
2210                columns,
2211                rows,
2212                scalar,
2213            } => {
2214                let vec_ty = self.types.insert(
2215                    Type {
2216                        name: None,
2217                        inner: TypeInner::Vector { size: rows, scalar },
2218                    },
2219                    span,
2220                );
2221                let el = self.eval_zero_value_impl(vec_ty, span)?;
2222                let expr = Expression::Compose {
2223                    ty,
2224                    components: vec![el; columns as usize],
2225                };
2226                self.register_evaluated_expr(expr, span)
2227            }
2228            TypeInner::Array {
2229                base,
2230                size: ArraySize::Constant(size),
2231                ..
2232            } => {
2233                let el = self.eval_zero_value_impl(base, span)?;
2234                let expr = Expression::Compose {
2235                    ty,
2236                    components: vec![el; size.get() as usize],
2237                };
2238                self.register_evaluated_expr(expr, span)
2239            }
2240            TypeInner::Struct { ref members, .. } => {
2241                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2242                let mut components = Vec::with_capacity(members.len());
2243                for ty in types {
2244                    components.push(self.eval_zero_value_impl(ty, span)?);
2245                }
2246                let expr = Expression::Compose { ty, components };
2247                self.register_evaluated_expr(expr, span)
2248            }
2249            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2250        }
2251    }
2252
2253    /// Convert the scalar components of `expr` to `target`.
2254    ///
2255    /// Treat `span` as the location of the resulting expression.
2256    pub fn cast(
2257        &mut self,
2258        expr: Handle<Expression>,
2259        target: crate::Scalar,
2260        span: Span,
2261    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2262        use crate::Scalar as Sc;
2263
2264        let expr = self.eval_zero_value(expr, span)?;
2265
2266        let make_error = || -> Result<_, ConstantEvaluatorError> {
2267            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2268
2269            #[cfg(feature = "wgsl-in")]
2270            let to = target.to_wgsl_for_diagnostics();
2271
2272            #[cfg(not(feature = "wgsl-in"))]
2273            let to = format!("{target:?}");
2274
2275            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2276        };
2277
2278        use crate::proc::type_methods::IntFloatLimits;
2279
2280        let expr = match self.expressions[expr] {
2281            Expression::Literal(literal) => {
2282                let literal = match target {
2283                    Sc::I32 => Literal::I32(match literal {
2284                        Literal::I32(v) => v,
2285                        Literal::U32(v) => v as i32,
2286                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2287                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2288                        Literal::Bool(v) => v as i32,
2289                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2290                            return make_error();
2291                        }
2292                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2293                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2294                    }),
2295                    Sc::U32 => Literal::U32(match literal {
2296                        Literal::I32(v) => v as u32,
2297                        Literal::U32(v) => v,
2298                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2299                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2300                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2301                        Literal::Bool(v) => v as u32,
2302                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2303                            return make_error();
2304                        }
2305                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2306                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2307                    }),
2308                    Sc::I64 => Literal::I64(match literal {
2309                        Literal::I32(v) => v as i64,
2310                        Literal::U32(v) => v as i64,
2311                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2312                        Literal::Bool(v) => v as i64,
2313                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2314                        Literal::I64(v) => v,
2315                        Literal::U64(v) => v as i64,
2316                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2317                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2318                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2319                    }),
2320                    Sc::U64 => Literal::U64(match literal {
2321                        Literal::I32(v) => v as u64,
2322                        Literal::U32(v) => v as u64,
2323                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2324                        Literal::Bool(v) => v as u64,
2325                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2326                        Literal::I64(v) => v as u64,
2327                        Literal::U64(v) => v,
2328                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2329                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2330                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2331                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2332                    }),
2333                    Sc::F16 => Literal::F16(match literal {
2334                        Literal::F16(v) => v,
2335                        Literal::F32(v) => f16::from_f32(v),
2336                        Literal::F64(v) => f16::from_f64(v),
2337                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2338                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2339                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2340                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2341                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2342                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2343                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2344                    }),
2345                    Sc::F32 => Literal::F32(match literal {
2346                        Literal::I32(v) => v as f32,
2347                        Literal::U32(v) => v as f32,
2348                        Literal::F32(v) => v,
2349                        Literal::Bool(v) => v as u32 as f32,
2350                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2351                            return make_error();
2352                        }
2353                        Literal::F16(v) => f16::to_f32(v),
2354                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2355                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2356                    }),
2357                    Sc::F64 => Literal::F64(match literal {
2358                        Literal::I32(v) => v as f64,
2359                        Literal::U32(v) => v as f64,
2360                        Literal::F16(v) => f16::to_f64(v),
2361                        Literal::F32(v) => v as f64,
2362                        Literal::F64(v) => v,
2363                        Literal::Bool(v) => v as u32 as f64,
2364                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2365                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2366                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2367                    }),
2368                    Sc::BOOL => Literal::Bool(match literal {
2369                        Literal::I32(v) => v != 0,
2370                        Literal::U32(v) => v != 0,
2371                        Literal::F32(v) => v != 0.0,
2372                        Literal::F16(v) => v != f16::zero(),
2373                        Literal::Bool(v) => v,
2374                        Literal::AbstractInt(v) => v != 0,
2375                        Literal::AbstractFloat(v) => v != 0.0,
2376                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2377                            return make_error();
2378                        }
2379                    }),
2380                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2381                        Literal::AbstractInt(v) => {
2382                            // Overflow is forbidden, but inexact conversions
2383                            // are fine. The range of f64 is far larger than
2384                            // that of i64, so we don't have to check anything
2385                            // here.
2386                            v as f64
2387                        }
2388                        Literal::AbstractFloat(v) => v,
2389                        _ => return make_error(),
2390                    }),
2391                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2392                        Literal::AbstractInt(v) => v,
2393                        _ => return make_error(),
2394                    }),
2395                    _ => {
2396                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2397                        return make_error();
2398                    }
2399                };
2400                Expression::Literal(literal)
2401            }
2402            Expression::Compose {
2403                ty,
2404                components: ref src_components,
2405            } => {
2406                let ty_inner = match self.types[ty].inner {
2407                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2408                        size,
2409                        scalar: target,
2410                    },
2411                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2412                        columns,
2413                        rows,
2414                        scalar: target,
2415                    },
2416                    _ => return make_error(),
2417                };
2418
2419                let mut components = src_components.clone();
2420                for component in &mut components {
2421                    *component = self.cast(*component, target, span)?;
2422                }
2423
2424                let ty = self.types.insert(
2425                    Type {
2426                        name: None,
2427                        inner: ty_inner,
2428                    },
2429                    span,
2430                );
2431
2432                Expression::Compose { ty, components }
2433            }
2434            Expression::Splat { size, value } => {
2435                let value_span = self.expressions.get_span(value);
2436                let cast_value = self.cast(value, target, value_span)?;
2437                Expression::Splat {
2438                    size,
2439                    value: cast_value,
2440                }
2441            }
2442            _ => return make_error(),
2443        };
2444
2445        self.register_evaluated_expr(expr, span)
2446    }
2447
2448    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2449    ///
2450    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2451    /// matrix, or nested arrays of such.
2452    ///
2453    /// This is basically the same as the [`cast`] method, except that that
2454    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2455    ///
2456    /// Treat `span` as the location of the resulting expression.
2457    ///
2458    /// [`cast`]: ConstantEvaluator::cast
2459    /// [`As`]: crate::Expression::As
2460    pub fn cast_array(
2461        &mut self,
2462        expr: Handle<Expression>,
2463        target: crate::Scalar,
2464        span: Span,
2465    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2466        let expr = self.check_and_get(expr)?;
2467
2468        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2469            return self.cast(expr, target, span);
2470        };
2471
2472        let TypeInner::Array {
2473            base: _,
2474            size,
2475            stride: _,
2476        } = self.types[ty].inner
2477        else {
2478            return self.cast(expr, target, span);
2479        };
2480
2481        let mut components = components.clone();
2482        for component in &mut components {
2483            *component = self.cast_array(*component, target, span)?;
2484        }
2485
2486        let first = components.first().unwrap();
2487        let new_base = match self.resolve_type(*first)? {
2488            crate::proc::TypeResolution::Handle(ty) => ty,
2489            crate::proc::TypeResolution::Value(inner) => {
2490                self.types.insert(Type { name: None, inner }, span)
2491            }
2492        };
2493        let mut layouter = core::mem::take(self.layouter);
2494        layouter.update(self.to_ctx()).unwrap();
2495        *self.layouter = layouter;
2496
2497        let new_base_stride = self.layouter[new_base].to_stride();
2498        let new_array_ty = self.types.insert(
2499            Type {
2500                name: None,
2501                inner: TypeInner::Array {
2502                    base: new_base,
2503                    size,
2504                    stride: new_base_stride,
2505                },
2506            },
2507            span,
2508        );
2509
2510        let compose = Expression::Compose {
2511            ty: new_array_ty,
2512            components,
2513        };
2514        self.register_evaluated_expr(compose, span)
2515    }
2516
2517    fn unary_op(
2518        &mut self,
2519        op: UnaryOperator,
2520        expr: Handle<Expression>,
2521        span: Span,
2522    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2523        let expr = self.eval_zero_value_and_splat(expr, span)?;
2524
2525        let expr = match self.expressions[expr] {
2526            Expression::Literal(value) => Expression::Literal(match op {
2527                UnaryOperator::Negate => match value {
2528                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2529                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2530                    Literal::F32(v) => Literal::F32(-v),
2531                    Literal::F16(v) => Literal::F16(-v),
2532                    Literal::F64(v) => Literal::F64(-v),
2533                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2534                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2535                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2536                },
2537                UnaryOperator::LogicalNot => match value {
2538                    Literal::Bool(v) => Literal::Bool(!v),
2539                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2540                },
2541                UnaryOperator::BitwiseNot => match value {
2542                    Literal::I32(v) => Literal::I32(!v),
2543                    Literal::I64(v) => Literal::I64(!v),
2544                    Literal::U32(v) => Literal::U32(!v),
2545                    Literal::U64(v) => Literal::U64(!v),
2546                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2547                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2548                },
2549            }),
2550            Expression::Compose {
2551                ty,
2552                components: ref src_components,
2553            } => {
2554                match self.types[ty].inner {
2555                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2556                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2557                }
2558
2559                let mut components = src_components.clone();
2560                for component in &mut components {
2561                    *component = self.unary_op(op, *component, span)?;
2562                }
2563
2564                Expression::Compose { ty, components }
2565            }
2566            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2567        };
2568
2569        self.register_evaluated_expr(expr, span)
2570    }
2571
2572    fn binary_op(
2573        &mut self,
2574        op: BinaryOperator,
2575        left: Handle<Expression>,
2576        right: Handle<Expression>,
2577        span: Span,
2578    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2579        let left = self.eval_zero_value_and_splat(left, span)?;
2580        let right = self.eval_zero_value_and_splat(right, span)?;
2581
2582        let expr = match (&self.expressions[left], &self.expressions[right]) {
2583            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2584                let literal = match op {
2585                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2586                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2587                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2588                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2589                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2590                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2591
2592                    _ => match (left_value, right_value) {
2593                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2594                            BinaryOperator::Add => a.wrapping_add(b),
2595                            BinaryOperator::Subtract => a.wrapping_sub(b),
2596                            BinaryOperator::Multiply => a.wrapping_mul(b),
2597                            BinaryOperator::Divide => {
2598                                if b == 0 {
2599                                    return Err(ConstantEvaluatorError::DivisionByZero);
2600                                } else {
2601                                    a.wrapping_div(b)
2602                                }
2603                            }
2604                            BinaryOperator::Modulo => {
2605                                if b == 0 {
2606                                    return Err(ConstantEvaluatorError::RemainderByZero);
2607                                } else {
2608                                    a.wrapping_rem(b)
2609                                }
2610                            }
2611                            BinaryOperator::And => a & b,
2612                            BinaryOperator::ExclusiveOr => a ^ b,
2613                            BinaryOperator::InclusiveOr => a | b,
2614                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2615                        }),
2616                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2617                            BinaryOperator::ShiftLeft => {
2618                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2619                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2620                                }
2621                                a.checked_shl(b)
2622                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2623                            }
2624                            BinaryOperator::ShiftRight => a
2625                                .checked_shr(b)
2626                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2627                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2628                        }),
2629                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2630                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2631                                ConstantEvaluatorError::Overflow("addition".into())
2632                            })?,
2633                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2634                                ConstantEvaluatorError::Overflow("subtraction".into())
2635                            })?,
2636                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2637                                ConstantEvaluatorError::Overflow("multiplication".into())
2638                            })?,
2639                            BinaryOperator::Divide => a
2640                                .checked_div(b)
2641                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2642                            BinaryOperator::Modulo => a
2643                                .checked_rem(b)
2644                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2645                            BinaryOperator::And => a & b,
2646                            BinaryOperator::ExclusiveOr => a ^ b,
2647                            BinaryOperator::InclusiveOr => a | b,
2648                            BinaryOperator::ShiftLeft => a
2649                                .checked_mul(
2650                                    1u32.checked_shl(b)
2651                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2652                                )
2653                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2654                            BinaryOperator::ShiftRight => a
2655                                .checked_shr(b)
2656                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2657                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2658                        }),
2659                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2660                            BinaryOperator::Add => a + b,
2661                            BinaryOperator::Subtract => a - b,
2662                            BinaryOperator::Multiply => a * b,
2663                            BinaryOperator::Divide => a / b,
2664                            BinaryOperator::Modulo => a % b,
2665                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2666                        }),
2667                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2668                            Literal::AbstractInt(match op {
2669                                BinaryOperator::ShiftLeft => {
2670                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2671                                        return Err(ConstantEvaluatorError::Overflow(
2672                                            "<<".to_string(),
2673                                        ));
2674                                    }
2675                                    a.checked_shl(b).unwrap_or(0)
2676                                }
2677                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2678                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2679                            })
2680                        }
2681                        (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2682                            BinaryOperator::Add => a + b,
2683                            BinaryOperator::Subtract => a - b,
2684                            BinaryOperator::Multiply => a * b,
2685                            BinaryOperator::Divide => a / b,
2686                            BinaryOperator::Modulo => a % b,
2687                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2688                        }),
2689                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2690                            Literal::AbstractInt(match op {
2691                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2692                                    ConstantEvaluatorError::Overflow("addition".into())
2693                                })?,
2694                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2695                                    ConstantEvaluatorError::Overflow("subtraction".into())
2696                                })?,
2697                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2698                                    ConstantEvaluatorError::Overflow("multiplication".into())
2699                                })?,
2700                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2701                                    if b == 0 {
2702                                        ConstantEvaluatorError::DivisionByZero
2703                                    } else {
2704                                        ConstantEvaluatorError::Overflow("division".into())
2705                                    }
2706                                })?,
2707                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2708                                    if b == 0 {
2709                                        ConstantEvaluatorError::RemainderByZero
2710                                    } else {
2711                                        ConstantEvaluatorError::Overflow("remainder".into())
2712                                    }
2713                                })?,
2714                                BinaryOperator::And => a & b,
2715                                BinaryOperator::ExclusiveOr => a ^ b,
2716                                BinaryOperator::InclusiveOr => a | b,
2717                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2718                            })
2719                        }
2720                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2721                            Literal::AbstractFloat(match op {
2722                                BinaryOperator::Add => a + b,
2723                                BinaryOperator::Subtract => a - b,
2724                                BinaryOperator::Multiply => a * b,
2725                                BinaryOperator::Divide => a / b,
2726                                BinaryOperator::Modulo => a % b,
2727                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728                            })
2729                        }
2730                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2731                            BinaryOperator::LogicalAnd => a && b,
2732                            BinaryOperator::LogicalOr => a || b,
2733                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2734                        }),
2735                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2736                    },
2737                };
2738                Expression::Literal(literal)
2739            }
2740            (
2741                &Expression::Compose {
2742                    components: ref src_components,
2743                    ty,
2744                },
2745                &Expression::Literal(_),
2746            ) => {
2747                let mut components = src_components.clone();
2748                for component in &mut components {
2749                    *component = self.binary_op(op, *component, right, span)?;
2750                }
2751                Expression::Compose { ty, components }
2752            }
2753            (
2754                &Expression::Literal(_),
2755                &Expression::Compose {
2756                    components: ref src_components,
2757                    ty,
2758                },
2759            ) => {
2760                let mut components = src_components.clone();
2761                for component in &mut components {
2762                    *component = self.binary_op(op, left, *component, span)?;
2763                }
2764                Expression::Compose { ty, components }
2765            }
2766            (
2767                &Expression::Compose {
2768                    components: ref left_components,
2769                    ty: left_ty,
2770                },
2771                &Expression::Compose {
2772                    components: ref right_components,
2773                    ty: right_ty,
2774                },
2775            ) => {
2776                // We have to make a copy of the component lists, because the
2777                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2778                // the component lists.
2779                let left_flattened = crate::proc::flatten_compose(
2780                    left_ty,
2781                    left_components,
2782                    self.expressions,
2783                    self.types,
2784                );
2785                let right_flattened = crate::proc::flatten_compose(
2786                    right_ty,
2787                    right_components,
2788                    self.expressions,
2789                    self.types,
2790                );
2791
2792                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2793                // make a reasonable guess of the capacity we'll need.
2794                let mut flattened = Vec::with_capacity(left_components.len());
2795                flattened.extend(left_flattened.zip(right_flattened));
2796
2797                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2798                    (
2799                        &TypeInner::Vector {
2800                            size: left_size, ..
2801                        },
2802                        &TypeInner::Vector {
2803                            size: right_size, ..
2804                        },
2805                    ) if left_size == right_size => {
2806                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2807                    }
2808                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2809                }
2810            }
2811            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2812        };
2813
2814        self.register_evaluated_expr(expr, span)
2815    }
2816
2817    fn binary_op_vector(
2818        &mut self,
2819        op: BinaryOperator,
2820        size: crate::VectorSize,
2821        components: &[(Handle<Expression>, Handle<Expression>)],
2822        left_ty: Handle<Type>,
2823        span: Span,
2824    ) -> Result<Expression, ConstantEvaluatorError> {
2825        let ty = match op {
2826            // Relational operators produce vectors of booleans.
2827            BinaryOperator::Equal
2828            | BinaryOperator::NotEqual
2829            | BinaryOperator::Less
2830            | BinaryOperator::LessEqual
2831            | BinaryOperator::Greater
2832            | BinaryOperator::GreaterEqual => self.types.insert(
2833                Type {
2834                    name: None,
2835                    inner: TypeInner::Vector {
2836                        size,
2837                        scalar: crate::Scalar::BOOL,
2838                    },
2839                },
2840                span,
2841            ),
2842
2843            // Other operators produce the same type as their left
2844            // operand.
2845            BinaryOperator::Add
2846            | BinaryOperator::Subtract
2847            | BinaryOperator::Multiply
2848            | BinaryOperator::Divide
2849            | BinaryOperator::Modulo
2850            | BinaryOperator::And
2851            | BinaryOperator::ExclusiveOr
2852            | BinaryOperator::InclusiveOr
2853            | BinaryOperator::ShiftLeft
2854            | BinaryOperator::ShiftRight => left_ty,
2855
2856            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2857                // Not supported on vectors
2858                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2859            }
2860        };
2861
2862        let components = components
2863            .iter()
2864            .map(|&(left, right)| self.binary_op(op, left, right, span))
2865            .collect::<Result<Vec<_>, _>>()?;
2866
2867        Ok(Expression::Compose { ty, components })
2868    }
2869
2870    fn relational(
2871        &mut self,
2872        fun: RelationalFunction,
2873        arg: Handle<Expression>,
2874        span: Span,
2875    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2876        let arg = self.eval_zero_value_and_splat(arg, span)?;
2877        match fun {
2878            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2879                Expression::Literal(Literal::Bool(_)) => Ok(arg),
2880                Expression::Compose { ty, ref components }
2881                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2882                {
2883                    let components =
2884                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2885                            .map(|component| match self.expressions[component] {
2886                                Expression::Literal(Literal::Bool(val)) => Ok(val),
2887                                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2888                            })
2889                            .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2890                    let result = match fun {
2891                        RelationalFunction::All => components.iter().all(|c| *c),
2892                        RelationalFunction::Any => components.iter().any(|c| *c),
2893                        _ => unreachable!(),
2894                    };
2895                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2896                }
2897                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2898            },
2899            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2900                "{fun:?} built-in function"
2901            ))),
2902        }
2903    }
2904
2905    /// Deep copy `expr` from `expressions` into `self.expressions`.
2906    ///
2907    /// Return the root of the new copy.
2908    ///
2909    /// This is used when we're evaluating expressions in a function's
2910    /// expression arena that refer to a constant: we need to copy the
2911    /// constant's value into the function's arena so we can operate on it.
2912    fn copy_from(
2913        &mut self,
2914        expr: Handle<Expression>,
2915        expressions: &Arena<Expression>,
2916    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2917        let span = expressions.get_span(expr);
2918        match expressions[expr] {
2919            ref expr @ (Expression::Literal(_)
2920            | Expression::Constant(_)
2921            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2922            Expression::Compose { ty, ref components } => {
2923                let mut components = components.clone();
2924                for component in &mut components {
2925                    *component = self.copy_from(*component, expressions)?;
2926                }
2927                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2928            }
2929            Expression::Splat { size, value } => {
2930                let value = self.copy_from(value, expressions)?;
2931                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2932            }
2933            _ => {
2934                log::debug!("copy_from: SubexpressionsAreNotConstant");
2935                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2936            }
2937        }
2938    }
2939
2940    /// Returns the total number of components, after flattening, of a vector compose expression.
2941    fn vector_compose_flattened_size(
2942        &self,
2943        components: &[Handle<Expression>],
2944    ) -> Result<usize, ConstantEvaluatorError> {
2945        components
2946            .iter()
2947            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2948                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2949                    TypeInner::Scalar(_) => 1,
2950                    // We trust that the vector size of `component` is correct,
2951                    // as it will have already been validated when `component`
2952                    // was registered.
2953                    TypeInner::Vector { size, .. } => size as usize,
2954                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2955                };
2956                Ok(acc + size)
2957            })
2958    }
2959
2960    fn register_evaluated_expr(
2961        &mut self,
2962        expr: Expression,
2963        span: Span,
2964    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2965        // It suffices to only check_literal_value() for `Literal` expressions,
2966        // since we only register one expression at a time, `Compose`
2967        // expressions can only refer to other expressions, and `ZeroValue`
2968        // expressions are always okay.
2969        if let Expression::Literal(literal) = expr {
2970            crate::valid::check_literal_value(literal)?;
2971        }
2972
2973        // Ensure vector composes contain the correct number of components. We
2974        // do so here when each compose is registered to avoid having to deal
2975        // with the mess each time the compose is used in another expression.
2976        if let Expression::Compose { ty, ref components } = expr {
2977            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2978                let expected = size as usize;
2979                let actual = self.vector_compose_flattened_size(components)?;
2980                if expected != actual {
2981                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2982                        expected,
2983                        actual,
2984                    });
2985                }
2986            }
2987        }
2988
2989        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2990    }
2991
2992    fn append_expr(
2993        &mut self,
2994        expr: Expression,
2995        span: Span,
2996        expr_type: ExpressionKind,
2997    ) -> Handle<Expression> {
2998        let h = match self.behavior {
2999            Behavior::Wgsl(
3000                WgslRestrictions::Runtime(ref mut function_local_data)
3001                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3002            )
3003            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3004                let is_running = function_local_data.emitter.is_running();
3005                let needs_pre_emit = expr.needs_pre_emit();
3006                if is_running && needs_pre_emit {
3007                    function_local_data
3008                        .block
3009                        .extend(function_local_data.emitter.finish(self.expressions));
3010                    let h = self.expressions.append(expr, span);
3011                    function_local_data.emitter.start(self.expressions);
3012                    h
3013                } else {
3014                    self.expressions.append(expr, span)
3015                }
3016            }
3017            _ => self.expressions.append(expr, span),
3018        };
3019        self.expression_kind_tracker.insert(h, expr_type);
3020        h
3021    }
3022
3023    fn resolve_type(
3024        &self,
3025        expr: Handle<Expression>,
3026    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3027        use crate::proc::TypeResolution as Tr;
3028        use crate::Expression as Ex;
3029        let resolution = match self.expressions[expr] {
3030            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3031            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3032            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3033            Ex::Splat { size, value } => {
3034                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3035                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3036                };
3037                Tr::Value(TypeInner::Vector { scalar, size })
3038            }
3039            _ => {
3040                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3041                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3042            }
3043        };
3044
3045        Ok(resolution)
3046    }
3047
3048    fn select(
3049        &mut self,
3050        reject: Handle<Expression>,
3051        accept: Handle<Expression>,
3052        condition: Handle<Expression>,
3053        span: Span,
3054    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3055        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3056
3057        let reject = arg(reject)?;
3058        let accept = arg(accept)?;
3059        let condition = arg(condition)?;
3060
3061        let select_single_component =
3062            |this: &mut Self, reject_scalar, reject, accept, condition| {
3063                let accept = this.cast(accept, reject_scalar, span)?;
3064                if condition {
3065                    Ok(accept)
3066                } else {
3067                    Ok(reject)
3068                }
3069            };
3070
3071        match (&self.expressions[reject], &self.expressions[accept]) {
3072            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3073                let reject_scalar = reject_lit.scalar();
3074                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3075                else {
3076                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3077                };
3078                select_single_component(self, reject_scalar, reject, accept, condition)
3079            }
3080            (
3081                &Expression::Compose {
3082                    ty: reject_ty,
3083                    components: ref reject_components,
3084                },
3085                &Expression::Compose {
3086                    ty: accept_ty,
3087                    components: ref accept_components,
3088                },
3089            ) => {
3090                let ty_deets = |ty| {
3091                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3092                    (size.unwrap(), scalar)
3093                };
3094
3095                let expected_vec_size = {
3096                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3097                        [reject_ty, accept_ty].map(ty_deets);
3098
3099                    if reject_vec_size != accept_vec_size {
3100                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3101                            reject: reject_vec_size,
3102                            accept: accept_vec_size,
3103                        });
3104                    }
3105                    reject_vec_size
3106                };
3107
3108                let condition_components = match self.expressions[condition] {
3109                    Expression::Literal(Literal::Bool(condition)) => {
3110                        vec![condition; (expected_vec_size as u8).into()]
3111                    }
3112                    Expression::Compose {
3113                        ty: condition_ty,
3114                        components: ref condition_components,
3115                    } => {
3116                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3117                        if condition_scalar.kind != ScalarKind::Bool {
3118                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3119                        }
3120                        if condition_vec_size != expected_vec_size {
3121                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3122                        }
3123                        condition_components
3124                            .iter()
3125                            .copied()
3126                            .map(|component| match &self.expressions[component] {
3127                                &Expression::Literal(Literal::Bool(condition)) => condition,
3128                                _ => unreachable!(),
3129                            })
3130                            .collect()
3131                    }
3132
3133                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3134                };
3135
3136                let evaluated = Expression::Compose {
3137                    ty: reject_ty,
3138                    components: reject_components
3139                        .clone()
3140                        .into_iter()
3141                        .zip(accept_components.clone().into_iter())
3142                        .zip(condition_components.into_iter())
3143                        .map(|((reject, accept), condition)| {
3144                            let reject_scalar = match &self.expressions[reject] {
3145                                &Expression::Literal(lit) => lit.scalar(),
3146                                _ => unreachable!(),
3147                            };
3148                            select_single_component(self, reject_scalar, reject, accept, condition)
3149                        })
3150                        .collect::<Result<_, _>>()?,
3151                };
3152                self.register_evaluated_expr(evaluated, span)
3153            }
3154            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3155        }
3156    }
3157}
3158
3159fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3160    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3161    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3162    // return a right-to-left bit index of 0.
3163    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3164        match e {
3165            idx @ 0..=31 => idx,
3166            32 => u32::MAX,
3167            _ => unreachable!(),
3168        }
3169    };
3170    match concrete_int {
3171        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3172        ConcreteInt::I32([e]) => {
3173            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3174        }
3175    }
3176}
3177
3178#[test]
3179fn first_trailing_bit_smoke() {
3180    assert_eq!(
3181        first_trailing_bit(ConcreteInt::I32([0])),
3182        ConcreteInt::I32([-1])
3183    );
3184    assert_eq!(
3185        first_trailing_bit(ConcreteInt::I32([1])),
3186        ConcreteInt::I32([0])
3187    );
3188    assert_eq!(
3189        first_trailing_bit(ConcreteInt::I32([2])),
3190        ConcreteInt::I32([1])
3191    );
3192    assert_eq!(
3193        first_trailing_bit(ConcreteInt::I32([-1])),
3194        ConcreteInt::I32([0]),
3195    );
3196    assert_eq!(
3197        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3198        ConcreteInt::I32([31]),
3199    );
3200    assert_eq!(
3201        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3202        ConcreteInt::I32([0]),
3203    );
3204    for idx in 0..32 {
3205        assert_eq!(
3206            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3207            ConcreteInt::I32([idx])
3208        )
3209    }
3210
3211    assert_eq!(
3212        first_trailing_bit(ConcreteInt::U32([0])),
3213        ConcreteInt::U32([u32::MAX])
3214    );
3215    assert_eq!(
3216        first_trailing_bit(ConcreteInt::U32([1])),
3217        ConcreteInt::U32([0])
3218    );
3219    assert_eq!(
3220        first_trailing_bit(ConcreteInt::U32([2])),
3221        ConcreteInt::U32([1])
3222    );
3223    assert_eq!(
3224        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3225        ConcreteInt::U32([31]),
3226    );
3227    assert_eq!(
3228        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3229        ConcreteInt::U32([0]),
3230    );
3231    for idx in 0..32 {
3232        assert_eq!(
3233            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3234            ConcreteInt::U32([idx])
3235        )
3236    }
3237}
3238
3239fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3240    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3241    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3242    // index of 0.
3243    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3244        match e {
3245            idx @ 0..=31 => 31 - idx,
3246            32 => u32::MAX,
3247            _ => unreachable!(),
3248        }
3249    };
3250    match concrete_int {
3251        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3252            let rtl_bit_index = if e.is_negative() {
3253                e.leading_ones()
3254            } else {
3255                e.leading_zeros()
3256            };
3257            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3258        }]),
3259        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3260    }
3261}
3262
3263#[test]
3264fn first_leading_bit_smoke() {
3265    assert_eq!(
3266        first_leading_bit(ConcreteInt::I32([-1])),
3267        ConcreteInt::I32([-1])
3268    );
3269    assert_eq!(
3270        first_leading_bit(ConcreteInt::I32([0])),
3271        ConcreteInt::I32([-1])
3272    );
3273    assert_eq!(
3274        first_leading_bit(ConcreteInt::I32([1])),
3275        ConcreteInt::I32([0])
3276    );
3277    assert_eq!(
3278        first_leading_bit(ConcreteInt::I32([-2])),
3279        ConcreteInt::I32([0])
3280    );
3281    assert_eq!(
3282        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3283        ConcreteInt::I32([12])
3284    );
3285    assert_eq!(
3286        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3287        ConcreteInt::I32([30])
3288    );
3289    assert_eq!(
3290        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3291        ConcreteInt::I32([30])
3292    );
3293    // NOTE: Ignore the sign bit, which is a separate (above) case.
3294    for idx in 0..(32 - 1) {
3295        assert_eq!(
3296            first_leading_bit(ConcreteInt::I32([1 << idx])),
3297            ConcreteInt::I32([idx])
3298        );
3299    }
3300    for idx in 1..(32 - 1) {
3301        assert_eq!(
3302            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3303            ConcreteInt::I32([idx - 1])
3304        );
3305    }
3306
3307    assert_eq!(
3308        first_leading_bit(ConcreteInt::U32([0])),
3309        ConcreteInt::U32([u32::MAX])
3310    );
3311    assert_eq!(
3312        first_leading_bit(ConcreteInt::U32([1])),
3313        ConcreteInt::U32([0])
3314    );
3315    assert_eq!(
3316        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3317        ConcreteInt::U32([31])
3318    );
3319    for idx in 0..32 {
3320        assert_eq!(
3321            first_leading_bit(ConcreteInt::U32([1 << idx])),
3322            ConcreteInt::U32([idx])
3323        )
3324    }
3325}
3326
3327/// Trait for conversions of abstract values to concrete types.
3328trait TryFromAbstract<T>: Sized {
3329    /// Convert an abstract literal `value` to `Self`.
3330    ///
3331    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3332    /// WGSL, we follow WGSL's conversion rules here:
3333    ///
3334    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3335    ///   from [`AbstractInt`] to an integer type are either lossless or an
3336    ///   error.
3337    ///
3338    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3339    ///   to floating point in constant expressions and override
3340    ///   expressions are errors if the value is out of range for the
3341    ///   destination type, but rounding is okay.
3342    ///
3343    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3344    ///   other floating point type, following the scalar floating point to
3345    ///   integral conversion algorithm (§15.7.6). There is no automatic
3346    ///   conversion from AbstractFloat to integer types.
3347    ///
3348    /// [`AbstractInt`]: crate::Literal::AbstractInt
3349    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3350    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3351}
3352
3353impl TryFromAbstract<i64> for i32 {
3354    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3355        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3356            value: format!("{value:?}"),
3357            to_type: "i32",
3358        })
3359    }
3360}
3361
3362impl TryFromAbstract<i64> for u32 {
3363    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3364        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3365            value: format!("{value:?}"),
3366            to_type: "u32",
3367        })
3368    }
3369}
3370
3371impl TryFromAbstract<i64> for u64 {
3372    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3373        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3374            value: format!("{value:?}"),
3375            to_type: "u64",
3376        })
3377    }
3378}
3379
3380impl TryFromAbstract<i64> for i64 {
3381    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3382        Ok(value)
3383    }
3384}
3385
3386impl TryFromAbstract<i64> for f32 {
3387    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3388        let f = value as f32;
3389        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3390        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3391        // overflow here.
3392        Ok(f)
3393    }
3394}
3395
3396impl TryFromAbstract<f64> for f32 {
3397    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3398        let f = value as f32;
3399        if f.is_infinite() {
3400            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3401                value: format!("{value:?}"),
3402                to_type: "f32",
3403            });
3404        }
3405        Ok(f)
3406    }
3407}
3408
3409impl TryFromAbstract<i64> for f64 {
3410    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3411        let f = value as f64;
3412        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3413        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3414        // overflow here.
3415        Ok(f)
3416    }
3417}
3418
3419impl TryFromAbstract<f64> for f64 {
3420    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3421        Ok(value)
3422    }
3423}
3424
3425impl TryFromAbstract<f64> for i32 {
3426    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3427        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3428        // To convert a floating point scalar value X to an integer scalar type T:
3429        // * If X is a NaN, the result is an indeterminate value in T.
3430        // * If X is exactly representable in the target type T, then the
3431        //   result is that value.
3432        // * Otherwise, the result is the value in T closest to truncate(X) and
3433        //   also exactly representable in the original floating point type.
3434        //
3435        // A rust cast satisfies these requirements apart from "the result
3436        // is... exactly representable in the original floating point type".
3437        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3438        // we're all good.
3439        Ok(value as i32)
3440    }
3441}
3442
3443impl TryFromAbstract<f64> for u32 {
3444    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3445        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3446        // so a simple rust cast is sufficient.
3447        Ok(value as u32)
3448    }
3449}
3450
3451impl TryFromAbstract<f64> for i64 {
3452    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3453        // As above, except we clamp to the minimum and maximum values
3454        // representable by both f64 and i64.
3455        use crate::proc::type_methods::IntFloatLimits;
3456        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3457    }
3458}
3459
3460impl TryFromAbstract<f64> for u64 {
3461    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3462        // As above, this time clamping to the minimum and maximum values
3463        // representable by both f64 and u64.
3464        use crate::proc::type_methods::IntFloatLimits;
3465        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3466    }
3467}
3468
3469impl TryFromAbstract<f64> for f16 {
3470    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3471        let f = f16::from_f64(value);
3472        if f.is_infinite() {
3473            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3474                value: format!("{value:?}"),
3475                to_type: "f16",
3476            });
3477        }
3478        Ok(f)
3479    }
3480}
3481
3482impl TryFromAbstract<i64> for f16 {
3483    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3484        let f = f16::from_i64(value);
3485        if f.is_none() {
3486            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3487                value: format!("{value:?}"),
3488                to_type: "f16",
3489            });
3490        }
3491        Ok(f.unwrap())
3492    }
3493}
3494
3495fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3496where
3497    T: Copy,
3498    T: core::ops::Mul<T, Output = T>,
3499    T: core::ops::Sub<T, Output = T>,
3500{
3501    [
3502        a[1] * b[2] - a[2] * b[1],
3503        a[2] * b[0] - a[0] * b[2],
3504        a[0] * b[1] - a[1] * b[0],
3505    ]
3506}
3507
3508#[cfg(test)]
3509mod tests {
3510    use alloc::{vec, vec::Vec};
3511
3512    use crate::{
3513        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3514        UniqueArena, VectorSize,
3515    };
3516
3517    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3518
3519    #[test]
3520    fn unary_op() {
3521        let mut types = UniqueArena::new();
3522        let mut constants = Arena::new();
3523        let overrides = Arena::new();
3524        let mut global_expressions = Arena::new();
3525
3526        let scalar_ty = types.insert(
3527            Type {
3528                name: None,
3529                inner: TypeInner::Scalar(crate::Scalar::I32),
3530            },
3531            Default::default(),
3532        );
3533
3534        let vec_ty = types.insert(
3535            Type {
3536                name: None,
3537                inner: TypeInner::Vector {
3538                    size: VectorSize::Bi,
3539                    scalar: crate::Scalar::I32,
3540                },
3541            },
3542            Default::default(),
3543        );
3544
3545        let h = constants.append(
3546            Constant {
3547                name: None,
3548                ty: scalar_ty,
3549                init: global_expressions
3550                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3551            },
3552            Default::default(),
3553        );
3554
3555        let h1 = constants.append(
3556            Constant {
3557                name: None,
3558                ty: scalar_ty,
3559                init: global_expressions
3560                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3561            },
3562            Default::default(),
3563        );
3564
3565        let vec_h = constants.append(
3566            Constant {
3567                name: None,
3568                ty: vec_ty,
3569                init: global_expressions.append(
3570                    Expression::Compose {
3571                        ty: vec_ty,
3572                        components: vec![constants[h].init, constants[h1].init],
3573                    },
3574                    Default::default(),
3575                ),
3576            },
3577            Default::default(),
3578        );
3579
3580        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3581        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3582
3583        let expr2 = Expression::Unary {
3584            op: UnaryOperator::Negate,
3585            expr,
3586        };
3587
3588        let expr3 = Expression::Unary {
3589            op: UnaryOperator::BitwiseNot,
3590            expr,
3591        };
3592
3593        let expr4 = Expression::Unary {
3594            op: UnaryOperator::BitwiseNot,
3595            expr: expr1,
3596        };
3597
3598        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3599        let mut solver = ConstantEvaluator {
3600            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3601            types: &mut types,
3602            constants: &constants,
3603            overrides: &overrides,
3604            expressions: &mut global_expressions,
3605            expression_kind_tracker,
3606            layouter: &mut crate::proc::Layouter::default(),
3607        };
3608
3609        let res1 = solver
3610            .try_eval_and_append(expr2, Default::default())
3611            .unwrap();
3612        let res2 = solver
3613            .try_eval_and_append(expr3, Default::default())
3614            .unwrap();
3615        let res3 = solver
3616            .try_eval_and_append(expr4, Default::default())
3617            .unwrap();
3618
3619        assert_eq!(
3620            global_expressions[res1],
3621            Expression::Literal(Literal::I32(-4))
3622        );
3623
3624        assert_eq!(
3625            global_expressions[res2],
3626            Expression::Literal(Literal::I32(!4))
3627        );
3628
3629        let res3_inner = &global_expressions[res3];
3630
3631        match *res3_inner {
3632            Expression::Compose {
3633                ref ty,
3634                ref components,
3635            } => {
3636                assert_eq!(*ty, vec_ty);
3637                let mut components_iter = components.iter().copied();
3638                assert_eq!(
3639                    global_expressions[components_iter.next().unwrap()],
3640                    Expression::Literal(Literal::I32(!4))
3641                );
3642                assert_eq!(
3643                    global_expressions[components_iter.next().unwrap()],
3644                    Expression::Literal(Literal::I32(!8))
3645                );
3646                assert!(components_iter.next().is_none());
3647            }
3648            _ => panic!("Expected vector"),
3649        }
3650    }
3651
3652    #[test]
3653    fn cast() {
3654        let mut types = UniqueArena::new();
3655        let mut constants = Arena::new();
3656        let overrides = Arena::new();
3657        let mut global_expressions = Arena::new();
3658
3659        let scalar_ty = types.insert(
3660            Type {
3661                name: None,
3662                inner: TypeInner::Scalar(crate::Scalar::I32),
3663            },
3664            Default::default(),
3665        );
3666
3667        let h = constants.append(
3668            Constant {
3669                name: None,
3670                ty: scalar_ty,
3671                init: global_expressions
3672                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3673            },
3674            Default::default(),
3675        );
3676
3677        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3678
3679        let root = Expression::As {
3680            expr,
3681            kind: ScalarKind::Bool,
3682            convert: Some(crate::BOOL_WIDTH),
3683        };
3684
3685        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3686        let mut solver = ConstantEvaluator {
3687            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3688            types: &mut types,
3689            constants: &constants,
3690            overrides: &overrides,
3691            expressions: &mut global_expressions,
3692            expression_kind_tracker,
3693            layouter: &mut crate::proc::Layouter::default(),
3694        };
3695
3696        let res = solver
3697            .try_eval_and_append(root, Default::default())
3698            .unwrap();
3699
3700        assert_eq!(
3701            global_expressions[res],
3702            Expression::Literal(Literal::Bool(true))
3703        );
3704    }
3705
3706    #[test]
3707    fn access() {
3708        let mut types = UniqueArena::new();
3709        let mut constants = Arena::new();
3710        let overrides = Arena::new();
3711        let mut global_expressions = Arena::new();
3712
3713        let matrix_ty = types.insert(
3714            Type {
3715                name: None,
3716                inner: TypeInner::Matrix {
3717                    columns: VectorSize::Bi,
3718                    rows: VectorSize::Tri,
3719                    scalar: crate::Scalar::F32,
3720                },
3721            },
3722            Default::default(),
3723        );
3724
3725        let vec_ty = types.insert(
3726            Type {
3727                name: None,
3728                inner: TypeInner::Vector {
3729                    size: VectorSize::Tri,
3730                    scalar: crate::Scalar::F32,
3731                },
3732            },
3733            Default::default(),
3734        );
3735
3736        let mut vec1_components = Vec::with_capacity(3);
3737        let mut vec2_components = Vec::with_capacity(3);
3738
3739        for i in 0..3 {
3740            let h = global_expressions.append(
3741                Expression::Literal(Literal::F32(i as f32)),
3742                Default::default(),
3743            );
3744
3745            vec1_components.push(h)
3746        }
3747
3748        for i in 3..6 {
3749            let h = global_expressions.append(
3750                Expression::Literal(Literal::F32(i as f32)),
3751                Default::default(),
3752            );
3753
3754            vec2_components.push(h)
3755        }
3756
3757        let vec1 = constants.append(
3758            Constant {
3759                name: None,
3760                ty: vec_ty,
3761                init: global_expressions.append(
3762                    Expression::Compose {
3763                        ty: vec_ty,
3764                        components: vec1_components,
3765                    },
3766                    Default::default(),
3767                ),
3768            },
3769            Default::default(),
3770        );
3771
3772        let vec2 = constants.append(
3773            Constant {
3774                name: None,
3775                ty: vec_ty,
3776                init: global_expressions.append(
3777                    Expression::Compose {
3778                        ty: vec_ty,
3779                        components: vec2_components,
3780                    },
3781                    Default::default(),
3782                ),
3783            },
3784            Default::default(),
3785        );
3786
3787        let h = constants.append(
3788            Constant {
3789                name: None,
3790                ty: matrix_ty,
3791                init: global_expressions.append(
3792                    Expression::Compose {
3793                        ty: matrix_ty,
3794                        components: vec![constants[vec1].init, constants[vec2].init],
3795                    },
3796                    Default::default(),
3797                ),
3798            },
3799            Default::default(),
3800        );
3801
3802        let base = global_expressions.append(Expression::Constant(h), Default::default());
3803
3804        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3805        let mut solver = ConstantEvaluator {
3806            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3807            types: &mut types,
3808            constants: &constants,
3809            overrides: &overrides,
3810            expressions: &mut global_expressions,
3811            expression_kind_tracker,
3812            layouter: &mut crate::proc::Layouter::default(),
3813        };
3814
3815        let root1 = Expression::AccessIndex { base, index: 1 };
3816
3817        let res1 = solver
3818            .try_eval_and_append(root1, Default::default())
3819            .unwrap();
3820
3821        let root2 = Expression::AccessIndex {
3822            base: res1,
3823            index: 2,
3824        };
3825
3826        let res2 = solver
3827            .try_eval_and_append(root2, Default::default())
3828            .unwrap();
3829
3830        match global_expressions[res1] {
3831            Expression::Compose {
3832                ref ty,
3833                ref components,
3834            } => {
3835                assert_eq!(*ty, vec_ty);
3836                let mut components_iter = components.iter().copied();
3837                assert_eq!(
3838                    global_expressions[components_iter.next().unwrap()],
3839                    Expression::Literal(Literal::F32(3.))
3840                );
3841                assert_eq!(
3842                    global_expressions[components_iter.next().unwrap()],
3843                    Expression::Literal(Literal::F32(4.))
3844                );
3845                assert_eq!(
3846                    global_expressions[components_iter.next().unwrap()],
3847                    Expression::Literal(Literal::F32(5.))
3848                );
3849                assert!(components_iter.next().is_none());
3850            }
3851            _ => panic!("Expected vector"),
3852        }
3853
3854        assert_eq!(
3855            global_expressions[res2],
3856            Expression::Literal(Literal::F32(5.))
3857        );
3858    }
3859
3860    #[test]
3861    fn compose_of_constants() {
3862        let mut types = UniqueArena::new();
3863        let mut constants = Arena::new();
3864        let overrides = Arena::new();
3865        let mut global_expressions = Arena::new();
3866
3867        let i32_ty = types.insert(
3868            Type {
3869                name: None,
3870                inner: TypeInner::Scalar(crate::Scalar::I32),
3871            },
3872            Default::default(),
3873        );
3874
3875        let vec2_i32_ty = types.insert(
3876            Type {
3877                name: None,
3878                inner: TypeInner::Vector {
3879                    size: VectorSize::Bi,
3880                    scalar: crate::Scalar::I32,
3881                },
3882            },
3883            Default::default(),
3884        );
3885
3886        let h = constants.append(
3887            Constant {
3888                name: None,
3889                ty: i32_ty,
3890                init: global_expressions
3891                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3892            },
3893            Default::default(),
3894        );
3895
3896        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3897
3898        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3899        let mut solver = ConstantEvaluator {
3900            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3901            types: &mut types,
3902            constants: &constants,
3903            overrides: &overrides,
3904            expressions: &mut global_expressions,
3905            expression_kind_tracker,
3906            layouter: &mut crate::proc::Layouter::default(),
3907        };
3908
3909        let solved_compose = solver
3910            .try_eval_and_append(
3911                Expression::Compose {
3912                    ty: vec2_i32_ty,
3913                    components: vec![h_expr, h_expr],
3914                },
3915                Default::default(),
3916            )
3917            .unwrap();
3918        let solved_negate = solver
3919            .try_eval_and_append(
3920                Expression::Unary {
3921                    op: UnaryOperator::Negate,
3922                    expr: solved_compose,
3923                },
3924                Default::default(),
3925            )
3926            .unwrap();
3927
3928        let pass = match global_expressions[solved_negate] {
3929            Expression::Compose { ty, ref components } => {
3930                ty == vec2_i32_ty
3931                    && components.iter().all(|&component| {
3932                        let component = &global_expressions[component];
3933                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3934                    })
3935            }
3936            _ => false,
3937        };
3938        if !pass {
3939            panic!("unexpected evaluation result")
3940        }
3941    }
3942
3943    #[test]
3944    fn splat_of_constant() {
3945        let mut types = UniqueArena::new();
3946        let mut constants = Arena::new();
3947        let overrides = Arena::new();
3948        let mut global_expressions = Arena::new();
3949
3950        let i32_ty = types.insert(
3951            Type {
3952                name: None,
3953                inner: TypeInner::Scalar(crate::Scalar::I32),
3954            },
3955            Default::default(),
3956        );
3957
3958        let vec2_i32_ty = types.insert(
3959            Type {
3960                name: None,
3961                inner: TypeInner::Vector {
3962                    size: VectorSize::Bi,
3963                    scalar: crate::Scalar::I32,
3964                },
3965            },
3966            Default::default(),
3967        );
3968
3969        let h = constants.append(
3970            Constant {
3971                name: None,
3972                ty: i32_ty,
3973                init: global_expressions
3974                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3975            },
3976            Default::default(),
3977        );
3978
3979        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3980
3981        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3982        let mut solver = ConstantEvaluator {
3983            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3984            types: &mut types,
3985            constants: &constants,
3986            overrides: &overrides,
3987            expressions: &mut global_expressions,
3988            expression_kind_tracker,
3989            layouter: &mut crate::proc::Layouter::default(),
3990        };
3991
3992        let solved_compose = solver
3993            .try_eval_and_append(
3994                Expression::Splat {
3995                    size: VectorSize::Bi,
3996                    value: h_expr,
3997                },
3998                Default::default(),
3999            )
4000            .unwrap();
4001        let solved_negate = solver
4002            .try_eval_and_append(
4003                Expression::Unary {
4004                    op: UnaryOperator::Negate,
4005                    expr: solved_compose,
4006                },
4007                Default::default(),
4008            )
4009            .unwrap();
4010
4011        let pass = match global_expressions[solved_negate] {
4012            Expression::Compose { ty, ref components } => {
4013                ty == vec2_i32_ty
4014                    && components.iter().all(|&component| {
4015                        let component = &global_expressions[component];
4016                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4017                    })
4018            }
4019            _ => false,
4020        };
4021        if !pass {
4022            panic!("unexpected evaluation result")
4023        }
4024    }
4025
4026    #[test]
4027    fn splat_of_zero_value() {
4028        let mut types = UniqueArena::new();
4029        let constants = Arena::new();
4030        let overrides = Arena::new();
4031        let mut global_expressions = Arena::new();
4032
4033        let f32_ty = types.insert(
4034            Type {
4035                name: None,
4036                inner: TypeInner::Scalar(crate::Scalar::F32),
4037            },
4038            Default::default(),
4039        );
4040
4041        let vec2_f32_ty = types.insert(
4042            Type {
4043                name: None,
4044                inner: TypeInner::Vector {
4045                    size: VectorSize::Bi,
4046                    scalar: crate::Scalar::F32,
4047                },
4048            },
4049            Default::default(),
4050        );
4051
4052        let five =
4053            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4054        let five_splat = global_expressions.append(
4055            Expression::Splat {
4056                size: VectorSize::Bi,
4057                value: five,
4058            },
4059            Default::default(),
4060        );
4061        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4062        let zero_splat = global_expressions.append(
4063            Expression::Splat {
4064                size: VectorSize::Bi,
4065                value: zero,
4066            },
4067            Default::default(),
4068        );
4069
4070        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4071        let mut solver = ConstantEvaluator {
4072            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4073            types: &mut types,
4074            constants: &constants,
4075            overrides: &overrides,
4076            expressions: &mut global_expressions,
4077            expression_kind_tracker,
4078            layouter: &mut crate::proc::Layouter::default(),
4079        };
4080
4081        let solved_add = solver
4082            .try_eval_and_append(
4083                Expression::Binary {
4084                    op: crate::BinaryOperator::Add,
4085                    left: zero_splat,
4086                    right: five_splat,
4087                },
4088                Default::default(),
4089            )
4090            .unwrap();
4091
4092        let pass = match global_expressions[solved_add] {
4093            Expression::Compose { ty, ref components } => {
4094                ty == vec2_f32_ty
4095                    && components.iter().all(|&component| {
4096                        let component = &global_expressions[component];
4097                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4098                    })
4099            }
4100            _ => false,
4101        };
4102        if !pass {
4103            panic!("unexpected evaluation result")
4104        }
4105    }
4106}