naga/proc/
constant_evaluator.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14    arena::{Arena, Handle, HandleVec, UniqueArena},
15    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16    ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
23/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
24///
25/// Technique stolen directly from
26/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
27macro_rules! with_dollar_sign {
28    ($($body:tt)*) => {
29        macro_rules! __with_dollar_sign { $($body)* }
30        __with_dollar_sign!($);
31    }
32}
33
34macro_rules! gen_component_wise_extractor {
35    (
36        $ident:ident -> $target:ident,
37        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39    ) => {
40        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
41        #[derive(Debug)]
42        #[cfg_attr(test, derive(PartialEq))]
43        enum $target<const N: usize> {
44            $(
45                #[doc = concat!(
46                    "Maps to [`Literal::",
47                    stringify!($literal),
48                    "`]",
49                )]
50                $mapping([$ty; N]),
51            )+
52        }
53
54        impl From<$target<1>> for Expression {
55            fn from(value: $target<1>) -> Self {
56                match value {
57                    $(
58                        $target::$mapping([value]) => {
59                            Expression::Literal(Literal::$literal(value))
60                        }
61                    )+
62                }
63            }
64        }
65
66        #[doc = concat!(
67            "Attempts to evaluate multiple `exprs` as a combined [`",
68            stringify!($target),
69            "`] to pass to `handler`. ",
70        )]
71        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
72        /// component of each vector.
73        ///
74        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
75        /// same length, a new vector expression is registered, composed of each component emitted
76        /// by `handler`.
77        fn $ident<const N: usize, const M: usize, F>(
78            eval: &mut ConstantEvaluator<'_>,
79            span: Span,
80            exprs: [Handle<Expression>; N],
81            mut handler: F,
82        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83        where
84            $target<M>: Into<Expression>,
85            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86        {
87            assert!(N > 0);
88            let err = ConstantEvaluatorError::InvalidMathArg;
89            let mut exprs = exprs.into_iter();
90
91            macro_rules! sanitize {
92                ($expr:expr) => {
93                    eval.eval_zero_value_and_splat($expr, span)
94                        .map(|expr| &eval.expressions[expr])
95                };
96            }
97
98            let new_expr = match sanitize!(exprs.next().unwrap())? {
99                $(
100                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101                        .chain(exprs.map(|expr| {
102                            sanitize!(expr).and_then(|expr| match expr {
103                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
104                                _ => Err(err.clone()),
105                            })
106                        }))
107                        .collect::<Result<ArrayVec<_, N>, _>>()
108                        .map(|a| a.into_inner().unwrap())
109                        .map($target::$mapping)
110                        .and_then(|comps| Ok(handler(comps)?.into())),
111                )+
112                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113                    &TypeInner::Vector { size, scalar } => match scalar.kind {
114                        $(ScalarKind::$scalar_kind)|* => {
115                            let first_ty = ty;
116                            let mut component_groups =
117                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118                            component_groups.push(crate::proc::flatten_compose(
119                                first_ty,
120                                components,
121                                eval.expressions,
122                                eval.types,
123                            ).collect());
124                            component_groups.extend(
125                                exprs
126                                    .map(|expr| {
127                                        sanitize!(expr).and_then(|expr| match expr {
128                                            &Expression::Compose { ty, ref components }
129                                                if &eval.types[ty].inner
130                                                    == &eval.types[first_ty].inner =>
131                                            {
132                                                Ok(crate::proc::flatten_compose(
133                                                    ty,
134                                                    components,
135                                                    eval.expressions,
136                                                    eval.types,
137                                                ).collect())
138                                            }
139                                            _ => Err(err.clone()),
140                                        })
141                                    })
142                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143                                    )?,
144                            );
145                            let component_groups = component_groups.into_inner().unwrap();
146                            let mut new_components =
147                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148                            for idx in 0..(size as u8).into() {
149                                let group = component_groups
150                                    .iter()
151                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152                                    .collect::<Result<ArrayVec<_, N>, _>>()?
153                                    .into_inner()
154                                    .unwrap();
155                                new_components.push($ident(
156                                    eval,
157                                    span,
158                                    group,
159                                    handler.clone(),
160                                )?);
161                            }
162                            Ok(Expression::Compose {
163                                ty: first_ty,
164                                components: new_components.into_iter().collect(),
165                            })
166                        }
167                        _ => return Err(err),
168                    },
169                    _ => return Err(err),
170                },
171                _ => return Err(err),
172            }?;
173            eval.register_evaluated_expr(new_expr, span)
174        }
175
176        with_dollar_sign! {
177            ($d:tt) => {
178                #[allow(unused)]
179                #[doc = concat!(
180                    "A convenience macro for using the same RHS for each [`",
181                    stringify!($target),
182                    "`] variant in a call to [`",
183                    stringify!($ident),
184                    "`].",
185                )]
186                macro_rules! $ident {
187                    (
188                        $eval:expr,
189                        $span:expr,
190                        [$d ($d expr:expr),+ $d (,)?],
191                        |$d ($d arg:ident),+| $d tt:tt
192                    ) => {
193                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
194                            $(
195                                $target::$mapping([$d ($d arg),+]) => {
196                                    let res = $d tt;
197                                    Result::map(res, $target::$mapping)
198                                },
199                            )+
200                        })
201                    };
202                }
203            };
204        }
205    };
206}
207
208gen_component_wise_extractor! {
209    component_wise_scalar -> Scalar,
210    literals: [
211        AbstractFloat => AbstractFloat: f64,
212        F32 => F32: f32,
213        F16 => F16: f16,
214        AbstractInt => AbstractInt: i64,
215        U32 => U32: u32,
216        I32 => I32: i32,
217        U64 => U64: u64,
218        I64 => I64: i64,
219    ],
220    scalar_kinds: [
221        Float,
222        AbstractFloat,
223        Sint,
224        Uint,
225        AbstractInt,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_float -> Float,
231    literals: [
232        AbstractFloat => Abstract: f64,
233        F32 => F32: f32,
234        F16 => F16: f16,
235    ],
236    scalar_kinds: [
237        Float,
238        AbstractFloat,
239    ],
240}
241
242gen_component_wise_extractor! {
243    component_wise_concrete_int -> ConcreteInt,
244    literals: [
245        U32 => U32: u32,
246        I32 => I32: i32,
247    ],
248    scalar_kinds: [
249        Sint,
250        Uint,
251    ],
252}
253
254gen_component_wise_extractor! {
255    component_wise_signed -> Signed,
256    literals: [
257        AbstractFloat => AbstractFloat: f64,
258        AbstractInt => AbstractInt: i64,
259        F32 => F32: f32,
260        F16 => F16: f16,
261        I32 => I32: i32,
262    ],
263    scalar_kinds: [
264        Sint,
265        AbstractInt,
266        Float,
267        AbstractFloat,
268    ],
269}
270
271#[derive(Debug)]
272enum Behavior<'a> {
273    Wgsl(WgslRestrictions<'a>),
274    Glsl(GlslRestrictions<'a>),
275}
276
277impl Behavior<'_> {
278    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
279    const fn has_runtime_restrictions(&self) -> bool {
280        matches!(
281            self,
282            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
283                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
284        )
285    }
286}
287
288/// A context for evaluating constant expressions.
289///
290/// A `ConstantEvaluator` points at an expression arena to which it can append
291/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
292/// of Naga [`Expression`] you like, and if its value can be computed at compile
293/// time, `try_eval_and_append` appends an expression representing the computed
294/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
295/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
296///
297/// A `ConstantEvaluator` also holds whatever information we need to carry out
298/// that evaluation: types, other constants, and so on.
299///
300/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
301/// [`Compose`]: Expression::Compose
302/// [`ZeroValue`]: Expression::ZeroValue
303/// [`Literal`]: Expression::Literal
304/// [`Swizzle`]: Expression::Swizzle
305#[derive(Debug)]
306pub struct ConstantEvaluator<'a> {
307    /// Which language's evaluation rules we should follow.
308    behavior: Behavior<'a>,
309
310    /// The module's type arena.
311    ///
312    /// Because expressions like [`Splat`] contain type handles, we need to be
313    /// able to add new types to produce those expressions.
314    ///
315    /// [`Splat`]: Expression::Splat
316    types: &'a mut UniqueArena<Type>,
317
318    /// The module's constant arena.
319    constants: &'a Arena<Constant>,
320
321    /// The module's override arena.
322    overrides: &'a Arena<Override>,
323
324    /// The arena to which we are contributing expressions.
325    expressions: &'a mut Arena<Expression>,
326
327    /// Tracks the constness of expressions residing in [`Self::expressions`]
328    expression_kind_tracker: &'a mut ExpressionKindTracker,
329
330    layouter: &'a mut crate::proc::Layouter,
331}
332
333#[derive(Debug)]
334enum WgslRestrictions<'a> {
335    /// - const-expressions will be evaluated and inserted in the arena
336    Const(Option<FunctionLocalData<'a>>),
337    /// - const-expressions will be evaluated and inserted in the arena
338    /// - override-expressions will be inserted in the arena
339    Override,
340    /// - const-expressions will be evaluated and inserted in the arena
341    /// - override-expressions will be inserted in the arena
342    /// - runtime-expressions will be inserted in the arena
343    Runtime(FunctionLocalData<'a>),
344}
345
346#[derive(Debug)]
347enum GlslRestrictions<'a> {
348    /// - const-expressions will be evaluated and inserted in the arena
349    Const,
350    /// - const-expressions will be evaluated and inserted in the arena
351    /// - override-expressions will be inserted in the arena
352    /// - runtime-expressions will be inserted in the arena
353    Runtime(FunctionLocalData<'a>),
354}
355
356#[derive(Debug)]
357struct FunctionLocalData<'a> {
358    /// Global constant expressions
359    global_expressions: &'a Arena<Expression>,
360    emitter: &'a mut super::Emitter,
361    block: &'a mut crate::Block,
362}
363
364#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
365pub enum ExpressionKind {
366    Const,
367    Override,
368    Runtime,
369}
370
371#[derive(Debug)]
372pub struct ExpressionKindTracker {
373    inner: HandleVec<Expression, ExpressionKind>,
374}
375
376impl ExpressionKindTracker {
377    pub const fn new() -> Self {
378        Self {
379            inner: HandleVec::new(),
380        }
381    }
382
383    /// Forces the the expression to not be const
384    pub fn force_non_const(&mut self, value: Handle<Expression>) {
385        self.inner[value] = ExpressionKind::Runtime;
386    }
387
388    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
389        self.inner.insert(value, expr_type);
390    }
391
392    pub fn is_const(&self, h: Handle<Expression>) -> bool {
393        matches!(self.type_of(h), ExpressionKind::Const)
394    }
395
396    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
397        matches!(
398            self.type_of(h),
399            ExpressionKind::Const | ExpressionKind::Override
400        )
401    }
402
403    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
404        self.inner[value]
405    }
406
407    pub fn from_arena(arena: &Arena<Expression>) -> Self {
408        let mut tracker = Self {
409            inner: HandleVec::with_capacity(arena.len()),
410        };
411        for (handle, expr) in arena.iter() {
412            tracker
413                .inner
414                .insert(handle, tracker.type_of_with_expr(expr));
415        }
416        tracker
417    }
418
419    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
420        match *expr {
421            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
422                ExpressionKind::Const
423            }
424            Expression::Override(_) => ExpressionKind::Override,
425            Expression::Compose { ref components, .. } => {
426                let mut expr_type = ExpressionKind::Const;
427                for component in components {
428                    expr_type = expr_type.max(self.type_of(*component))
429                }
430                expr_type
431            }
432            Expression::Splat { value, .. } => self.type_of(value),
433            Expression::AccessIndex { base, .. } => self.type_of(base),
434            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
435            Expression::Swizzle { vector, .. } => self.type_of(vector),
436            Expression::Unary { expr, .. } => self.type_of(expr),
437            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
438            Expression::Math {
439                arg,
440                arg1,
441                arg2,
442                arg3,
443                ..
444            } => self
445                .type_of(arg)
446                .max(
447                    arg1.map(|arg| self.type_of(arg))
448                        .unwrap_or(ExpressionKind::Const),
449                )
450                .max(
451                    arg2.map(|arg| self.type_of(arg))
452                        .unwrap_or(ExpressionKind::Const),
453                )
454                .max(
455                    arg3.map(|arg| self.type_of(arg))
456                        .unwrap_or(ExpressionKind::Const),
457                ),
458            Expression::As { expr, .. } => self.type_of(expr),
459            Expression::Select {
460                condition,
461                accept,
462                reject,
463            } => self
464                .type_of(condition)
465                .max(self.type_of(accept))
466                .max(self.type_of(reject)),
467            Expression::Relational { argument, .. } => self.type_of(argument),
468            Expression::ArrayLength(expr) => self.type_of(expr),
469            _ => ExpressionKind::Runtime,
470        }
471    }
472}
473
474#[derive(Clone, Debug, thiserror::Error)]
475#[cfg_attr(test, derive(PartialEq))]
476pub enum ConstantEvaluatorError {
477    #[error("Constants cannot access function arguments")]
478    FunctionArg,
479    #[error("Constants cannot access global variables")]
480    GlobalVariable,
481    #[error("Constants cannot access local variables")]
482    LocalVariable,
483    #[error("Cannot get the array length of a non array type")]
484    InvalidArrayLengthArg,
485    #[error("Constants cannot get the array length of a dynamically sized array")]
486    ArrayLengthDynamic,
487    #[error("Cannot call arrayLength on array sized by override-expression")]
488    ArrayLengthOverridden,
489    #[error("Constants cannot call functions")]
490    Call,
491    #[error("Constants don't support workGroupUniformLoad")]
492    WorkGroupUniformLoadResult,
493    #[error("Constants don't support atomic functions")]
494    Atomic,
495    #[error("Constants don't support derivative functions")]
496    Derivative,
497    #[error("Constants don't support load expressions")]
498    Load,
499    #[error("Constants don't support image expressions")]
500    ImageExpression,
501    #[error("Constants don't support ray query expressions")]
502    RayQueryExpression,
503    #[error("Constants don't support subgroup expressions")]
504    SubgroupExpression,
505    #[error("Cannot access the type")]
506    InvalidAccessBase,
507    #[error("Cannot access at the index")]
508    InvalidAccessIndex,
509    #[error("Cannot access with index of type")]
510    InvalidAccessIndexTy,
511    #[error("Constants don't support array length expressions")]
512    ArrayLength,
513    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
514    InvalidCastArg { from: String, to: String },
515    #[error("Cannot apply the unary op to the argument")]
516    InvalidUnaryOpArg,
517    #[error("Cannot apply the binary op to the arguments")]
518    InvalidBinaryOpArgs,
519    #[error("Cannot apply math function to type")]
520    InvalidMathArg,
521    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
522    InvalidMathArgCount(crate::MathFunction, usize, usize),
523    #[error("Cannot apply relational function to type")]
524    InvalidRelationalArg(RelationalFunction),
525    #[error("value of `low` is greater than `high` for clamp built-in function")]
526    InvalidClamp,
527    #[error("Constructor expects {expected} components, found {actual}")]
528    InvalidVectorComposeLength { expected: usize, actual: usize },
529    #[error("Constructor must only contain vector or scalar arguments")]
530    InvalidVectorComposeComponent,
531    #[error("Splat is defined only on scalar values")]
532    SplatScalarOnly,
533    #[error("Can only swizzle vector constants")]
534    SwizzleVectorOnly,
535    #[error("swizzle component not present in source expression")]
536    SwizzleOutOfBounds,
537    #[error("Type is not constructible")]
538    TypeNotConstructible,
539    #[error("Subexpression(s) are not constant")]
540    SubexpressionsAreNotConstant,
541    #[error("Not implemented as constant expression: {0}")]
542    NotImplemented(String),
543    #[error("{0} operation overflowed")]
544    Overflow(String),
545    #[error(
546        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
547    )]
548    AutomaticConversionLossy {
549        value: String,
550        to_type: &'static str,
551    },
552    #[error("Division by zero")]
553    DivisionByZero,
554    #[error("Remainder by zero")]
555    RemainderByZero,
556    #[error("RHS of shift operation is greater than or equal to 32")]
557    ShiftedMoreThan32Bits,
558    #[error(transparent)]
559    Literal(#[from] crate::valid::LiteralError),
560    #[error("Can't use pipeline-overridable constants in const-expressions")]
561    Override,
562    #[error("Unexpected runtime-expression")]
563    RuntimeExpr,
564    #[error("Unexpected override-expression")]
565    OverrideExpr,
566    #[error("Expected boolean expression for condition argument of `select`, got something else")]
567    SelectScalarConditionNotABool,
568    #[error(
569        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
570        reject,
571        accept
572    )]
573    SelectVecRejectAcceptSizeMismatch {
574        reject: crate::VectorSize,
575        accept: crate::VectorSize,
576    },
577    #[error("Expected boolean vector for condition arg., got something else")]
578    SelectConditionNotAVecBool,
579    #[error(
580        "Expected same number of vector components between condition, accept, and reject args., got something else",
581    )]
582    SelectConditionVecSizeMismatch,
583    #[error(
584        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
585    )]
586    SelectAcceptRejectTypeMismatch,
587}
588
589impl<'a> ConstantEvaluator<'a> {
590    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
591    /// constant expression arena.
592    ///
593    /// Report errors according to WGSL's rules for constant evaluation.
594    pub fn for_wgsl_module(
595        module: &'a mut crate::Module,
596        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
597        layouter: &'a mut crate::proc::Layouter,
598        in_override_ctx: bool,
599    ) -> Self {
600        Self::for_module(
601            Behavior::Wgsl(if in_override_ctx {
602                WgslRestrictions::Override
603            } else {
604                WgslRestrictions::Const(None)
605            }),
606            module,
607            global_expression_kind_tracker,
608            layouter,
609        )
610    }
611
612    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
613    /// constant expression arena.
614    ///
615    /// Report errors according to GLSL's rules for constant evaluation.
616    pub fn for_glsl_module(
617        module: &'a mut crate::Module,
618        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
619        layouter: &'a mut crate::proc::Layouter,
620    ) -> Self {
621        Self::for_module(
622            Behavior::Glsl(GlslRestrictions::Const),
623            module,
624            global_expression_kind_tracker,
625            layouter,
626        )
627    }
628
629    fn for_module(
630        behavior: Behavior<'a>,
631        module: &'a mut crate::Module,
632        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
633        layouter: &'a mut crate::proc::Layouter,
634    ) -> Self {
635        Self {
636            behavior,
637            types: &mut module.types,
638            constants: &module.constants,
639            overrides: &module.overrides,
640            expressions: &mut module.global_expressions,
641            expression_kind_tracker: global_expression_kind_tracker,
642            layouter,
643        }
644    }
645
646    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
647    /// expression arena.
648    ///
649    /// Report errors according to WGSL's rules for constant evaluation.
650    pub fn for_wgsl_function(
651        module: &'a mut crate::Module,
652        expressions: &'a mut Arena<Expression>,
653        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
654        layouter: &'a mut crate::proc::Layouter,
655        emitter: &'a mut super::Emitter,
656        block: &'a mut crate::Block,
657        is_const: bool,
658    ) -> Self {
659        let local_data = FunctionLocalData {
660            global_expressions: &module.global_expressions,
661            emitter,
662            block,
663        };
664        Self {
665            behavior: Behavior::Wgsl(if is_const {
666                WgslRestrictions::Const(Some(local_data))
667            } else {
668                WgslRestrictions::Runtime(local_data)
669            }),
670            types: &mut module.types,
671            constants: &module.constants,
672            overrides: &module.overrides,
673            expressions,
674            expression_kind_tracker: local_expression_kind_tracker,
675            layouter,
676        }
677    }
678
679    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
680    /// expression arena.
681    ///
682    /// Report errors according to GLSL's rules for constant evaluation.
683    pub fn for_glsl_function(
684        module: &'a mut crate::Module,
685        expressions: &'a mut Arena<Expression>,
686        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
687        layouter: &'a mut crate::proc::Layouter,
688        emitter: &'a mut super::Emitter,
689        block: &'a mut crate::Block,
690    ) -> Self {
691        Self {
692            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
693                global_expressions: &module.global_expressions,
694                emitter,
695                block,
696            })),
697            types: &mut module.types,
698            constants: &module.constants,
699            overrides: &module.overrides,
700            expressions,
701            expression_kind_tracker: local_expression_kind_tracker,
702            layouter,
703        }
704    }
705
706    pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
707        crate::proc::GlobalCtx {
708            types: self.types,
709            constants: self.constants,
710            overrides: self.overrides,
711            global_expressions: match self.function_local_data() {
712                Some(data) => data.global_expressions,
713                None => self.expressions,
714            },
715        }
716    }
717
718    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
719        if !self.expression_kind_tracker.is_const(expr) {
720            log::debug!("check: SubexpressionsAreNotConstant");
721            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
722        }
723        Ok(())
724    }
725
726    fn check_and_get(
727        &mut self,
728        expr: Handle<Expression>,
729    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
730        match self.expressions[expr] {
731            Expression::Constant(c) => {
732                // Are we working in a function's expression arena, or the
733                // module's constant expression arena?
734                if let Some(function_local_data) = self.function_local_data() {
735                    // Deep-copy the constant's value into our arena.
736                    self.copy_from(
737                        self.constants[c].init,
738                        function_local_data.global_expressions,
739                    )
740                } else {
741                    // "See through" the constant and use its initializer.
742                    Ok(self.constants[c].init)
743                }
744            }
745            _ => {
746                self.check(expr)?;
747                Ok(expr)
748            }
749        }
750    }
751
752    /// Try to evaluate `expr` at compile time.
753    ///
754    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
755    /// we can determine its value at compile time, we append an expression
756    /// representing its value - a tree of [`Literal`], [`Compose`],
757    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
758    /// `self` contributes to.
759    ///
760    /// If `expr`'s value cannot be determined at compile time, and `self` is
761    /// contributing to some function's expression arena, then append `expr` to
762    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
763    /// contributing to the module's constant expression arena; since `expr`'s
764    /// value is not a constant, return an error.
765    ///
766    /// We only consider `expr` itself, without recursing into its operands. Its
767    /// operands must all have been produced by prior calls to
768    /// `try_eval_and_append`, to ensure that they have already been reduced to
769    /// an evaluated form if possible.
770    ///
771    /// [`Literal`]: Expression::Literal
772    /// [`Compose`]: Expression::Compose
773    /// [`ZeroValue`]: Expression::ZeroValue
774    /// [`Swizzle`]: Expression::Swizzle
775    pub fn try_eval_and_append(
776        &mut self,
777        expr: Expression,
778        span: Span,
779    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
780        match self.expression_kind_tracker.type_of_with_expr(&expr) {
781            ExpressionKind::Const => {
782                let eval_result = self.try_eval_and_append_impl(&expr, span);
783                // We should be able to evaluate `Const` expressions at this
784                // point. If we failed to, then that probably means we just
785                // haven't implemented that part of constant evaluation. Work
786                // around this by simply emitting it as a run-time expression.
787                if self.behavior.has_runtime_restrictions()
788                    && matches!(
789                        eval_result,
790                        Err(ConstantEvaluatorError::NotImplemented(_)
791                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
792                    )
793                {
794                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
795                } else {
796                    eval_result
797                }
798            }
799            ExpressionKind::Override => match self.behavior {
800                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
801                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
802                }
803                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
804                    Err(ConstantEvaluatorError::OverrideExpr)
805                }
806                Behavior::Glsl(_) => {
807                    unreachable!()
808                }
809            },
810            ExpressionKind::Runtime => {
811                if self.behavior.has_runtime_restrictions() {
812                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
813                } else {
814                    Err(ConstantEvaluatorError::RuntimeExpr)
815                }
816            }
817        }
818    }
819
820    /// Is the [`Self::expressions`] arena the global module expression arena?
821    const fn is_global_arena(&self) -> bool {
822        matches!(
823            self.behavior,
824            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
825                | Behavior::Glsl(GlslRestrictions::Const)
826        )
827    }
828
829    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
830        match self.behavior {
831            Behavior::Wgsl(
832                WgslRestrictions::Runtime(ref function_local_data)
833                | WgslRestrictions::Const(Some(ref function_local_data)),
834            )
835            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
836                Some(function_local_data)
837            }
838            _ => None,
839        }
840    }
841
842    fn try_eval_and_append_impl(
843        &mut self,
844        expr: &Expression,
845        span: Span,
846    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
847        log::trace!("try_eval_and_append: {expr:?}");
848        match *expr {
849            Expression::Constant(c) if self.is_global_arena() => {
850                // "See through" the constant and use its initializer.
851                // This is mainly done to avoid having constants pointing to other constants.
852                Ok(self.constants[c].init)
853            }
854            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
855            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
856                self.register_evaluated_expr(expr.clone(), span)
857            }
858            Expression::Compose { ty, ref components } => {
859                let components = components
860                    .iter()
861                    .map(|component| self.check_and_get(*component))
862                    .collect::<Result<Vec<_>, _>>()?;
863                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
864            }
865            Expression::Splat { size, value } => {
866                let value = self.check_and_get(value)?;
867                self.register_evaluated_expr(Expression::Splat { size, value }, span)
868            }
869            Expression::AccessIndex { base, index } => {
870                let base = self.check_and_get(base)?;
871
872                self.access(base, index as usize, span)
873            }
874            Expression::Access { base, index } => {
875                let base = self.check_and_get(base)?;
876                let index = self.check_and_get(index)?;
877
878                self.access(base, self.constant_index(index)?, span)
879            }
880            Expression::Swizzle {
881                size,
882                vector,
883                pattern,
884            } => {
885                let vector = self.check_and_get(vector)?;
886
887                self.swizzle(size, span, vector, pattern)
888            }
889            Expression::Unary { expr, op } => {
890                let expr = self.check_and_get(expr)?;
891
892                self.unary_op(op, expr, span)
893            }
894            Expression::Binary { left, right, op } => {
895                let left = self.check_and_get(left)?;
896                let right = self.check_and_get(right)?;
897
898                self.binary_op(op, left, right, span)
899            }
900            Expression::Math {
901                fun,
902                arg,
903                arg1,
904                arg2,
905                arg3,
906            } => {
907                let arg = self.check_and_get(arg)?;
908                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
909                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
910                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
911
912                self.math(arg, arg1, arg2, arg3, fun, span)
913            }
914            Expression::As {
915                convert,
916                expr,
917                kind,
918            } => {
919                let expr = self.check_and_get(expr)?;
920
921                match convert {
922                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
923                    None => Err(ConstantEvaluatorError::NotImplemented(
924                        "bitcast built-in function".into(),
925                    )),
926                }
927            }
928            Expression::Select {
929                reject,
930                accept,
931                condition,
932            } => {
933                let mut arg = |expr| self.check_and_get(expr);
934
935                let reject = arg(reject)?;
936                let accept = arg(accept)?;
937                let condition = arg(condition)?;
938
939                self.select(reject, accept, condition, span)
940            }
941            Expression::Relational { fun, argument } => {
942                let argument = self.check_and_get(argument)?;
943                self.relational(fun, argument, span)
944            }
945            Expression::ArrayLength(expr) => match self.behavior {
946                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
947                Behavior::Glsl(_) => {
948                    let expr = self.check_and_get(expr)?;
949                    self.array_length(expr, span)
950                }
951            },
952            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
953            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
954            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
955            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
956            Expression::WorkGroupUniformLoadResult { .. } => {
957                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
958            }
959            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
960            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
961            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
962            Expression::ImageSample { .. }
963            | Expression::ImageLoad { .. }
964            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
965            Expression::RayQueryProceedResult
966            | Expression::RayQueryGetIntersection { .. }
967            | Expression::RayQueryVertexPositions { .. } => {
968                Err(ConstantEvaluatorError::RayQueryExpression)
969            }
970            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
971            Expression::SubgroupOperationResult { .. } => {
972                Err(ConstantEvaluatorError::SubgroupExpression)
973            }
974        }
975    }
976
977    /// Splat `value` to `size`, without using [`Splat`] expressions.
978    ///
979    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
980    /// build a vector with the given `size` whose components are all
981    /// `value`.
982    ///
983    /// Use `span` as the span of the inserted expressions and
984    /// resulting types.
985    ///
986    /// [`Splat`]: Expression::Splat
987    /// [`Compose`]: Expression::Compose
988    /// [`ZeroValue`]: Expression::ZeroValue
989    fn splat(
990        &mut self,
991        value: Handle<Expression>,
992        size: crate::VectorSize,
993        span: Span,
994    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
995        match self.expressions[value] {
996            Expression::Literal(literal) => {
997                let scalar = literal.scalar();
998                let ty = self.types.insert(
999                    Type {
1000                        name: None,
1001                        inner: TypeInner::Vector { size, scalar },
1002                    },
1003                    span,
1004                );
1005                let expr = Expression::Compose {
1006                    ty,
1007                    components: vec![value; size as usize],
1008                };
1009                self.register_evaluated_expr(expr, span)
1010            }
1011            Expression::ZeroValue(ty) => {
1012                let inner = match self.types[ty].inner {
1013                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1014                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1015                };
1016                let res_ty = self.types.insert(Type { name: None, inner }, span);
1017                let expr = Expression::ZeroValue(res_ty);
1018                self.register_evaluated_expr(expr, span)
1019            }
1020            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1021        }
1022    }
1023
1024    fn swizzle(
1025        &mut self,
1026        size: crate::VectorSize,
1027        span: Span,
1028        src_constant: Handle<Expression>,
1029        pattern: [crate::SwizzleComponent; 4],
1030    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1031        let mut get_dst_ty = |ty| match self.types[ty].inner {
1032            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1033                Type {
1034                    name: None,
1035                    inner: TypeInner::Vector { size, scalar },
1036                },
1037                span,
1038            )),
1039            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1040        };
1041
1042        match self.expressions[src_constant] {
1043            Expression::ZeroValue(ty) => {
1044                let dst_ty = get_dst_ty(ty)?;
1045                let expr = Expression::ZeroValue(dst_ty);
1046                self.register_evaluated_expr(expr, span)
1047            }
1048            Expression::Splat { value, .. } => {
1049                let expr = Expression::Splat { size, value };
1050                self.register_evaluated_expr(expr, span)
1051            }
1052            Expression::Compose { ty, ref components } => {
1053                let dst_ty = get_dst_ty(ty)?;
1054
1055                let mut flattened = [src_constant; 4]; // dummy value
1056                let len =
1057                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1058                        .zip(flattened.iter_mut())
1059                        .map(|(component, elt)| *elt = component)
1060                        .count();
1061                let flattened = &flattened[..len];
1062
1063                let swizzled_components = pattern[..size as usize]
1064                    .iter()
1065                    .map(|&sc| {
1066                        let sc = sc as usize;
1067                        if let Some(elt) = flattened.get(sc) {
1068                            Ok(*elt)
1069                        } else {
1070                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1071                        }
1072                    })
1073                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1074                let expr = Expression::Compose {
1075                    ty: dst_ty,
1076                    components: swizzled_components,
1077                };
1078                self.register_evaluated_expr(expr, span)
1079            }
1080            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1081        }
1082    }
1083
1084    fn math(
1085        &mut self,
1086        arg: Handle<Expression>,
1087        arg1: Option<Handle<Expression>>,
1088        arg2: Option<Handle<Expression>>,
1089        arg3: Option<Handle<Expression>>,
1090        fun: crate::MathFunction,
1091        span: Span,
1092    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1093        let expected = fun.argument_count();
1094        let given = Some(arg)
1095            .into_iter()
1096            .chain(arg1)
1097            .chain(arg2)
1098            .chain(arg3)
1099            .count();
1100        if expected != given {
1101            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1102                fun, expected, given,
1103            ));
1104        }
1105
1106        // NOTE: We try to match the declaration order of `MathFunction` here.
1107        match fun {
1108            // comparison
1109            crate::MathFunction::Abs => {
1110                component_wise_scalar(self, span, [arg], |args| match args {
1111                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1112                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1113                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1114                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1115                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1116                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1117                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1118                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1119                })
1120            }
1121            crate::MathFunction::Min => {
1122                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1123                    Ok([e1.min(e2)])
1124                })
1125            }
1126            crate::MathFunction::Max => {
1127                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1128                    Ok([e1.max(e2)])
1129                })
1130            }
1131            crate::MathFunction::Clamp => {
1132                component_wise_scalar!(
1133                    self,
1134                    span,
1135                    [arg, arg1.unwrap(), arg2.unwrap()],
1136                    |e, low, high| {
1137                        if low > high {
1138                            Err(ConstantEvaluatorError::InvalidClamp)
1139                        } else {
1140                            Ok([e.clamp(low, high)])
1141                        }
1142                    }
1143                )
1144            }
1145            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1146                Float::F16([e]) => Ok(Float::F16(
1147                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1148                )),
1149                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1150                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1151            }),
1152
1153            // trigonometry
1154            crate::MathFunction::Cos => {
1155                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1156            }
1157            crate::MathFunction::Cosh => {
1158                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1159            }
1160            crate::MathFunction::Sin => {
1161                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1162            }
1163            crate::MathFunction::Sinh => {
1164                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1165            }
1166            crate::MathFunction::Tan => {
1167                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1168            }
1169            crate::MathFunction::Tanh => {
1170                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1171            }
1172            crate::MathFunction::Acos => {
1173                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1174            }
1175            crate::MathFunction::Asin => {
1176                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1177            }
1178            crate::MathFunction::Atan => {
1179                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1180            }
1181            crate::MathFunction::Atan2 => {
1182                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1183                    Ok([y.atan2(x)])
1184                })
1185            }
1186            crate::MathFunction::Asinh => {
1187                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1188            }
1189            crate::MathFunction::Acosh => {
1190                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1191            }
1192            crate::MathFunction::Atanh => {
1193                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1194            }
1195            crate::MathFunction::Radians => {
1196                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1197            }
1198            crate::MathFunction::Degrees => {
1199                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1200            }
1201
1202            // decomposition
1203            crate::MathFunction::Ceil => {
1204                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1205            }
1206            crate::MathFunction::Floor => {
1207                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1208            }
1209            crate::MathFunction::Round => {
1210                component_wise_float(self, span, [arg], |e| match e {
1211                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1212                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1213                    Float::F16([e]) => {
1214                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1215                        //
1216                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1217                        // which has licensing compatible with ours. See also
1218                        // <https://github.com/rust-lang/rust/issues/96710>.
1219                        //
1220                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1221                        fn round_ties_even(x: f64) -> f64 {
1222                            let i = x as i64;
1223                            let f = (x - i as f64).abs();
1224                            if f == 0.5 {
1225                                if i & 1 == 1 {
1226                                    // -1.5, 1.5, 3.5, ...
1227                                    (x.abs() + 0.5).copysign(x)
1228                                } else {
1229                                    (x.abs() - 0.5).copysign(x)
1230                                }
1231                            } else {
1232                                x.round()
1233                            }
1234                        }
1235
1236                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1237                    }
1238                })
1239            }
1240            crate::MathFunction::Fract => {
1241                component_wise_float!(self, span, [arg], |e| {
1242                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1243                    // here.
1244                    Ok([e - e.floor()])
1245                })
1246            }
1247            crate::MathFunction::Trunc => {
1248                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1249            }
1250
1251            // exponent
1252            crate::MathFunction::Exp => {
1253                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1254            }
1255            crate::MathFunction::Exp2 => {
1256                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1257            }
1258            crate::MathFunction::Log => {
1259                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1260            }
1261            crate::MathFunction::Log2 => {
1262                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1263            }
1264            crate::MathFunction::Pow => {
1265                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1266                    Ok([e1.powf(e2)])
1267                })
1268            }
1269
1270            // computational
1271            crate::MathFunction::Sign => {
1272                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1273            }
1274            crate::MathFunction::Fma => {
1275                component_wise_float!(
1276                    self,
1277                    span,
1278                    [arg, arg1.unwrap(), arg2.unwrap()],
1279                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1280                )
1281            }
1282            crate::MathFunction::Step => {
1283                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1284                    Float::Abstract([edge, x]) => {
1285                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1286                    }
1287                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1288                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1289                        f16::one()
1290                    } else {
1291                        f16::zero()
1292                    }])),
1293                })
1294            }
1295            crate::MathFunction::Sqrt => {
1296                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1297            }
1298            crate::MathFunction::InverseSqrt => {
1299                component_wise_float(self, span, [arg], |e| match e {
1300                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1301                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1302                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1303                })
1304            }
1305
1306            // bits
1307            crate::MathFunction::CountTrailingZeros => {
1308                component_wise_concrete_int!(self, span, [arg], |e| {
1309                    #[allow(clippy::useless_conversion)]
1310                    Ok([e
1311                        .trailing_zeros()
1312                        .try_into()
1313                        .expect("bit count overflowed 32 bits, somehow!?")])
1314                })
1315            }
1316            crate::MathFunction::CountLeadingZeros => {
1317                component_wise_concrete_int!(self, span, [arg], |e| {
1318                    #[allow(clippy::useless_conversion)]
1319                    Ok([e
1320                        .leading_zeros()
1321                        .try_into()
1322                        .expect("bit count overflowed 32 bits, somehow!?")])
1323                })
1324            }
1325            crate::MathFunction::CountOneBits => {
1326                component_wise_concrete_int!(self, span, [arg], |e| {
1327                    #[allow(clippy::useless_conversion)]
1328                    Ok([e
1329                        .count_ones()
1330                        .try_into()
1331                        .expect("bit count overflowed 32 bits, somehow!?")])
1332                })
1333            }
1334            crate::MathFunction::ReverseBits => {
1335                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1336            }
1337            crate::MathFunction::FirstTrailingBit => {
1338                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1339            }
1340            crate::MathFunction::FirstLeadingBit => {
1341                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1342            }
1343
1344            // vector
1345            crate::MathFunction::Dot4I8Packed => {
1346                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1347            }
1348            crate::MathFunction::Dot4U8Packed => {
1349                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1350            }
1351            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1352
1353            // unimplemented
1354            crate::MathFunction::Modf
1355            | crate::MathFunction::Frexp
1356            | crate::MathFunction::Ldexp
1357            | crate::MathFunction::Dot
1358            | crate::MathFunction::Outer
1359            | crate::MathFunction::Distance
1360            | crate::MathFunction::Length
1361            | crate::MathFunction::Normalize
1362            | crate::MathFunction::FaceForward
1363            | crate::MathFunction::Reflect
1364            | crate::MathFunction::Refract
1365            | crate::MathFunction::Mix
1366            | crate::MathFunction::SmoothStep
1367            | crate::MathFunction::Inverse
1368            | crate::MathFunction::Transpose
1369            | crate::MathFunction::Determinant
1370            | crate::MathFunction::QuantizeToF16
1371            | crate::MathFunction::ExtractBits
1372            | crate::MathFunction::InsertBits
1373            | crate::MathFunction::Pack4x8snorm
1374            | crate::MathFunction::Pack4x8unorm
1375            | crate::MathFunction::Pack2x16snorm
1376            | crate::MathFunction::Pack2x16unorm
1377            | crate::MathFunction::Pack2x16float
1378            | crate::MathFunction::Pack4xI8
1379            | crate::MathFunction::Pack4xU8
1380            | crate::MathFunction::Pack4xI8Clamp
1381            | crate::MathFunction::Pack4xU8Clamp
1382            | crate::MathFunction::Unpack4x8snorm
1383            | crate::MathFunction::Unpack4x8unorm
1384            | crate::MathFunction::Unpack2x16snorm
1385            | crate::MathFunction::Unpack2x16unorm
1386            | crate::MathFunction::Unpack2x16float
1387            | crate::MathFunction::Unpack4xI8
1388            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1389                format!("{fun:?} built-in function"),
1390            )),
1391        }
1392    }
1393
1394    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1395    fn packed_dot_product(
1396        &mut self,
1397        a: Handle<Expression>,
1398        b: Handle<Expression>,
1399        span: Span,
1400        signed: bool,
1401    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1402        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1403            return Err(ConstantEvaluatorError::InvalidMathArg);
1404        };
1405        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1406            return Err(ConstantEvaluatorError::InvalidMathArg);
1407        };
1408
1409        let result = if signed {
1410            Literal::I32(
1411                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1412                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1413                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1414                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1415            )
1416        } else {
1417            Literal::U32(
1418                (a & 0xFF) * (b & 0xFF)
1419                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1420                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1421                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1422            )
1423        };
1424
1425        self.register_evaluated_expr(Expression::Literal(result), span)
1426    }
1427
1428    /// Vector cross product.
1429    fn cross_product(
1430        &mut self,
1431        a: Handle<Expression>,
1432        b: Handle<Expression>,
1433        span: Span,
1434    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1435        use Literal as Li;
1436
1437        let (a, ty) = self.extract_vec::<3>(a)?;
1438        let (b, _) = self.extract_vec::<3>(b)?;
1439
1440        let product = match (a, b) {
1441            (
1442                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1443                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1444            ) => {
1445                // `cross` has no overload for AbstractInt, so AbstractInt
1446                // arguments are automatically converted to AbstractFloat. Since
1447                // `f64` has a much wider range than `i64`, there's no danger of
1448                // overflow here.
1449                let p = cross_product(
1450                    [a0 as f64, a1 as f64, a2 as f64],
1451                    [b0 as f64, b1 as f64, b2 as f64],
1452                );
1453                [
1454                    Li::AbstractFloat(p[0]),
1455                    Li::AbstractFloat(p[1]),
1456                    Li::AbstractFloat(p[2]),
1457                ]
1458            }
1459            (
1460                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1461                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1462            ) => {
1463                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1464                [
1465                    Li::AbstractFloat(p[0]),
1466                    Li::AbstractFloat(p[1]),
1467                    Li::AbstractFloat(p[2]),
1468                ]
1469            }
1470            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1471                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1472                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1473            }
1474            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1475                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1476                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1477            }
1478            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1479                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1480                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1481            }
1482            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1483        };
1484
1485        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1486        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1487        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1488
1489        self.register_evaluated_expr(
1490            Expression::Compose {
1491                ty,
1492                components: vec![p0, p1, p2],
1493            },
1494            span,
1495        )
1496    }
1497
1498    /// Extract the values of a `vecN` from `expr`.
1499    ///
1500    /// Return the value of `expr`, whose type is `vecN<S>` for some
1501    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1502    /// values.
1503    ///
1504    /// Also return the type handle from the `Compose` expression.
1505    fn extract_vec<const N: usize>(
1506        &mut self,
1507        expr: Handle<Expression>,
1508    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1509        let span = self.expressions.get_span(expr);
1510        let expr = self.eval_zero_value_and_splat(expr, span)?;
1511        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1512            return Err(ConstantEvaluatorError::InvalidMathArg);
1513        };
1514
1515        let mut value = [Literal::Bool(false); N];
1516        for (component, elt) in
1517            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1518                .zip(value.iter_mut())
1519        {
1520            let Expression::Literal(literal) = self.expressions[component] else {
1521                return Err(ConstantEvaluatorError::InvalidMathArg);
1522            };
1523            *elt = literal;
1524        }
1525
1526        Ok((value, ty))
1527    }
1528
1529    fn array_length(
1530        &mut self,
1531        array: Handle<Expression>,
1532        span: Span,
1533    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1534        match self.expressions[array] {
1535            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1536                match self.types[ty].inner {
1537                    TypeInner::Array { size, .. } => match size {
1538                        ArraySize::Constant(len) => {
1539                            let expr = Expression::Literal(Literal::U32(len.get()));
1540                            self.register_evaluated_expr(expr, span)
1541                        }
1542                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1543                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1544                    },
1545                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1546                }
1547            }
1548            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1549        }
1550    }
1551
1552    fn access(
1553        &mut self,
1554        base: Handle<Expression>,
1555        index: usize,
1556        span: Span,
1557    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1558        match self.expressions[base] {
1559            Expression::ZeroValue(ty) => {
1560                let ty_inner = &self.types[ty].inner;
1561                let components = ty_inner
1562                    .components()
1563                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1564
1565                if index >= components as usize {
1566                    Err(ConstantEvaluatorError::InvalidAccessBase)
1567                } else {
1568                    let ty_res = ty_inner
1569                        .component_type(index)
1570                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1571                    let ty = match ty_res {
1572                        crate::proc::TypeResolution::Handle(ty) => ty,
1573                        crate::proc::TypeResolution::Value(inner) => {
1574                            self.types.insert(Type { name: None, inner }, span)
1575                        }
1576                    };
1577                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1578                }
1579            }
1580            Expression::Splat { size, value } => {
1581                if index >= size as usize {
1582                    Err(ConstantEvaluatorError::InvalidAccessBase)
1583                } else {
1584                    Ok(value)
1585                }
1586            }
1587            Expression::Compose { ty, ref components } => {
1588                let _ = self.types[ty]
1589                    .inner
1590                    .components()
1591                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1592
1593                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1594                    .nth(index)
1595                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1596            }
1597            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1598        }
1599    }
1600
1601    fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1602        match self.expressions[expr] {
1603            Expression::ZeroValue(ty)
1604                if matches!(
1605                    self.types[ty].inner,
1606                    TypeInner::Scalar(crate::Scalar {
1607                        kind: ScalarKind::Uint,
1608                        ..
1609                    })
1610                ) =>
1611            {
1612                Ok(0)
1613            }
1614            Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1615            _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1616        }
1617    }
1618
1619    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
1620    ///
1621    /// [`ZeroValue`]: Expression::ZeroValue
1622    /// [`Splat`]: Expression::Splat
1623    /// [`Literal`]: Expression::Literal
1624    /// [`Compose`]: Expression::Compose
1625    fn eval_zero_value_and_splat(
1626        &mut self,
1627        mut expr: Handle<Expression>,
1628        span: Span,
1629    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1630        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
1631        // each of its components.
1632        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
1633            let components = components
1634                .clone()
1635                .iter()
1636                .map(|component| self.eval_zero_value_and_splat(*component, span))
1637                .collect::<Result<_, _>>()?;
1638            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
1639        }
1640
1641        // The result of the splat() for a Splat of a scalar ZeroValue is a
1642        // vector ZeroValue, so we must call eval_zero_value_impl() after
1643        // splat() in order to ensure we have no ZeroValues remaining.
1644        if let Expression::Splat { size, value } = self.expressions[expr] {
1645            expr = self.splat(value, size, span)?;
1646        }
1647        if let Expression::ZeroValue(ty) = self.expressions[expr] {
1648            expr = self.eval_zero_value_impl(ty, span)?;
1649        }
1650        Ok(expr)
1651    }
1652
1653    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1654    ///
1655    /// [`ZeroValue`]: Expression::ZeroValue
1656    /// [`Literal`]: Expression::Literal
1657    /// [`Compose`]: Expression::Compose
1658    fn eval_zero_value(
1659        &mut self,
1660        expr: Handle<Expression>,
1661        span: Span,
1662    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1663        match self.expressions[expr] {
1664            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1665            _ => Ok(expr),
1666        }
1667    }
1668
1669    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1670    ///
1671    /// [`ZeroValue`]: Expression::ZeroValue
1672    /// [`Literal`]: Expression::Literal
1673    /// [`Compose`]: Expression::Compose
1674    fn eval_zero_value_impl(
1675        &mut self,
1676        ty: Handle<Type>,
1677        span: Span,
1678    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1679        match self.types[ty].inner {
1680            TypeInner::Scalar(scalar) => {
1681                let expr = Expression::Literal(
1682                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1683                );
1684                self.register_evaluated_expr(expr, span)
1685            }
1686            TypeInner::Vector { size, scalar } => {
1687                let scalar_ty = self.types.insert(
1688                    Type {
1689                        name: None,
1690                        inner: TypeInner::Scalar(scalar),
1691                    },
1692                    span,
1693                );
1694                let el = self.eval_zero_value_impl(scalar_ty, span)?;
1695                let expr = Expression::Compose {
1696                    ty,
1697                    components: vec![el; size as usize],
1698                };
1699                self.register_evaluated_expr(expr, span)
1700            }
1701            TypeInner::Matrix {
1702                columns,
1703                rows,
1704                scalar,
1705            } => {
1706                let vec_ty = self.types.insert(
1707                    Type {
1708                        name: None,
1709                        inner: TypeInner::Vector { size: rows, scalar },
1710                    },
1711                    span,
1712                );
1713                let el = self.eval_zero_value_impl(vec_ty, span)?;
1714                let expr = Expression::Compose {
1715                    ty,
1716                    components: vec![el; columns as usize],
1717                };
1718                self.register_evaluated_expr(expr, span)
1719            }
1720            TypeInner::Array {
1721                base,
1722                size: ArraySize::Constant(size),
1723                ..
1724            } => {
1725                let el = self.eval_zero_value_impl(base, span)?;
1726                let expr = Expression::Compose {
1727                    ty,
1728                    components: vec![el; size.get() as usize],
1729                };
1730                self.register_evaluated_expr(expr, span)
1731            }
1732            TypeInner::Struct { ref members, .. } => {
1733                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1734                let mut components = Vec::with_capacity(members.len());
1735                for ty in types {
1736                    components.push(self.eval_zero_value_impl(ty, span)?);
1737                }
1738                let expr = Expression::Compose { ty, components };
1739                self.register_evaluated_expr(expr, span)
1740            }
1741            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1742        }
1743    }
1744
1745    /// Convert the scalar components of `expr` to `target`.
1746    ///
1747    /// Treat `span` as the location of the resulting expression.
1748    pub fn cast(
1749        &mut self,
1750        expr: Handle<Expression>,
1751        target: crate::Scalar,
1752        span: Span,
1753    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1754        use crate::Scalar as Sc;
1755
1756        let expr = self.eval_zero_value(expr, span)?;
1757
1758        let make_error = || -> Result<_, ConstantEvaluatorError> {
1759            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1760
1761            #[cfg(feature = "wgsl-in")]
1762            let to = target.to_wgsl_for_diagnostics();
1763
1764            #[cfg(not(feature = "wgsl-in"))]
1765            let to = format!("{target:?}");
1766
1767            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1768        };
1769
1770        use crate::proc::type_methods::IntFloatLimits;
1771
1772        let expr = match self.expressions[expr] {
1773            Expression::Literal(literal) => {
1774                let literal = match target {
1775                    Sc::I32 => Literal::I32(match literal {
1776                        Literal::I32(v) => v,
1777                        Literal::U32(v) => v as i32,
1778                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
1779                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
1780                        Literal::Bool(v) => v as i32,
1781                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1782                            return make_error();
1783                        }
1784                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1785                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1786                    }),
1787                    Sc::U32 => Literal::U32(match literal {
1788                        Literal::I32(v) => v as u32,
1789                        Literal::U32(v) => v,
1790                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
1791                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
1792                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
1793                        Literal::Bool(v) => v as u32,
1794                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1795                            return make_error();
1796                        }
1797                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1798                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1799                    }),
1800                    Sc::I64 => Literal::I64(match literal {
1801                        Literal::I32(v) => v as i64,
1802                        Literal::U32(v) => v as i64,
1803                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1804                        Literal::Bool(v) => v as i64,
1805                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1806                        Literal::I64(v) => v,
1807                        Literal::U64(v) => v as i64,
1808                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
1809                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1810                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1811                    }),
1812                    Sc::U64 => Literal::U64(match literal {
1813                        Literal::I32(v) => v as u64,
1814                        Literal::U32(v) => v as u64,
1815                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1816                        Literal::Bool(v) => v as u64,
1817                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1818                        Literal::I64(v) => v as u64,
1819                        Literal::U64(v) => v,
1820                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
1821                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
1822                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1823                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1824                    }),
1825                    Sc::F16 => Literal::F16(match literal {
1826                        Literal::F16(v) => v,
1827                        Literal::F32(v) => f16::from_f32(v),
1828                        Literal::F64(v) => f16::from_f64(v),
1829                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
1830                        Literal::I64(v) => f16::from_i64(v).unwrap(),
1831                        Literal::U64(v) => f16::from_u64(v).unwrap(),
1832                        Literal::I32(v) => f16::from_i32(v).unwrap(),
1833                        Literal::U32(v) => f16::from_u32(v).unwrap(),
1834                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
1835                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
1836                    }),
1837                    Sc::F32 => Literal::F32(match literal {
1838                        Literal::I32(v) => v as f32,
1839                        Literal::U32(v) => v as f32,
1840                        Literal::F32(v) => v,
1841                        Literal::Bool(v) => v as u32 as f32,
1842                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1843                            return make_error();
1844                        }
1845                        Literal::F16(v) => f16::to_f32(v),
1846                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1847                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1848                    }),
1849                    Sc::F64 => Literal::F64(match literal {
1850                        Literal::I32(v) => v as f64,
1851                        Literal::U32(v) => v as f64,
1852                        Literal::F16(v) => f16::to_f64(v),
1853                        Literal::F32(v) => v as f64,
1854                        Literal::F64(v) => v,
1855                        Literal::Bool(v) => v as u32 as f64,
1856                        Literal::I64(_) | Literal::U64(_) => return make_error(),
1857                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1858                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1859                    }),
1860                    Sc::BOOL => Literal::Bool(match literal {
1861                        Literal::I32(v) => v != 0,
1862                        Literal::U32(v) => v != 0,
1863                        Literal::F32(v) => v != 0.0,
1864                        Literal::F16(v) => v != f16::zero(),
1865                        Literal::Bool(v) => v,
1866                        Literal::AbstractInt(v) => v != 0,
1867                        Literal::AbstractFloat(v) => v != 0.0,
1868                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1869                            return make_error();
1870                        }
1871                    }),
1872                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1873                        Literal::AbstractInt(v) => {
1874                            // Overflow is forbidden, but inexact conversions
1875                            // are fine. The range of f64 is far larger than
1876                            // that of i64, so we don't have to check anything
1877                            // here.
1878                            v as f64
1879                        }
1880                        Literal::AbstractFloat(v) => v,
1881                        _ => return make_error(),
1882                    }),
1883                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
1884                        Literal::AbstractInt(v) => v,
1885                        _ => return make_error(),
1886                    }),
1887                    _ => {
1888                        log::debug!("Constant evaluator refused to convert value to {target:?}");
1889                        return make_error();
1890                    }
1891                };
1892                Expression::Literal(literal)
1893            }
1894            Expression::Compose {
1895                ty,
1896                components: ref src_components,
1897            } => {
1898                let ty_inner = match self.types[ty].inner {
1899                    TypeInner::Vector { size, .. } => TypeInner::Vector {
1900                        size,
1901                        scalar: target,
1902                    },
1903                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1904                        columns,
1905                        rows,
1906                        scalar: target,
1907                    },
1908                    _ => return make_error(),
1909                };
1910
1911                let mut components = src_components.clone();
1912                for component in &mut components {
1913                    *component = self.cast(*component, target, span)?;
1914                }
1915
1916                let ty = self.types.insert(
1917                    Type {
1918                        name: None,
1919                        inner: ty_inner,
1920                    },
1921                    span,
1922                );
1923
1924                Expression::Compose { ty, components }
1925            }
1926            Expression::Splat { size, value } => {
1927                let value_span = self.expressions.get_span(value);
1928                let cast_value = self.cast(value, target, value_span)?;
1929                Expression::Splat {
1930                    size,
1931                    value: cast_value,
1932                }
1933            }
1934            _ => return make_error(),
1935        };
1936
1937        self.register_evaluated_expr(expr, span)
1938    }
1939
1940    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
1941    ///
1942    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
1943    /// matrix, or nested arrays of such.
1944    ///
1945    /// This is basically the same as the [`cast`] method, except that that
1946    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
1947    ///
1948    /// Treat `span` as the location of the resulting expression.
1949    ///
1950    /// [`cast`]: ConstantEvaluator::cast
1951    /// [`As`]: crate::Expression::As
1952    pub fn cast_array(
1953        &mut self,
1954        expr: Handle<Expression>,
1955        target: crate::Scalar,
1956        span: Span,
1957    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1958        let expr = self.check_and_get(expr)?;
1959
1960        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1961            return self.cast(expr, target, span);
1962        };
1963
1964        let TypeInner::Array {
1965            base: _,
1966            size,
1967            stride: _,
1968        } = self.types[ty].inner
1969        else {
1970            return self.cast(expr, target, span);
1971        };
1972
1973        let mut components = components.clone();
1974        for component in &mut components {
1975            *component = self.cast_array(*component, target, span)?;
1976        }
1977
1978        let first = components.first().unwrap();
1979        let new_base = match self.resolve_type(*first)? {
1980            crate::proc::TypeResolution::Handle(ty) => ty,
1981            crate::proc::TypeResolution::Value(inner) => {
1982                self.types.insert(Type { name: None, inner }, span)
1983            }
1984        };
1985        let mut layouter = core::mem::take(self.layouter);
1986        layouter.update(self.to_ctx()).unwrap();
1987        *self.layouter = layouter;
1988
1989        let new_base_stride = self.layouter[new_base].to_stride();
1990        let new_array_ty = self.types.insert(
1991            Type {
1992                name: None,
1993                inner: TypeInner::Array {
1994                    base: new_base,
1995                    size,
1996                    stride: new_base_stride,
1997                },
1998            },
1999            span,
2000        );
2001
2002        let compose = Expression::Compose {
2003            ty: new_array_ty,
2004            components,
2005        };
2006        self.register_evaluated_expr(compose, span)
2007    }
2008
2009    fn unary_op(
2010        &mut self,
2011        op: UnaryOperator,
2012        expr: Handle<Expression>,
2013        span: Span,
2014    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2015        let expr = self.eval_zero_value_and_splat(expr, span)?;
2016
2017        let expr = match self.expressions[expr] {
2018            Expression::Literal(value) => Expression::Literal(match op {
2019                UnaryOperator::Negate => match value {
2020                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2021                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2022                    Literal::F32(v) => Literal::F32(-v),
2023                    Literal::F16(v) => Literal::F16(-v),
2024                    Literal::F64(v) => Literal::F64(-v),
2025                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2026                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2027                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2028                },
2029                UnaryOperator::LogicalNot => match value {
2030                    Literal::Bool(v) => Literal::Bool(!v),
2031                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2032                },
2033                UnaryOperator::BitwiseNot => match value {
2034                    Literal::I32(v) => Literal::I32(!v),
2035                    Literal::I64(v) => Literal::I64(!v),
2036                    Literal::U32(v) => Literal::U32(!v),
2037                    Literal::U64(v) => Literal::U64(!v),
2038                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2039                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2040                },
2041            }),
2042            Expression::Compose {
2043                ty,
2044                components: ref src_components,
2045            } => {
2046                match self.types[ty].inner {
2047                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2048                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2049                }
2050
2051                let mut components = src_components.clone();
2052                for component in &mut components {
2053                    *component = self.unary_op(op, *component, span)?;
2054                }
2055
2056                Expression::Compose { ty, components }
2057            }
2058            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2059        };
2060
2061        self.register_evaluated_expr(expr, span)
2062    }
2063
2064    fn binary_op(
2065        &mut self,
2066        op: BinaryOperator,
2067        left: Handle<Expression>,
2068        right: Handle<Expression>,
2069        span: Span,
2070    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2071        let left = self.eval_zero_value_and_splat(left, span)?;
2072        let right = self.eval_zero_value_and_splat(right, span)?;
2073
2074        let expr = match (&self.expressions[left], &self.expressions[right]) {
2075            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2076                let literal = match op {
2077                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2078                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2079                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2080                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2081                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2082                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2083
2084                    _ => match (left_value, right_value) {
2085                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2086                            BinaryOperator::Add => a.wrapping_add(b),
2087                            BinaryOperator::Subtract => a.wrapping_sub(b),
2088                            BinaryOperator::Multiply => a.wrapping_mul(b),
2089                            BinaryOperator::Divide => {
2090                                if b == 0 {
2091                                    return Err(ConstantEvaluatorError::DivisionByZero);
2092                                } else {
2093                                    a.wrapping_div(b)
2094                                }
2095                            }
2096                            BinaryOperator::Modulo => {
2097                                if b == 0 {
2098                                    return Err(ConstantEvaluatorError::RemainderByZero);
2099                                } else {
2100                                    a.wrapping_rem(b)
2101                                }
2102                            }
2103                            BinaryOperator::And => a & b,
2104                            BinaryOperator::ExclusiveOr => a ^ b,
2105                            BinaryOperator::InclusiveOr => a | b,
2106                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2107                        }),
2108                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2109                            BinaryOperator::ShiftLeft => {
2110                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2111                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2112                                }
2113                                a.checked_shl(b)
2114                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2115                            }
2116                            BinaryOperator::ShiftRight => a
2117                                .checked_shr(b)
2118                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2119                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2120                        }),
2121                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2122                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2123                                ConstantEvaluatorError::Overflow("addition".into())
2124                            })?,
2125                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2126                                ConstantEvaluatorError::Overflow("subtraction".into())
2127                            })?,
2128                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2129                                ConstantEvaluatorError::Overflow("multiplication".into())
2130                            })?,
2131                            BinaryOperator::Divide => a
2132                                .checked_div(b)
2133                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2134                            BinaryOperator::Modulo => a
2135                                .checked_rem(b)
2136                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2137                            BinaryOperator::And => a & b,
2138                            BinaryOperator::ExclusiveOr => a ^ b,
2139                            BinaryOperator::InclusiveOr => a | b,
2140                            BinaryOperator::ShiftLeft => a
2141                                .checked_mul(
2142                                    1u32.checked_shl(b)
2143                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2144                                )
2145                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2146                            BinaryOperator::ShiftRight => a
2147                                .checked_shr(b)
2148                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2149                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2150                        }),
2151                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2152                            BinaryOperator::Add => a + b,
2153                            BinaryOperator::Subtract => a - b,
2154                            BinaryOperator::Multiply => a * b,
2155                            BinaryOperator::Divide => a / b,
2156                            BinaryOperator::Modulo => a % b,
2157                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2158                        }),
2159                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2160                            Literal::AbstractInt(match op {
2161                                BinaryOperator::ShiftLeft => {
2162                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2163                                        return Err(ConstantEvaluatorError::Overflow(
2164                                            "<<".to_string(),
2165                                        ));
2166                                    }
2167                                    a.checked_shl(b).unwrap_or(0)
2168                                }
2169                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2170                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2171                            })
2172                        }
2173                        (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2174                            BinaryOperator::Add => a + b,
2175                            BinaryOperator::Subtract => a - b,
2176                            BinaryOperator::Multiply => a * b,
2177                            BinaryOperator::Divide => a / b,
2178                            BinaryOperator::Modulo => a % b,
2179                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2180                        }),
2181                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2182                            Literal::AbstractInt(match op {
2183                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2184                                    ConstantEvaluatorError::Overflow("addition".into())
2185                                })?,
2186                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2187                                    ConstantEvaluatorError::Overflow("subtraction".into())
2188                                })?,
2189                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2190                                    ConstantEvaluatorError::Overflow("multiplication".into())
2191                                })?,
2192                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2193                                    if b == 0 {
2194                                        ConstantEvaluatorError::DivisionByZero
2195                                    } else {
2196                                        ConstantEvaluatorError::Overflow("division".into())
2197                                    }
2198                                })?,
2199                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2200                                    if b == 0 {
2201                                        ConstantEvaluatorError::RemainderByZero
2202                                    } else {
2203                                        ConstantEvaluatorError::Overflow("remainder".into())
2204                                    }
2205                                })?,
2206                                BinaryOperator::And => a & b,
2207                                BinaryOperator::ExclusiveOr => a ^ b,
2208                                BinaryOperator::InclusiveOr => a | b,
2209                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2210                            })
2211                        }
2212                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2213                            Literal::AbstractFloat(match op {
2214                                BinaryOperator::Add => a + b,
2215                                BinaryOperator::Subtract => a - b,
2216                                BinaryOperator::Multiply => a * b,
2217                                BinaryOperator::Divide => a / b,
2218                                BinaryOperator::Modulo => a % b,
2219                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2220                            })
2221                        }
2222                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2223                            BinaryOperator::LogicalAnd => a && b,
2224                            BinaryOperator::LogicalOr => a || b,
2225                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2226                        }),
2227                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2228                    },
2229                };
2230                Expression::Literal(literal)
2231            }
2232            (
2233                &Expression::Compose {
2234                    components: ref src_components,
2235                    ty,
2236                },
2237                &Expression::Literal(_),
2238            ) => {
2239                let mut components = src_components.clone();
2240                for component in &mut components {
2241                    *component = self.binary_op(op, *component, right, span)?;
2242                }
2243                Expression::Compose { ty, components }
2244            }
2245            (
2246                &Expression::Literal(_),
2247                &Expression::Compose {
2248                    components: ref src_components,
2249                    ty,
2250                },
2251            ) => {
2252                let mut components = src_components.clone();
2253                for component in &mut components {
2254                    *component = self.binary_op(op, left, *component, span)?;
2255                }
2256                Expression::Compose { ty, components }
2257            }
2258            (
2259                &Expression::Compose {
2260                    components: ref left_components,
2261                    ty: left_ty,
2262                },
2263                &Expression::Compose {
2264                    components: ref right_components,
2265                    ty: right_ty,
2266                },
2267            ) => {
2268                // We have to make a copy of the component lists, because the
2269                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2270                // the component lists.
2271                let left_flattened = crate::proc::flatten_compose(
2272                    left_ty,
2273                    left_components,
2274                    self.expressions,
2275                    self.types,
2276                );
2277                let right_flattened = crate::proc::flatten_compose(
2278                    right_ty,
2279                    right_components,
2280                    self.expressions,
2281                    self.types,
2282                );
2283
2284                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2285                // make a reasonable guess of the capacity we'll need.
2286                let mut flattened = Vec::with_capacity(left_components.len());
2287                flattened.extend(left_flattened.zip(right_flattened));
2288
2289                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2290                    (
2291                        &TypeInner::Vector {
2292                            size: left_size, ..
2293                        },
2294                        &TypeInner::Vector {
2295                            size: right_size, ..
2296                        },
2297                    ) if left_size == right_size => {
2298                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2299                    }
2300                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2301                }
2302            }
2303            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2304        };
2305
2306        self.register_evaluated_expr(expr, span)
2307    }
2308
2309    fn binary_op_vector(
2310        &mut self,
2311        op: BinaryOperator,
2312        size: crate::VectorSize,
2313        components: &[(Handle<Expression>, Handle<Expression>)],
2314        left_ty: Handle<Type>,
2315        span: Span,
2316    ) -> Result<Expression, ConstantEvaluatorError> {
2317        let ty = match op {
2318            // Relational operators produce vectors of booleans.
2319            BinaryOperator::Equal
2320            | BinaryOperator::NotEqual
2321            | BinaryOperator::Less
2322            | BinaryOperator::LessEqual
2323            | BinaryOperator::Greater
2324            | BinaryOperator::GreaterEqual => self.types.insert(
2325                Type {
2326                    name: None,
2327                    inner: TypeInner::Vector {
2328                        size,
2329                        scalar: crate::Scalar::BOOL,
2330                    },
2331                },
2332                span,
2333            ),
2334
2335            // Other operators produce the same type as their left
2336            // operand.
2337            BinaryOperator::Add
2338            | BinaryOperator::Subtract
2339            | BinaryOperator::Multiply
2340            | BinaryOperator::Divide
2341            | BinaryOperator::Modulo
2342            | BinaryOperator::And
2343            | BinaryOperator::ExclusiveOr
2344            | BinaryOperator::InclusiveOr
2345            | BinaryOperator::ShiftLeft
2346            | BinaryOperator::ShiftRight => left_ty,
2347
2348            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2349                // Not supported on vectors
2350                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2351            }
2352        };
2353
2354        let components = components
2355            .iter()
2356            .map(|&(left, right)| self.binary_op(op, left, right, span))
2357            .collect::<Result<Vec<_>, _>>()?;
2358
2359        Ok(Expression::Compose { ty, components })
2360    }
2361
2362    fn relational(
2363        &mut self,
2364        fun: RelationalFunction,
2365        arg: Handle<Expression>,
2366        span: Span,
2367    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2368        let arg = self.eval_zero_value_and_splat(arg, span)?;
2369        match fun {
2370            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2371                Expression::Literal(Literal::Bool(_)) => Ok(arg),
2372                Expression::Compose { ty, ref components }
2373                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2374                {
2375                    let components =
2376                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2377                            .map(|component| match self.expressions[component] {
2378                                Expression::Literal(Literal::Bool(val)) => Ok(val),
2379                                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2380                            })
2381                            .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2382                    let result = match fun {
2383                        RelationalFunction::All => components.iter().all(|c| *c),
2384                        RelationalFunction::Any => components.iter().any(|c| *c),
2385                        _ => unreachable!(),
2386                    };
2387                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2388                }
2389                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2390            },
2391            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2392                "{fun:?} built-in function"
2393            ))),
2394        }
2395    }
2396
2397    /// Deep copy `expr` from `expressions` into `self.expressions`.
2398    ///
2399    /// Return the root of the new copy.
2400    ///
2401    /// This is used when we're evaluating expressions in a function's
2402    /// expression arena that refer to a constant: we need to copy the
2403    /// constant's value into the function's arena so we can operate on it.
2404    fn copy_from(
2405        &mut self,
2406        expr: Handle<Expression>,
2407        expressions: &Arena<Expression>,
2408    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2409        let span = expressions.get_span(expr);
2410        match expressions[expr] {
2411            ref expr @ (Expression::Literal(_)
2412            | Expression::Constant(_)
2413            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2414            Expression::Compose { ty, ref components } => {
2415                let mut components = components.clone();
2416                for component in &mut components {
2417                    *component = self.copy_from(*component, expressions)?;
2418                }
2419                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2420            }
2421            Expression::Splat { size, value } => {
2422                let value = self.copy_from(value, expressions)?;
2423                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2424            }
2425            _ => {
2426                log::debug!("copy_from: SubexpressionsAreNotConstant");
2427                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2428            }
2429        }
2430    }
2431
2432    /// Returns the total number of components, after flattening, of a vector compose expression.
2433    fn vector_compose_flattened_size(
2434        &self,
2435        components: &[Handle<Expression>],
2436    ) -> Result<usize, ConstantEvaluatorError> {
2437        components
2438            .iter()
2439            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2440                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2441                    TypeInner::Scalar(_) => 1,
2442                    // We trust that the vector size of `component` is correct,
2443                    // as it will have already been validated when `component`
2444                    // was registered.
2445                    TypeInner::Vector { size, .. } => size as usize,
2446                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2447                };
2448                Ok(acc + size)
2449            })
2450    }
2451
2452    fn register_evaluated_expr(
2453        &mut self,
2454        expr: Expression,
2455        span: Span,
2456    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2457        // It suffices to only check_literal_value() for `Literal` expressions,
2458        // since we only register one expression at a time, `Compose`
2459        // expressions can only refer to other expressions, and `ZeroValue`
2460        // expressions are always okay.
2461        if let Expression::Literal(literal) = expr {
2462            crate::valid::check_literal_value(literal)?;
2463        }
2464
2465        // Ensure vector composes contain the correct number of components. We
2466        // do so here when each compose is registered to avoid having to deal
2467        // with the mess each time the compose is used in another expression.
2468        if let Expression::Compose { ty, ref components } = expr {
2469            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2470                let expected = size as usize;
2471                let actual = self.vector_compose_flattened_size(components)?;
2472                if expected != actual {
2473                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2474                        expected,
2475                        actual,
2476                    });
2477                }
2478            }
2479        }
2480
2481        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2482    }
2483
2484    fn append_expr(
2485        &mut self,
2486        expr: Expression,
2487        span: Span,
2488        expr_type: ExpressionKind,
2489    ) -> Handle<Expression> {
2490        let h = match self.behavior {
2491            Behavior::Wgsl(
2492                WgslRestrictions::Runtime(ref mut function_local_data)
2493                | WgslRestrictions::Const(Some(ref mut function_local_data)),
2494            )
2495            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2496                let is_running = function_local_data.emitter.is_running();
2497                let needs_pre_emit = expr.needs_pre_emit();
2498                if is_running && needs_pre_emit {
2499                    function_local_data
2500                        .block
2501                        .extend(function_local_data.emitter.finish(self.expressions));
2502                    let h = self.expressions.append(expr, span);
2503                    function_local_data.emitter.start(self.expressions);
2504                    h
2505                } else {
2506                    self.expressions.append(expr, span)
2507                }
2508            }
2509            _ => self.expressions.append(expr, span),
2510        };
2511        self.expression_kind_tracker.insert(h, expr_type);
2512        h
2513    }
2514
2515    fn resolve_type(
2516        &self,
2517        expr: Handle<Expression>,
2518    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2519        use crate::proc::TypeResolution as Tr;
2520        use crate::Expression as Ex;
2521        let resolution = match self.expressions[expr] {
2522            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2523            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2524            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2525            Ex::Splat { size, value } => {
2526                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2527                    return Err(ConstantEvaluatorError::SplatScalarOnly);
2528                };
2529                Tr::Value(TypeInner::Vector { scalar, size })
2530            }
2531            _ => {
2532                log::debug!("resolve_type: SubexpressionsAreNotConstant");
2533                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2534            }
2535        };
2536
2537        Ok(resolution)
2538    }
2539
2540    fn select(
2541        &mut self,
2542        reject: Handle<Expression>,
2543        accept: Handle<Expression>,
2544        condition: Handle<Expression>,
2545        span: Span,
2546    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2547        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
2548
2549        let reject = arg(reject)?;
2550        let accept = arg(accept)?;
2551        let condition = arg(condition)?;
2552
2553        let select_single_component =
2554            |this: &mut Self, reject_scalar, reject, accept, condition| {
2555                let accept = this.cast(accept, reject_scalar, span)?;
2556                if condition {
2557                    Ok(accept)
2558                } else {
2559                    Ok(reject)
2560                }
2561            };
2562
2563        match (&self.expressions[reject], &self.expressions[accept]) {
2564            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
2565                let reject_scalar = reject_lit.scalar();
2566                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
2567                else {
2568                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
2569                };
2570                select_single_component(self, reject_scalar, reject, accept, condition)
2571            }
2572            (
2573                &Expression::Compose {
2574                    ty: reject_ty,
2575                    components: ref reject_components,
2576                },
2577                &Expression::Compose {
2578                    ty: accept_ty,
2579                    components: ref accept_components,
2580                },
2581            ) => {
2582                let ty_deets = |ty| {
2583                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
2584                    (size.unwrap(), scalar)
2585                };
2586
2587                let expected_vec_size = {
2588                    let [(reject_vec_size, _), (accept_vec_size, _)] =
2589                        [reject_ty, accept_ty].map(ty_deets);
2590
2591                    if reject_vec_size != accept_vec_size {
2592                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
2593                            reject: reject_vec_size,
2594                            accept: accept_vec_size,
2595                        });
2596                    }
2597                    reject_vec_size
2598                };
2599
2600                let condition_components = match self.expressions[condition] {
2601                    Expression::Literal(Literal::Bool(condition)) => {
2602                        vec![condition; (expected_vec_size as u8).into()]
2603                    }
2604                    Expression::Compose {
2605                        ty: condition_ty,
2606                        components: ref condition_components,
2607                    } => {
2608                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
2609                        if condition_scalar.kind != ScalarKind::Bool {
2610                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
2611                        }
2612                        if condition_vec_size != expected_vec_size {
2613                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
2614                        }
2615                        condition_components
2616                            .iter()
2617                            .copied()
2618                            .map(|component| match &self.expressions[component] {
2619                                &Expression::Literal(Literal::Bool(condition)) => condition,
2620                                _ => unreachable!(),
2621                            })
2622                            .collect()
2623                    }
2624
2625                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
2626                };
2627
2628                let evaluated = Expression::Compose {
2629                    ty: reject_ty,
2630                    components: reject_components
2631                        .clone()
2632                        .into_iter()
2633                        .zip(accept_components.clone().into_iter())
2634                        .zip(condition_components.into_iter())
2635                        .map(|((reject, accept), condition)| {
2636                            let reject_scalar = match &self.expressions[reject] {
2637                                &Expression::Literal(lit) => lit.scalar(),
2638                                _ => unreachable!(),
2639                            };
2640                            select_single_component(self, reject_scalar, reject, accept, condition)
2641                        })
2642                        .collect::<Result<_, _>>()?,
2643                };
2644                self.register_evaluated_expr(evaluated, span)
2645            }
2646            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
2647        }
2648    }
2649}
2650
2651fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2652    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
2653    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
2654    // return a right-to-left bit index of 0.
2655    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2656        match e {
2657            idx @ 0..=31 => idx,
2658            32 => u32::MAX,
2659            _ => unreachable!(),
2660        }
2661    };
2662    match concrete_int {
2663        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2664        ConcreteInt::I32([e]) => {
2665            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2666        }
2667    }
2668}
2669
2670#[test]
2671fn first_trailing_bit_smoke() {
2672    assert_eq!(
2673        first_trailing_bit(ConcreteInt::I32([0])),
2674        ConcreteInt::I32([-1])
2675    );
2676    assert_eq!(
2677        first_trailing_bit(ConcreteInt::I32([1])),
2678        ConcreteInt::I32([0])
2679    );
2680    assert_eq!(
2681        first_trailing_bit(ConcreteInt::I32([2])),
2682        ConcreteInt::I32([1])
2683    );
2684    assert_eq!(
2685        first_trailing_bit(ConcreteInt::I32([-1])),
2686        ConcreteInt::I32([0]),
2687    );
2688    assert_eq!(
2689        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2690        ConcreteInt::I32([31]),
2691    );
2692    assert_eq!(
2693        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2694        ConcreteInt::I32([0]),
2695    );
2696    for idx in 0..32 {
2697        assert_eq!(
2698            first_trailing_bit(ConcreteInt::I32([1 << idx])),
2699            ConcreteInt::I32([idx])
2700        )
2701    }
2702
2703    assert_eq!(
2704        first_trailing_bit(ConcreteInt::U32([0])),
2705        ConcreteInt::U32([u32::MAX])
2706    );
2707    assert_eq!(
2708        first_trailing_bit(ConcreteInt::U32([1])),
2709        ConcreteInt::U32([0])
2710    );
2711    assert_eq!(
2712        first_trailing_bit(ConcreteInt::U32([2])),
2713        ConcreteInt::U32([1])
2714    );
2715    assert_eq!(
2716        first_trailing_bit(ConcreteInt::U32([1 << 31])),
2717        ConcreteInt::U32([31]),
2718    );
2719    assert_eq!(
2720        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2721        ConcreteInt::U32([0]),
2722    );
2723    for idx in 0..32 {
2724        assert_eq!(
2725            first_trailing_bit(ConcreteInt::U32([1 << idx])),
2726            ConcreteInt::U32([idx])
2727        )
2728    }
2729}
2730
2731fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2732    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
2733    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
2734    // index of 0.
2735    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2736        match e {
2737            idx @ 0..=31 => 31 - idx,
2738            32 => u32::MAX,
2739            _ => unreachable!(),
2740        }
2741    };
2742    match concrete_int {
2743        ConcreteInt::I32([e]) => ConcreteInt::I32([{
2744            let rtl_bit_index = if e.is_negative() {
2745                e.leading_ones()
2746            } else {
2747                e.leading_zeros()
2748            };
2749            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2750        }]),
2751        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2752    }
2753}
2754
2755#[test]
2756fn first_leading_bit_smoke() {
2757    assert_eq!(
2758        first_leading_bit(ConcreteInt::I32([-1])),
2759        ConcreteInt::I32([-1])
2760    );
2761    assert_eq!(
2762        first_leading_bit(ConcreteInt::I32([0])),
2763        ConcreteInt::I32([-1])
2764    );
2765    assert_eq!(
2766        first_leading_bit(ConcreteInt::I32([1])),
2767        ConcreteInt::I32([0])
2768    );
2769    assert_eq!(
2770        first_leading_bit(ConcreteInt::I32([-2])),
2771        ConcreteInt::I32([0])
2772    );
2773    assert_eq!(
2774        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2775        ConcreteInt::I32([12])
2776    );
2777    assert_eq!(
2778        first_leading_bit(ConcreteInt::I32([i32::MAX])),
2779        ConcreteInt::I32([30])
2780    );
2781    assert_eq!(
2782        first_leading_bit(ConcreteInt::I32([i32::MIN])),
2783        ConcreteInt::I32([30])
2784    );
2785    // NOTE: Ignore the sign bit, which is a separate (above) case.
2786    for idx in 0..(32 - 1) {
2787        assert_eq!(
2788            first_leading_bit(ConcreteInt::I32([1 << idx])),
2789            ConcreteInt::I32([idx])
2790        );
2791    }
2792    for idx in 1..(32 - 1) {
2793        assert_eq!(
2794            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2795            ConcreteInt::I32([idx - 1])
2796        );
2797    }
2798
2799    assert_eq!(
2800        first_leading_bit(ConcreteInt::U32([0])),
2801        ConcreteInt::U32([u32::MAX])
2802    );
2803    assert_eq!(
2804        first_leading_bit(ConcreteInt::U32([1])),
2805        ConcreteInt::U32([0])
2806    );
2807    assert_eq!(
2808        first_leading_bit(ConcreteInt::U32([u32::MAX])),
2809        ConcreteInt::U32([31])
2810    );
2811    for idx in 0..32 {
2812        assert_eq!(
2813            first_leading_bit(ConcreteInt::U32([1 << idx])),
2814            ConcreteInt::U32([idx])
2815        )
2816    }
2817}
2818
2819/// Trait for conversions of abstract values to concrete types.
2820trait TryFromAbstract<T>: Sized {
2821    /// Convert an abstract literal `value` to `Self`.
2822    ///
2823    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
2824    /// WGSL, we follow WGSL's conversion rules here:
2825    ///
2826    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
2827    ///   from [`AbstractInt`] to an integer type are either lossless or an
2828    ///   error.
2829    ///
2830    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
2831    ///   to floating point in constant expressions and override
2832    ///   expressions are errors if the value is out of range for the
2833    ///   destination type, but rounding is okay.
2834    ///
2835    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
2836    ///   other floating point type, following the scalar floating point to
2837    ///   integral conversion algorithm (§15.7.6). There is no automatic
2838    ///   conversion from AbstractFloat to integer types.
2839    ///
2840    /// [`AbstractInt`]: crate::Literal::AbstractInt
2841    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
2842    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2843}
2844
2845impl TryFromAbstract<i64> for i32 {
2846    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2847        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2848            value: format!("{value:?}"),
2849            to_type: "i32",
2850        })
2851    }
2852}
2853
2854impl TryFromAbstract<i64> for u32 {
2855    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2856        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2857            value: format!("{value:?}"),
2858            to_type: "u32",
2859        })
2860    }
2861}
2862
2863impl TryFromAbstract<i64> for u64 {
2864    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2865        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2866            value: format!("{value:?}"),
2867            to_type: "u64",
2868        })
2869    }
2870}
2871
2872impl TryFromAbstract<i64> for i64 {
2873    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2874        Ok(value)
2875    }
2876}
2877
2878impl TryFromAbstract<i64> for f32 {
2879    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2880        let f = value as f32;
2881        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2882        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
2883        // overflow here.
2884        Ok(f)
2885    }
2886}
2887
2888impl TryFromAbstract<f64> for f32 {
2889    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2890        let f = value as f32;
2891        if f.is_infinite() {
2892            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2893                value: format!("{value:?}"),
2894                to_type: "f32",
2895            });
2896        }
2897        Ok(f)
2898    }
2899}
2900
2901impl TryFromAbstract<i64> for f64 {
2902    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2903        let f = value as f64;
2904        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2905        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
2906        // overflow here.
2907        Ok(f)
2908    }
2909}
2910
2911impl TryFromAbstract<f64> for f64 {
2912    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2913        Ok(value)
2914    }
2915}
2916
2917impl TryFromAbstract<f64> for i32 {
2918    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2919        // https://www.w3.org/TR/WGSL/#floating-point-conversion
2920        // To convert a floating point scalar value X to an integer scalar type T:
2921        // * If X is a NaN, the result is an indeterminate value in T.
2922        // * If X is exactly representable in the target type T, then the
2923        //   result is that value.
2924        // * Otherwise, the result is the value in T closest to truncate(X) and
2925        //   also exactly representable in the original floating point type.
2926        //
2927        // A rust cast satisfies these requirements apart from "the result
2928        // is... exactly representable in the original floating point type".
2929        // However, i32::MIN and i32::MAX are exactly representable by f64, so
2930        // we're all good.
2931        Ok(value as i32)
2932    }
2933}
2934
2935impl TryFromAbstract<f64> for u32 {
2936    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2937        // As above, u32::MIN and u32::MAX are exactly representable by f64,
2938        // so a simple rust cast is sufficient.
2939        Ok(value as u32)
2940    }
2941}
2942
2943impl TryFromAbstract<f64> for i64 {
2944    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2945        // As above, except we clamp to the minimum and maximum values
2946        // representable by both f64 and i64.
2947        use crate::proc::type_methods::IntFloatLimits;
2948        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
2949    }
2950}
2951
2952impl TryFromAbstract<f64> for u64 {
2953    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2954        // As above, this time clamping to the minimum and maximum values
2955        // representable by both f64 and u64.
2956        use crate::proc::type_methods::IntFloatLimits;
2957        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
2958    }
2959}
2960
2961impl TryFromAbstract<f64> for f16 {
2962    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
2963        let f = f16::from_f64(value);
2964        if f.is_infinite() {
2965            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2966                value: format!("{value:?}"),
2967                to_type: "f16",
2968            });
2969        }
2970        Ok(f)
2971    }
2972}
2973
2974impl TryFromAbstract<i64> for f16 {
2975    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
2976        let f = f16::from_i64(value);
2977        if f.is_none() {
2978            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2979                value: format!("{value:?}"),
2980                to_type: "f16",
2981            });
2982        }
2983        Ok(f.unwrap())
2984    }
2985}
2986
2987fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2988where
2989    T: Copy,
2990    T: core::ops::Mul<T, Output = T>,
2991    T: core::ops::Sub<T, Output = T>,
2992{
2993    [
2994        a[1] * b[2] - a[2] * b[1],
2995        a[2] * b[0] - a[0] * b[2],
2996        a[0] * b[1] - a[1] * b[0],
2997    ]
2998}
2999
3000#[cfg(test)]
3001mod tests {
3002    use alloc::{vec, vec::Vec};
3003
3004    use crate::{
3005        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3006        UniqueArena, VectorSize,
3007    };
3008
3009    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3010
3011    #[test]
3012    fn unary_op() {
3013        let mut types = UniqueArena::new();
3014        let mut constants = Arena::new();
3015        let overrides = Arena::new();
3016        let mut global_expressions = Arena::new();
3017
3018        let scalar_ty = types.insert(
3019            Type {
3020                name: None,
3021                inner: TypeInner::Scalar(crate::Scalar::I32),
3022            },
3023            Default::default(),
3024        );
3025
3026        let vec_ty = types.insert(
3027            Type {
3028                name: None,
3029                inner: TypeInner::Vector {
3030                    size: VectorSize::Bi,
3031                    scalar: crate::Scalar::I32,
3032                },
3033            },
3034            Default::default(),
3035        );
3036
3037        let h = constants.append(
3038            Constant {
3039                name: None,
3040                ty: scalar_ty,
3041                init: global_expressions
3042                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3043            },
3044            Default::default(),
3045        );
3046
3047        let h1 = constants.append(
3048            Constant {
3049                name: None,
3050                ty: scalar_ty,
3051                init: global_expressions
3052                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3053            },
3054            Default::default(),
3055        );
3056
3057        let vec_h = constants.append(
3058            Constant {
3059                name: None,
3060                ty: vec_ty,
3061                init: global_expressions.append(
3062                    Expression::Compose {
3063                        ty: vec_ty,
3064                        components: vec![constants[h].init, constants[h1].init],
3065                    },
3066                    Default::default(),
3067                ),
3068            },
3069            Default::default(),
3070        );
3071
3072        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3073        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3074
3075        let expr2 = Expression::Unary {
3076            op: UnaryOperator::Negate,
3077            expr,
3078        };
3079
3080        let expr3 = Expression::Unary {
3081            op: UnaryOperator::BitwiseNot,
3082            expr,
3083        };
3084
3085        let expr4 = Expression::Unary {
3086            op: UnaryOperator::BitwiseNot,
3087            expr: expr1,
3088        };
3089
3090        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3091        let mut solver = ConstantEvaluator {
3092            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3093            types: &mut types,
3094            constants: &constants,
3095            overrides: &overrides,
3096            expressions: &mut global_expressions,
3097            expression_kind_tracker,
3098            layouter: &mut crate::proc::Layouter::default(),
3099        };
3100
3101        let res1 = solver
3102            .try_eval_and_append(expr2, Default::default())
3103            .unwrap();
3104        let res2 = solver
3105            .try_eval_and_append(expr3, Default::default())
3106            .unwrap();
3107        let res3 = solver
3108            .try_eval_and_append(expr4, Default::default())
3109            .unwrap();
3110
3111        assert_eq!(
3112            global_expressions[res1],
3113            Expression::Literal(Literal::I32(-4))
3114        );
3115
3116        assert_eq!(
3117            global_expressions[res2],
3118            Expression::Literal(Literal::I32(!4))
3119        );
3120
3121        let res3_inner = &global_expressions[res3];
3122
3123        match *res3_inner {
3124            Expression::Compose {
3125                ref ty,
3126                ref components,
3127            } => {
3128                assert_eq!(*ty, vec_ty);
3129                let mut components_iter = components.iter().copied();
3130                assert_eq!(
3131                    global_expressions[components_iter.next().unwrap()],
3132                    Expression::Literal(Literal::I32(!4))
3133                );
3134                assert_eq!(
3135                    global_expressions[components_iter.next().unwrap()],
3136                    Expression::Literal(Literal::I32(!8))
3137                );
3138                assert!(components_iter.next().is_none());
3139            }
3140            _ => panic!("Expected vector"),
3141        }
3142    }
3143
3144    #[test]
3145    fn cast() {
3146        let mut types = UniqueArena::new();
3147        let mut constants = Arena::new();
3148        let overrides = Arena::new();
3149        let mut global_expressions = Arena::new();
3150
3151        let scalar_ty = types.insert(
3152            Type {
3153                name: None,
3154                inner: TypeInner::Scalar(crate::Scalar::I32),
3155            },
3156            Default::default(),
3157        );
3158
3159        let h = constants.append(
3160            Constant {
3161                name: None,
3162                ty: scalar_ty,
3163                init: global_expressions
3164                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3165            },
3166            Default::default(),
3167        );
3168
3169        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3170
3171        let root = Expression::As {
3172            expr,
3173            kind: ScalarKind::Bool,
3174            convert: Some(crate::BOOL_WIDTH),
3175        };
3176
3177        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3178        let mut solver = ConstantEvaluator {
3179            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3180            types: &mut types,
3181            constants: &constants,
3182            overrides: &overrides,
3183            expressions: &mut global_expressions,
3184            expression_kind_tracker,
3185            layouter: &mut crate::proc::Layouter::default(),
3186        };
3187
3188        let res = solver
3189            .try_eval_and_append(root, Default::default())
3190            .unwrap();
3191
3192        assert_eq!(
3193            global_expressions[res],
3194            Expression::Literal(Literal::Bool(true))
3195        );
3196    }
3197
3198    #[test]
3199    fn access() {
3200        let mut types = UniqueArena::new();
3201        let mut constants = Arena::new();
3202        let overrides = Arena::new();
3203        let mut global_expressions = Arena::new();
3204
3205        let matrix_ty = types.insert(
3206            Type {
3207                name: None,
3208                inner: TypeInner::Matrix {
3209                    columns: VectorSize::Bi,
3210                    rows: VectorSize::Tri,
3211                    scalar: crate::Scalar::F32,
3212                },
3213            },
3214            Default::default(),
3215        );
3216
3217        let vec_ty = types.insert(
3218            Type {
3219                name: None,
3220                inner: TypeInner::Vector {
3221                    size: VectorSize::Tri,
3222                    scalar: crate::Scalar::F32,
3223                },
3224            },
3225            Default::default(),
3226        );
3227
3228        let mut vec1_components = Vec::with_capacity(3);
3229        let mut vec2_components = Vec::with_capacity(3);
3230
3231        for i in 0..3 {
3232            let h = global_expressions.append(
3233                Expression::Literal(Literal::F32(i as f32)),
3234                Default::default(),
3235            );
3236
3237            vec1_components.push(h)
3238        }
3239
3240        for i in 3..6 {
3241            let h = global_expressions.append(
3242                Expression::Literal(Literal::F32(i as f32)),
3243                Default::default(),
3244            );
3245
3246            vec2_components.push(h)
3247        }
3248
3249        let vec1 = constants.append(
3250            Constant {
3251                name: None,
3252                ty: vec_ty,
3253                init: global_expressions.append(
3254                    Expression::Compose {
3255                        ty: vec_ty,
3256                        components: vec1_components,
3257                    },
3258                    Default::default(),
3259                ),
3260            },
3261            Default::default(),
3262        );
3263
3264        let vec2 = constants.append(
3265            Constant {
3266                name: None,
3267                ty: vec_ty,
3268                init: global_expressions.append(
3269                    Expression::Compose {
3270                        ty: vec_ty,
3271                        components: vec2_components,
3272                    },
3273                    Default::default(),
3274                ),
3275            },
3276            Default::default(),
3277        );
3278
3279        let h = constants.append(
3280            Constant {
3281                name: None,
3282                ty: matrix_ty,
3283                init: global_expressions.append(
3284                    Expression::Compose {
3285                        ty: matrix_ty,
3286                        components: vec![constants[vec1].init, constants[vec2].init],
3287                    },
3288                    Default::default(),
3289                ),
3290            },
3291            Default::default(),
3292        );
3293
3294        let base = global_expressions.append(Expression::Constant(h), Default::default());
3295
3296        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3297        let mut solver = ConstantEvaluator {
3298            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3299            types: &mut types,
3300            constants: &constants,
3301            overrides: &overrides,
3302            expressions: &mut global_expressions,
3303            expression_kind_tracker,
3304            layouter: &mut crate::proc::Layouter::default(),
3305        };
3306
3307        let root1 = Expression::AccessIndex { base, index: 1 };
3308
3309        let res1 = solver
3310            .try_eval_and_append(root1, Default::default())
3311            .unwrap();
3312
3313        let root2 = Expression::AccessIndex {
3314            base: res1,
3315            index: 2,
3316        };
3317
3318        let res2 = solver
3319            .try_eval_and_append(root2, Default::default())
3320            .unwrap();
3321
3322        match global_expressions[res1] {
3323            Expression::Compose {
3324                ref ty,
3325                ref components,
3326            } => {
3327                assert_eq!(*ty, vec_ty);
3328                let mut components_iter = components.iter().copied();
3329                assert_eq!(
3330                    global_expressions[components_iter.next().unwrap()],
3331                    Expression::Literal(Literal::F32(3.))
3332                );
3333                assert_eq!(
3334                    global_expressions[components_iter.next().unwrap()],
3335                    Expression::Literal(Literal::F32(4.))
3336                );
3337                assert_eq!(
3338                    global_expressions[components_iter.next().unwrap()],
3339                    Expression::Literal(Literal::F32(5.))
3340                );
3341                assert!(components_iter.next().is_none());
3342            }
3343            _ => panic!("Expected vector"),
3344        }
3345
3346        assert_eq!(
3347            global_expressions[res2],
3348            Expression::Literal(Literal::F32(5.))
3349        );
3350    }
3351
3352    #[test]
3353    fn compose_of_constants() {
3354        let mut types = UniqueArena::new();
3355        let mut constants = Arena::new();
3356        let overrides = Arena::new();
3357        let mut global_expressions = Arena::new();
3358
3359        let i32_ty = types.insert(
3360            Type {
3361                name: None,
3362                inner: TypeInner::Scalar(crate::Scalar::I32),
3363            },
3364            Default::default(),
3365        );
3366
3367        let vec2_i32_ty = types.insert(
3368            Type {
3369                name: None,
3370                inner: TypeInner::Vector {
3371                    size: VectorSize::Bi,
3372                    scalar: crate::Scalar::I32,
3373                },
3374            },
3375            Default::default(),
3376        );
3377
3378        let h = constants.append(
3379            Constant {
3380                name: None,
3381                ty: i32_ty,
3382                init: global_expressions
3383                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3384            },
3385            Default::default(),
3386        );
3387
3388        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3389
3390        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3391        let mut solver = ConstantEvaluator {
3392            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3393            types: &mut types,
3394            constants: &constants,
3395            overrides: &overrides,
3396            expressions: &mut global_expressions,
3397            expression_kind_tracker,
3398            layouter: &mut crate::proc::Layouter::default(),
3399        };
3400
3401        let solved_compose = solver
3402            .try_eval_and_append(
3403                Expression::Compose {
3404                    ty: vec2_i32_ty,
3405                    components: vec![h_expr, h_expr],
3406                },
3407                Default::default(),
3408            )
3409            .unwrap();
3410        let solved_negate = solver
3411            .try_eval_and_append(
3412                Expression::Unary {
3413                    op: UnaryOperator::Negate,
3414                    expr: solved_compose,
3415                },
3416                Default::default(),
3417            )
3418            .unwrap();
3419
3420        let pass = match global_expressions[solved_negate] {
3421            Expression::Compose { ty, ref components } => {
3422                ty == vec2_i32_ty
3423                    && components.iter().all(|&component| {
3424                        let component = &global_expressions[component];
3425                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3426                    })
3427            }
3428            _ => false,
3429        };
3430        if !pass {
3431            panic!("unexpected evaluation result")
3432        }
3433    }
3434
3435    #[test]
3436    fn splat_of_constant() {
3437        let mut types = UniqueArena::new();
3438        let mut constants = Arena::new();
3439        let overrides = Arena::new();
3440        let mut global_expressions = Arena::new();
3441
3442        let i32_ty = types.insert(
3443            Type {
3444                name: None,
3445                inner: TypeInner::Scalar(crate::Scalar::I32),
3446            },
3447            Default::default(),
3448        );
3449
3450        let vec2_i32_ty = types.insert(
3451            Type {
3452                name: None,
3453                inner: TypeInner::Vector {
3454                    size: VectorSize::Bi,
3455                    scalar: crate::Scalar::I32,
3456                },
3457            },
3458            Default::default(),
3459        );
3460
3461        let h = constants.append(
3462            Constant {
3463                name: None,
3464                ty: i32_ty,
3465                init: global_expressions
3466                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3467            },
3468            Default::default(),
3469        );
3470
3471        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3472
3473        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3474        let mut solver = ConstantEvaluator {
3475            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3476            types: &mut types,
3477            constants: &constants,
3478            overrides: &overrides,
3479            expressions: &mut global_expressions,
3480            expression_kind_tracker,
3481            layouter: &mut crate::proc::Layouter::default(),
3482        };
3483
3484        let solved_compose = solver
3485            .try_eval_and_append(
3486                Expression::Splat {
3487                    size: VectorSize::Bi,
3488                    value: h_expr,
3489                },
3490                Default::default(),
3491            )
3492            .unwrap();
3493        let solved_negate = solver
3494            .try_eval_and_append(
3495                Expression::Unary {
3496                    op: UnaryOperator::Negate,
3497                    expr: solved_compose,
3498                },
3499                Default::default(),
3500            )
3501            .unwrap();
3502
3503        let pass = match global_expressions[solved_negate] {
3504            Expression::Compose { ty, ref components } => {
3505                ty == vec2_i32_ty
3506                    && components.iter().all(|&component| {
3507                        let component = &global_expressions[component];
3508                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3509                    })
3510            }
3511            _ => false,
3512        };
3513        if !pass {
3514            panic!("unexpected evaluation result")
3515        }
3516    }
3517
3518    #[test]
3519    fn splat_of_zero_value() {
3520        let mut types = UniqueArena::new();
3521        let constants = Arena::new();
3522        let overrides = Arena::new();
3523        let mut global_expressions = Arena::new();
3524
3525        let f32_ty = types.insert(
3526            Type {
3527                name: None,
3528                inner: TypeInner::Scalar(crate::Scalar::F32),
3529            },
3530            Default::default(),
3531        );
3532
3533        let vec2_f32_ty = types.insert(
3534            Type {
3535                name: None,
3536                inner: TypeInner::Vector {
3537                    size: VectorSize::Bi,
3538                    scalar: crate::Scalar::F32,
3539                },
3540            },
3541            Default::default(),
3542        );
3543
3544        let five =
3545            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3546        let five_splat = global_expressions.append(
3547            Expression::Splat {
3548                size: VectorSize::Bi,
3549                value: five,
3550            },
3551            Default::default(),
3552        );
3553        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3554        let zero_splat = global_expressions.append(
3555            Expression::Splat {
3556                size: VectorSize::Bi,
3557                value: zero,
3558            },
3559            Default::default(),
3560        );
3561
3562        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3563        let mut solver = ConstantEvaluator {
3564            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3565            types: &mut types,
3566            constants: &constants,
3567            overrides: &overrides,
3568            expressions: &mut global_expressions,
3569            expression_kind_tracker,
3570            layouter: &mut crate::proc::Layouter::default(),
3571        };
3572
3573        let solved_add = solver
3574            .try_eval_and_append(
3575                Expression::Binary {
3576                    op: crate::BinaryOperator::Add,
3577                    left: zero_splat,
3578                    right: five_splat,
3579                },
3580                Default::default(),
3581            )
3582            .unwrap();
3583
3584        let pass = match global_expressions[solved_add] {
3585            Expression::Compose { ty, ref components } => {
3586                ty == vec2_f32_ty
3587                    && components.iter().all(|&component| {
3588                        let component = &global_expressions[component];
3589                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
3590                    })
3591            }
3592            _ => false,
3593        };
3594        if !pass {
3595            panic!("unexpected evaluation result")
3596        }
3597    }
3598}