naga/proc/
constant_evaluator.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14    arena::{Arena, Handle, HandleVec, UniqueArena},
15    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16    ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
23/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
24///
25/// Technique stolen directly from
26/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
27macro_rules! with_dollar_sign {
28    ($($body:tt)*) => {
29        macro_rules! __with_dollar_sign { $($body)* }
30        __with_dollar_sign!($);
31    }
32}
33
34macro_rules! gen_component_wise_extractor {
35    (
36        $ident:ident -> $target:ident,
37        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39    ) => {
40        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
41        #[derive(Debug)]
42        #[cfg_attr(test, derive(PartialEq))]
43        enum $target<const N: usize> {
44            $(
45                #[doc = concat!(
46                    "Maps to [`Literal::",
47                    stringify!($literal),
48                    "`]",
49                )]
50                $mapping([$ty; N]),
51            )+
52        }
53
54        impl From<$target<1>> for Expression {
55            fn from(value: $target<1>) -> Self {
56                match value {
57                    $(
58                        $target::$mapping([value]) => {
59                            Expression::Literal(Literal::$literal(value))
60                        }
61                    )+
62                }
63            }
64        }
65
66        #[doc = concat!(
67            "Attempts to evaluate multiple `exprs` as a combined [`",
68            stringify!($target),
69            "`] to pass to `handler`. ",
70        )]
71        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
72        /// component of each vector.
73        ///
74        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
75        /// same length, a new vector expression is registered, composed of each component emitted
76        /// by `handler`.
77        fn $ident<const N: usize, const M: usize, F>(
78            eval: &mut ConstantEvaluator<'_>,
79            span: Span,
80            exprs: [Handle<Expression>; N],
81            mut handler: F,
82        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83        where
84            $target<M>: Into<Expression>,
85            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86        {
87            assert!(N > 0);
88            let err = ConstantEvaluatorError::InvalidMathArg;
89            let mut exprs = exprs.into_iter();
90
91            macro_rules! sanitize {
92                ($expr:expr) => {
93                    eval.eval_zero_value_and_splat($expr, span)
94                        .map(|expr| &eval.expressions[expr])
95                };
96            }
97
98            let new_expr = match sanitize!(exprs.next().unwrap())? {
99                $(
100                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101                        .chain(exprs.map(|expr| {
102                            sanitize!(expr).and_then(|expr| match expr {
103                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
104                                _ => Err(err.clone()),
105                            })
106                        }))
107                        .collect::<Result<ArrayVec<_, N>, _>>()
108                        .map(|a| a.into_inner().unwrap())
109                        .map($target::$mapping)
110                        .and_then(|comps| Ok(handler(comps)?.into())),
111                )+
112                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113                    &TypeInner::Vector { size, scalar } => match scalar.kind {
114                        $(ScalarKind::$scalar_kind)|* => {
115                            let first_ty = ty;
116                            let mut component_groups =
117                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118                            component_groups.push(crate::proc::flatten_compose(
119                                first_ty,
120                                components,
121                                eval.expressions,
122                                eval.types,
123                            ).collect());
124                            component_groups.extend(
125                                exprs
126                                    .map(|expr| {
127                                        sanitize!(expr).and_then(|expr| match expr {
128                                            &Expression::Compose { ty, ref components }
129                                                if &eval.types[ty].inner
130                                                    == &eval.types[first_ty].inner =>
131                                            {
132                                                Ok(crate::proc::flatten_compose(
133                                                    ty,
134                                                    components,
135                                                    eval.expressions,
136                                                    eval.types,
137                                                ).collect())
138                                            }
139                                            _ => Err(err.clone()),
140                                        })
141                                    })
142                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143                                    )?,
144                            );
145                            let component_groups = component_groups.into_inner().unwrap();
146                            let mut new_components =
147                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148                            for idx in 0..(size as u8).into() {
149                                let group = component_groups
150                                    .iter()
151                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152                                    .collect::<Result<ArrayVec<_, N>, _>>()?
153                                    .into_inner()
154                                    .unwrap();
155                                new_components.push($ident(
156                                    eval,
157                                    span,
158                                    group,
159                                    handler.clone(),
160                                )?);
161                            }
162                            Ok(Expression::Compose {
163                                ty: first_ty,
164                                components: new_components.into_iter().collect(),
165                            })
166                        }
167                        _ => return Err(err),
168                    },
169                    _ => return Err(err),
170                },
171                _ => return Err(err),
172            }?;
173            eval.register_evaluated_expr(new_expr, span)
174        }
175
176        with_dollar_sign! {
177            ($d:tt) => {
178                #[allow(unused)]
179                #[doc = concat!(
180                    "A convenience macro for using the same RHS for each [`",
181                    stringify!($target),
182                    "`] variant in a call to [`",
183                    stringify!($ident),
184                    "`].",
185                )]
186                macro_rules! $ident {
187                    (
188                        $eval:expr,
189                        $span:expr,
190                        [$d ($d expr:expr),+ $d (,)?],
191                        |$d ($d arg:ident),+| $d tt:tt
192                    ) => {
193                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
194                            $(
195                                $target::$mapping([$d ($d arg),+]) => {
196                                    let res = $d tt;
197                                    Result::map(res, $target::$mapping)
198                                },
199                            )+
200                        })
201                    };
202                }
203            };
204        }
205    };
206}
207
208gen_component_wise_extractor! {
209    component_wise_scalar -> Scalar,
210    literals: [
211        AbstractFloat => AbstractFloat: f64,
212        F32 => F32: f32,
213        F16 => F16: f16,
214        AbstractInt => AbstractInt: i64,
215        U32 => U32: u32,
216        I32 => I32: i32,
217        U64 => U64: u64,
218        I64 => I64: i64,
219    ],
220    scalar_kinds: [
221        Float,
222        AbstractFloat,
223        Sint,
224        Uint,
225        AbstractInt,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_float -> Float,
231    literals: [
232        AbstractFloat => Abstract: f64,
233        F32 => F32: f32,
234        F16 => F16: f16,
235    ],
236    scalar_kinds: [
237        Float,
238        AbstractFloat,
239    ],
240}
241
242gen_component_wise_extractor! {
243    component_wise_concrete_int -> ConcreteInt,
244    literals: [
245        U32 => U32: u32,
246        I32 => I32: i32,
247    ],
248    scalar_kinds: [
249        Sint,
250        Uint,
251    ],
252}
253
254gen_component_wise_extractor! {
255    component_wise_signed -> Signed,
256    literals: [
257        AbstractFloat => AbstractFloat: f64,
258        AbstractInt => AbstractInt: i64,
259        F32 => F32: f32,
260        F16 => F16: f16,
261        I32 => I32: i32,
262    ],
263    scalar_kinds: [
264        Sint,
265        AbstractInt,
266        Float,
267        AbstractFloat,
268    ],
269}
270
271/// Vectors with a concrete element type.
272#[derive(Debug)]
273enum LiteralVector {
274    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288    fn len(&self) -> usize {
289        match *self {
290            LiteralVector::F64(ref v) => v.len(),
291            LiteralVector::F32(ref v) => v.len(),
292            LiteralVector::F16(ref v) => v.len(),
293            LiteralVector::U32(ref v) => v.len(),
294            LiteralVector::I32(ref v) => v.len(),
295            LiteralVector::U64(ref v) => v.len(),
296            LiteralVector::I64(ref v) => v.len(),
297            LiteralVector::Bool(ref v) => v.len(),
298            LiteralVector::AbstractInt(ref v) => v.len(),
299            LiteralVector::AbstractFloat(ref v) => v.len(),
300        }
301    }
302
303    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304    fn from_literal(literal: Literal) -> Self {
305        match literal {
306            Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307            Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308            Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309            Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310            Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311            Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312            Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313            Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314            Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315            Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316        }
317    }
318
319    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320    /// Returns error if components types do not match.
321    /// # Panics
322    /// Panics if vector is empty
323    fn from_literal_vec(
324        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325    ) -> Result<Self, ConstantEvaluatorError> {
326        assert!(!components.is_empty());
327        Ok(match components[0] {
328            Literal::I32(_) => Self::I32(
329                components
330                    .iter()
331                    .map(|l| match l {
332                        &Literal::I32(v) => Ok(v),
333                        // TODO: should we handle abstract int here?
334                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
335                    })
336                    .collect::<Result<_, _>>()?,
337            ),
338            Literal::U32(_) => Self::U32(
339                components
340                    .iter()
341                    .map(|l| match l {
342                        &Literal::U32(v) => Ok(v),
343                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
344                    })
345                    .collect::<Result<_, _>>()?,
346            ),
347            Literal::I64(_) => Self::I64(
348                components
349                    .iter()
350                    .map(|l| match l {
351                        &Literal::I64(v) => Ok(v),
352                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
353                    })
354                    .collect::<Result<_, _>>()?,
355            ),
356            Literal::U64(_) => Self::U64(
357                components
358                    .iter()
359                    .map(|l| match l {
360                        &Literal::U64(v) => Ok(v),
361                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
362                    })
363                    .collect::<Result<_, _>>()?,
364            ),
365            Literal::F32(_) => Self::F32(
366                components
367                    .iter()
368                    .map(|l| match l {
369                        &Literal::F32(v) => Ok(v),
370                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
371                    })
372                    .collect::<Result<_, _>>()?,
373            ),
374            Literal::F64(_) => Self::F64(
375                components
376                    .iter()
377                    .map(|l| match l {
378                        &Literal::F64(v) => Ok(v),
379                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
380                    })
381                    .collect::<Result<_, _>>()?,
382            ),
383            Literal::Bool(_) => Self::Bool(
384                components
385                    .iter()
386                    .map(|l| match l {
387                        &Literal::Bool(v) => Ok(v),
388                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
389                    })
390                    .collect::<Result<_, _>>()?,
391            ),
392            Literal::AbstractInt(_) => Self::AbstractInt(
393                components
394                    .iter()
395                    .map(|l| match l {
396                        &Literal::AbstractInt(v) => Ok(v),
397                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
398                    })
399                    .collect::<Result<_, _>>()?,
400            ),
401            Literal::AbstractFloat(_) => Self::AbstractFloat(
402                components
403                    .iter()
404                    .map(|l| match l {
405                        &Literal::AbstractFloat(v) => Ok(v),
406                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
407                    })
408                    .collect::<Result<_, _>>()?,
409            ),
410            Literal::F16(_) => Self::F16(
411                components
412                    .iter()
413                    .map(|l| match l {
414                        &Literal::F16(v) => Ok(v),
415                        _ => Err(ConstantEvaluatorError::InvalidMathArg),
416                    })
417                    .collect::<Result<_, _>>()?,
418            ),
419        })
420    }
421
422    #[allow(dead_code)]
423    /// Returns [`ArrayVec`] of [`Literal`]s
424    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425        match *self {
426            LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427            LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428            LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429            LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430            LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431            LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432            LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433            LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434            LiteralVector::AbstractInt(ref v) => {
435                v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436            }
437            LiteralVector::AbstractFloat(ref v) => {
438                v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439            }
440        }
441    }
442
443    #[allow(dead_code)]
444    /// Puts self into eval's expressions arena and returns handle to it
445    fn register_as_evaluated_expr(
446        &self,
447        eval: &mut ConstantEvaluator<'_>,
448        span: Span,
449    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450        let lit_vec = self.to_literal_vec();
451        assert!(!lit_vec.is_empty());
452        let expr = if lit_vec.len() == 1 {
453            Expression::Literal(lit_vec[0])
454        } else {
455            Expression::Compose {
456                ty: eval.types.insert(
457                    Type {
458                        name: None,
459                        inner: TypeInner::Vector {
460                            size: match lit_vec.len() {
461                                2 => crate::VectorSize::Bi,
462                                3 => crate::VectorSize::Tri,
463                                4 => crate::VectorSize::Quad,
464                                _ => unreachable!(),
465                            },
466                            scalar: lit_vec[0].scalar(),
467                        },
468                    },
469                    Span::UNDEFINED,
470                ),
471                components: lit_vec
472                    .iter()
473                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474                    .collect::<Result<_, _>>()?,
475            }
476        };
477        eval.register_evaluated_expr(expr, span)
478    }
479}
480
481/// A macro for matching on [`LiteralVector`] variants.
482///
483/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
484/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
485///
486/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
487///
488/// Example usage:
489///
490/// ```rust,ignore
491/// match_literal_vector!(match v => Literal {
492///     F16 => |v| {v.sum()},
493///     Integer => |v| {v.sum()},
494///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
495/// })
496/// ```
497///
498/// ```rust,ignore
499/// match_literal_vector!(match (e1, e2) => LiteralVector {
500///     F16 => |e1, e2| {e1+e2},
501///     Integer => |e1, e2| {e1+e2},
502///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
503/// })
504/// ```
505macro_rules! match_literal_vector {
506    (match $lit_vec:expr => $out:ident {
507        $(
508            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509        ),+
510        $(,)?
511    }) => {
512        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513    };
514
515    (@inner_start
516        $lit_vec:expr;
517        $out:ident;
518        [$($ty:ident),+];
519        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520    ) => {
521        match_literal_vector!(@inner
522            $lit_vec;
523            $out;
524            [$($ty),+];
525            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526        )
527    };
528
529    (@inner
530        $lit_vec:expr;
531        $out:ident;
532        [$ty:ident $(, $ty1:ident)*];
533        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535    ) => {
536        match_literal_vector!(@inner
537            $ty;
538            $lit_vec;
539            $out;
540            [$($ty1),*];
541            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542            [$({ $($var),+ ; $($ret)? ; $body }),+]
543        )
544    };
545    (@inner
546        Integer;
547        $lit_vec:expr;
548        $out:ident;
549        [$($ty:ident),*];
550        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551        [
552            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554        ]
555    ) => {
556        match_literal_vector!(@inner
557            $lit_vec;
558            $out;
559            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561            [
562                { $($var),+ ; $($ret)? ; $body }, // U32
563                { $($var),+ ; $($ret)? ; $body }, // I32
564                { $($var),+ ; $($ret)? ; $body }, // U64
565                { $($var),+ ; $($ret)? ; $body }, // I64
566                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
567                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568            ]
569        )
570    };
571    (@inner
572        Float;
573        $lit_vec:expr;
574        $out:ident;
575        [$($ty:ident),*];
576        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577        [
578            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580        ]
581    ) => {
582        match_literal_vector!(@inner
583            $lit_vec;
584            $out;
585            [F16, F32, F64, AbstractFloat $(, $ty)*];
586            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587            [
588                { $($var),+ ; $($ret)? ; $body }, // F16
589                { $($var),+ ; $($ret)? ; $body }, // F32
590                { $($var),+ ; $($ret)? ; $body }, // F64
591                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
592                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593            ]
594        )
595    };
596    (@inner
597        $ty:ident;
598        $lit_vec:expr;
599        $out:ident;
600        [$ty1:ident $(,$ty2:ident)*];
601        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604        ]
605    ) => {
606        match_literal_vector!(@inner
607            $ty1;
608            $lit_vec;
609            $out;
610            [$($ty2),*];
611            [
612                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613                { $ty; $($var),+ ; $($ret)? ; $body }
614            ] <>
615            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617        )
618    };
619    (@inner
620        $ty:ident;
621        $lit_vec:expr;
622        $out:ident;
623        [];
624        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626    ) => {
627        match_literal_vector!(@inner_finish
628            $lit_vec;
629            $out;
630            [
631                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632                { $ty; $($var),+ ; $($ret)? ; $body }
633            ]
634        )
635    };
636    (@inner_finish
637        $lit_vec:expr;
638        $out:ident;
639        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640    ) => {
641        match $lit_vec {
642            $(
643                #[allow(unused_parens)]
644                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645            )+
646            _ => Err(ConstantEvaluatorError::InvalidMathArg),
647        }
648    };
649    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650        $out::$ty($body)
651    };
652    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653        $out::$ret($body)
654    };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659    Wgsl(WgslRestrictions<'a>),
660    Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
665    const fn has_runtime_restrictions(&self) -> bool {
666        matches!(
667            self,
668            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670        )
671    }
672}
673
674/// A context for evaluating constant expressions.
675///
676/// A `ConstantEvaluator` points at an expression arena to which it can append
677/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
678/// of Naga [`Expression`] you like, and if its value can be computed at compile
679/// time, `try_eval_and_append` appends an expression representing the computed
680/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
681/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
682///
683/// A `ConstantEvaluator` also holds whatever information we need to carry out
684/// that evaluation: types, other constants, and so on.
685///
686/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
687/// [`Compose`]: Expression::Compose
688/// [`ZeroValue`]: Expression::ZeroValue
689/// [`Literal`]: Expression::Literal
690/// [`Swizzle`]: Expression::Swizzle
691#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693    /// Which language's evaluation rules we should follow.
694    behavior: Behavior<'a>,
695
696    /// The module's type arena.
697    ///
698    /// Because expressions like [`Splat`] contain type handles, we need to be
699    /// able to add new types to produce those expressions.
700    ///
701    /// [`Splat`]: Expression::Splat
702    types: &'a mut UniqueArena<Type>,
703
704    /// The module's constant arena.
705    constants: &'a Arena<Constant>,
706
707    /// The module's override arena.
708    overrides: &'a Arena<Override>,
709
710    /// The arena to which we are contributing expressions.
711    expressions: &'a mut Arena<Expression>,
712
713    /// Tracks the constness of expressions residing in [`Self::expressions`]
714    expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716    layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721    /// - const-expressions will be evaluated and inserted in the arena
722    Const(Option<FunctionLocalData<'a>>),
723    /// - const-expressions will be evaluated and inserted in the arena
724    /// - override-expressions will be inserted in the arena
725    Override,
726    /// - const-expressions will be evaluated and inserted in the arena
727    /// - override-expressions will be inserted in the arena
728    /// - runtime-expressions will be inserted in the arena
729    Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734    /// - const-expressions will be evaluated and inserted in the arena
735    Const,
736    /// - const-expressions will be evaluated and inserted in the arena
737    /// - override-expressions will be inserted in the arena
738    /// - runtime-expressions will be inserted in the arena
739    Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744    /// Global constant expressions
745    global_expressions: &'a Arena<Expression>,
746    emitter: &'a mut super::Emitter,
747    block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752    Const,
753    Override,
754    Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759    inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763    pub const fn new() -> Self {
764        Self {
765            inner: HandleVec::new(),
766        }
767    }
768
769    /// Forces the the expression to not be const
770    pub fn force_non_const(&mut self, value: Handle<Expression>) {
771        self.inner[value] = ExpressionKind::Runtime;
772    }
773
774    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775        self.inner.insert(value, expr_type);
776    }
777
778    pub fn is_const(&self, h: Handle<Expression>) -> bool {
779        matches!(self.type_of(h), ExpressionKind::Const)
780    }
781
782    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783        matches!(
784            self.type_of(h),
785            ExpressionKind::Const | ExpressionKind::Override
786        )
787    }
788
789    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790        self.inner[value]
791    }
792
793    pub fn from_arena(arena: &Arena<Expression>) -> Self {
794        let mut tracker = Self {
795            inner: HandleVec::with_capacity(arena.len()),
796        };
797        for (handle, expr) in arena.iter() {
798            tracker
799                .inner
800                .insert(handle, tracker.type_of_with_expr(expr));
801        }
802        tracker
803    }
804
805    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806        match *expr {
807            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808                ExpressionKind::Const
809            }
810            Expression::Override(_) => ExpressionKind::Override,
811            Expression::Compose { ref components, .. } => {
812                let mut expr_type = ExpressionKind::Const;
813                for component in components {
814                    expr_type = expr_type.max(self.type_of(*component))
815                }
816                expr_type
817            }
818            Expression::Splat { value, .. } => self.type_of(value),
819            Expression::AccessIndex { base, .. } => self.type_of(base),
820            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821            Expression::Swizzle { vector, .. } => self.type_of(vector),
822            Expression::Unary { expr, .. } => self.type_of(expr),
823            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824            Expression::Math {
825                arg,
826                arg1,
827                arg2,
828                arg3,
829                ..
830            } => self
831                .type_of(arg)
832                .max(
833                    arg1.map(|arg| self.type_of(arg))
834                        .unwrap_or(ExpressionKind::Const),
835                )
836                .max(
837                    arg2.map(|arg| self.type_of(arg))
838                        .unwrap_or(ExpressionKind::Const),
839                )
840                .max(
841                    arg3.map(|arg| self.type_of(arg))
842                        .unwrap_or(ExpressionKind::Const),
843                ),
844            Expression::As { expr, .. } => self.type_of(expr),
845            Expression::Select {
846                condition,
847                accept,
848                reject,
849            } => self
850                .type_of(condition)
851                .max(self.type_of(accept))
852                .max(self.type_of(reject)),
853            Expression::Relational { argument, .. } => self.type_of(argument),
854            Expression::ArrayLength(expr) => self.type_of(expr),
855            _ => ExpressionKind::Runtime,
856        }
857    }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863    #[error("Constants cannot access function arguments")]
864    FunctionArg,
865    #[error("Constants cannot access global variables")]
866    GlobalVariable,
867    #[error("Constants cannot access local variables")]
868    LocalVariable,
869    #[error("Cannot get the array length of a non array type")]
870    InvalidArrayLengthArg,
871    #[error("Constants cannot get the array length of a dynamically sized array")]
872    ArrayLengthDynamic,
873    #[error("Cannot call arrayLength on array sized by override-expression")]
874    ArrayLengthOverridden,
875    #[error("Constants cannot call functions")]
876    Call,
877    #[error("Constants don't support workGroupUniformLoad")]
878    WorkGroupUniformLoadResult,
879    #[error("Constants don't support atomic functions")]
880    Atomic,
881    #[error("Constants don't support derivative functions")]
882    Derivative,
883    #[error("Constants don't support load expressions")]
884    Load,
885    #[error("Constants don't support image expressions")]
886    ImageExpression,
887    #[error("Constants don't support ray query expressions")]
888    RayQueryExpression,
889    #[error("Constants don't support subgroup expressions")]
890    SubgroupExpression,
891    #[error("Cannot access the type")]
892    InvalidAccessBase,
893    #[error("Cannot access at the index")]
894    InvalidAccessIndex,
895    #[error("Cannot access with index of type")]
896    InvalidAccessIndexTy,
897    #[error("Constants don't support array length expressions")]
898    ArrayLength,
899    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900    InvalidCastArg { from: String, to: String },
901    #[error("Cannot apply the unary op to the argument")]
902    InvalidUnaryOpArg,
903    #[error("Cannot apply the binary op to the arguments")]
904    InvalidBinaryOpArgs,
905    #[error("Cannot apply math function to type")]
906    InvalidMathArg,
907    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908    InvalidMathArgCount(crate::MathFunction, usize, usize),
909    #[error("Cannot apply relational function to type")]
910    InvalidRelationalArg(RelationalFunction),
911    #[error("value of `low` is greater than `high` for clamp built-in function")]
912    InvalidClamp,
913    #[error("Constructor expects {expected} components, found {actual}")]
914    InvalidVectorComposeLength { expected: usize, actual: usize },
915    #[error("Constructor must only contain vector or scalar arguments")]
916    InvalidVectorComposeComponent,
917    #[error("Splat is defined only on scalar values")]
918    SplatScalarOnly,
919    #[error("Can only swizzle vector constants")]
920    SwizzleVectorOnly,
921    #[error("swizzle component not present in source expression")]
922    SwizzleOutOfBounds,
923    #[error("Type is not constructible")]
924    TypeNotConstructible,
925    #[error("Subexpression(s) are not constant")]
926    SubexpressionsAreNotConstant,
927    #[error("Not implemented as constant expression: {0}")]
928    NotImplemented(String),
929    #[error("{0} operation overflowed")]
930    Overflow(String),
931    #[error(
932        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933    )]
934    AutomaticConversionLossy {
935        value: String,
936        to_type: &'static str,
937    },
938    #[error("Division by zero")]
939    DivisionByZero,
940    #[error("Remainder by zero")]
941    RemainderByZero,
942    #[error("RHS of shift operation is greater than or equal to 32")]
943    ShiftedMoreThan32Bits,
944    #[error(transparent)]
945    Literal(#[from] crate::valid::LiteralError),
946    #[error("Can't use pipeline-overridable constants in const-expressions")]
947    Override,
948    #[error("Unexpected runtime-expression")]
949    RuntimeExpr,
950    #[error("Unexpected override-expression")]
951    OverrideExpr,
952    #[error("Expected boolean expression for condition argument of `select`, got something else")]
953    SelectScalarConditionNotABool,
954    #[error(
955        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956        reject,
957        accept
958    )]
959    SelectVecRejectAcceptSizeMismatch {
960        reject: crate::VectorSize,
961        accept: crate::VectorSize,
962    },
963    #[error("Expected boolean vector for condition arg., got something else")]
964    SelectConditionNotAVecBool,
965    #[error(
966        "Expected same number of vector components between condition, accept, and reject args., got something else",
967    )]
968    SelectConditionVecSizeMismatch,
969    #[error(
970        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971    )]
972    SelectAcceptRejectTypeMismatch,
973    #[error("Cooperative operations can't be constant")]
974    CooperativeOperation,
975}
976
977impl<'a> ConstantEvaluator<'a> {
978    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
979    /// constant expression arena.
980    ///
981    /// Report errors according to WGSL's rules for constant evaluation.
982    pub const fn for_wgsl_module(
983        module: &'a mut crate::Module,
984        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
985        layouter: &'a mut crate::proc::Layouter,
986        in_override_ctx: bool,
987    ) -> Self {
988        Self::for_module(
989            Behavior::Wgsl(if in_override_ctx {
990                WgslRestrictions::Override
991            } else {
992                WgslRestrictions::Const(None)
993            }),
994            module,
995            global_expression_kind_tracker,
996            layouter,
997        )
998    }
999
1000    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
1001    /// constant expression arena.
1002    ///
1003    /// Report errors according to GLSL's rules for constant evaluation.
1004    pub const fn for_glsl_module(
1005        module: &'a mut crate::Module,
1006        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1007        layouter: &'a mut crate::proc::Layouter,
1008    ) -> Self {
1009        Self::for_module(
1010            Behavior::Glsl(GlslRestrictions::Const),
1011            module,
1012            global_expression_kind_tracker,
1013            layouter,
1014        )
1015    }
1016
1017    const fn for_module(
1018        behavior: Behavior<'a>,
1019        module: &'a mut crate::Module,
1020        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1021        layouter: &'a mut crate::proc::Layouter,
1022    ) -> Self {
1023        Self {
1024            behavior,
1025            types: &mut module.types,
1026            constants: &module.constants,
1027            overrides: &module.overrides,
1028            expressions: &mut module.global_expressions,
1029            expression_kind_tracker: global_expression_kind_tracker,
1030            layouter,
1031        }
1032    }
1033
1034    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1035    /// expression arena.
1036    ///
1037    /// Report errors according to WGSL's rules for constant evaluation.
1038    pub const fn for_wgsl_function(
1039        module: &'a mut crate::Module,
1040        expressions: &'a mut Arena<Expression>,
1041        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1042        layouter: &'a mut crate::proc::Layouter,
1043        emitter: &'a mut super::Emitter,
1044        block: &'a mut crate::Block,
1045        is_const: bool,
1046    ) -> Self {
1047        let local_data = FunctionLocalData {
1048            global_expressions: &module.global_expressions,
1049            emitter,
1050            block,
1051        };
1052        Self {
1053            behavior: Behavior::Wgsl(if is_const {
1054                WgslRestrictions::Const(Some(local_data))
1055            } else {
1056                WgslRestrictions::Runtime(local_data)
1057            }),
1058            types: &mut module.types,
1059            constants: &module.constants,
1060            overrides: &module.overrides,
1061            expressions,
1062            expression_kind_tracker: local_expression_kind_tracker,
1063            layouter,
1064        }
1065    }
1066
1067    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1068    /// expression arena.
1069    ///
1070    /// Report errors according to GLSL's rules for constant evaluation.
1071    pub const fn for_glsl_function(
1072        module: &'a mut crate::Module,
1073        expressions: &'a mut Arena<Expression>,
1074        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1075        layouter: &'a mut crate::proc::Layouter,
1076        emitter: &'a mut super::Emitter,
1077        block: &'a mut crate::Block,
1078    ) -> Self {
1079        Self {
1080            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1081                global_expressions: &module.global_expressions,
1082                emitter,
1083                block,
1084            })),
1085            types: &mut module.types,
1086            constants: &module.constants,
1087            overrides: &module.overrides,
1088            expressions,
1089            expression_kind_tracker: local_expression_kind_tracker,
1090            layouter,
1091        }
1092    }
1093
1094    pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1095        crate::proc::GlobalCtx {
1096            types: self.types,
1097            constants: self.constants,
1098            overrides: self.overrides,
1099            global_expressions: match self.function_local_data() {
1100                Some(data) => data.global_expressions,
1101                None => self.expressions,
1102            },
1103        }
1104    }
1105
1106    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1107        if !self.expression_kind_tracker.is_const(expr) {
1108            log::debug!("check: SubexpressionsAreNotConstant");
1109            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1110        }
1111        Ok(())
1112    }
1113
1114    fn check_and_get(
1115        &mut self,
1116        expr: Handle<Expression>,
1117    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1118        match self.expressions[expr] {
1119            Expression::Constant(c) => {
1120                // Are we working in a function's expression arena, or the
1121                // module's constant expression arena?
1122                if let Some(function_local_data) = self.function_local_data() {
1123                    // Deep-copy the constant's value into our arena.
1124                    self.copy_from(
1125                        self.constants[c].init,
1126                        function_local_data.global_expressions,
1127                    )
1128                } else {
1129                    // "See through" the constant and use its initializer.
1130                    Ok(self.constants[c].init)
1131                }
1132            }
1133            _ => {
1134                self.check(expr)?;
1135                Ok(expr)
1136            }
1137        }
1138    }
1139
1140    /// Try to evaluate `expr` at compile time.
1141    ///
1142    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1143    /// we can determine its value at compile time, we append an expression
1144    /// representing its value - a tree of [`Literal`], [`Compose`],
1145    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1146    /// `self` contributes to.
1147    ///
1148    /// If `expr`'s value cannot be determined at compile time, and `self` is
1149    /// contributing to some function's expression arena, then append `expr` to
1150    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1151    /// contributing to the module's constant expression arena; since `expr`'s
1152    /// value is not a constant, return an error.
1153    ///
1154    /// We only consider `expr` itself, without recursing into its operands. Its
1155    /// operands must all have been produced by prior calls to
1156    /// `try_eval_and_append`, to ensure that they have already been reduced to
1157    /// an evaluated form if possible.
1158    ///
1159    /// [`Literal`]: Expression::Literal
1160    /// [`Compose`]: Expression::Compose
1161    /// [`ZeroValue`]: Expression::ZeroValue
1162    /// [`Swizzle`]: Expression::Swizzle
1163    pub fn try_eval_and_append(
1164        &mut self,
1165        expr: Expression,
1166        span: Span,
1167    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1168        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1169            ExpressionKind::Const => {
1170                let eval_result = self.try_eval_and_append_impl(&expr, span);
1171                // We should be able to evaluate `Const` expressions at this
1172                // point. If we failed to, then that probably means we just
1173                // haven't implemented that part of constant evaluation. Work
1174                // around this by simply emitting it as a run-time expression.
1175                if self.behavior.has_runtime_restrictions()
1176                    && matches!(
1177                        eval_result,
1178                        Err(ConstantEvaluatorError::NotImplemented(_)
1179                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1180                    )
1181                {
1182                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1183                } else {
1184                    eval_result
1185                }
1186            }
1187            ExpressionKind::Override => match self.behavior {
1188                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1189                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1190                }
1191                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1192                    Err(ConstantEvaluatorError::OverrideExpr)
1193                }
1194
1195                // GLSL specialization constants (constant_id) become Override expressions
1196                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1197                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1198                }
1199                Behavior::Glsl(GlslRestrictions::Const) => {
1200                    Err(ConstantEvaluatorError::OverrideExpr)
1201                }
1202            },
1203            ExpressionKind::Runtime => {
1204                if self.behavior.has_runtime_restrictions() {
1205                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1206                } else {
1207                    Err(ConstantEvaluatorError::RuntimeExpr)
1208                }
1209            }
1210        }
1211    }
1212
1213    /// Is the [`Self::expressions`] arena the global module expression arena?
1214    const fn is_global_arena(&self) -> bool {
1215        matches!(
1216            self.behavior,
1217            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1218                | Behavior::Glsl(GlslRestrictions::Const)
1219        )
1220    }
1221
1222    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1223        match self.behavior {
1224            Behavior::Wgsl(
1225                WgslRestrictions::Runtime(ref function_local_data)
1226                | WgslRestrictions::Const(Some(ref function_local_data)),
1227            )
1228            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1229                Some(function_local_data)
1230            }
1231            _ => None,
1232        }
1233    }
1234
1235    fn try_eval_and_append_impl(
1236        &mut self,
1237        expr: &Expression,
1238        span: Span,
1239    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1240        log::trace!("try_eval_and_append: {expr:?}");
1241        match *expr {
1242            Expression::Constant(c) if self.is_global_arena() => {
1243                // "See through" the constant and use its initializer.
1244                // This is mainly done to avoid having constants pointing to other constants.
1245                Ok(self.constants[c].init)
1246            }
1247            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1248            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1249                self.register_evaluated_expr(expr.clone(), span)
1250            }
1251            Expression::Compose { ty, ref components } => {
1252                let components = components
1253                    .iter()
1254                    .map(|component| self.check_and_get(*component))
1255                    .collect::<Result<Vec<_>, _>>()?;
1256                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1257            }
1258            Expression::Splat { size, value } => {
1259                let value = self.check_and_get(value)?;
1260                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1261            }
1262            Expression::AccessIndex { base, index } => {
1263                let base = self.check_and_get(base)?;
1264
1265                self.access(base, index as usize, span)
1266            }
1267            Expression::Access { base, index } => {
1268                let base = self.check_and_get(base)?;
1269                let index = self.check_and_get(index)?;
1270
1271                let index_val: u32 = self
1272                    .to_ctx()
1273                    .get_const_val_from(index, self.expressions)
1274                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1275                self.access(base, index_val as usize, span)
1276            }
1277            Expression::Swizzle {
1278                size,
1279                vector,
1280                pattern,
1281            } => {
1282                let vector = self.check_and_get(vector)?;
1283
1284                self.swizzle(size, span, vector, pattern)
1285            }
1286            Expression::Unary { expr, op } => {
1287                let expr = self.check_and_get(expr)?;
1288
1289                self.unary_op(op, expr, span)
1290            }
1291            Expression::Binary { left, right, op } => {
1292                let left = self.check_and_get(left)?;
1293                let right = self.check_and_get(right)?;
1294
1295                self.binary_op(op, left, right, span)
1296            }
1297            Expression::Math {
1298                fun,
1299                arg,
1300                arg1,
1301                arg2,
1302                arg3,
1303            } => {
1304                let arg = self.check_and_get(arg)?;
1305                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1306                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1307                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1308
1309                self.math(arg, arg1, arg2, arg3, fun, span)
1310            }
1311            Expression::As {
1312                convert,
1313                expr,
1314                kind,
1315            } => {
1316                let expr = self.check_and_get(expr)?;
1317
1318                match convert {
1319                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1320                    None => Err(ConstantEvaluatorError::NotImplemented(
1321                        "bitcast built-in function".into(),
1322                    )),
1323                }
1324            }
1325            Expression::Select {
1326                reject,
1327                accept,
1328                condition,
1329            } => {
1330                let mut arg = |expr| self.check_and_get(expr);
1331
1332                let reject = arg(reject)?;
1333                let accept = arg(accept)?;
1334                let condition = arg(condition)?;
1335
1336                self.select(reject, accept, condition, span)
1337            }
1338            Expression::Relational { fun, argument } => {
1339                let argument = self.check_and_get(argument)?;
1340                self.relational(fun, argument, span)
1341            }
1342            Expression::ArrayLength(expr) => match self.behavior {
1343                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1344                Behavior::Glsl(_) => {
1345                    let expr = self.check_and_get(expr)?;
1346                    self.array_length(expr, span)
1347                }
1348            },
1349            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1350            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1351            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1352            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1353            Expression::WorkGroupUniformLoadResult { .. } => {
1354                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1355            }
1356            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1357            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1358            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1359            Expression::ImageSample { .. }
1360            | Expression::ImageLoad { .. }
1361            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1362            Expression::RayQueryProceedResult
1363            | Expression::RayQueryGetIntersection { .. }
1364            | Expression::RayQueryVertexPositions { .. } => {
1365                Err(ConstantEvaluatorError::RayQueryExpression)
1366            }
1367            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1368            Expression::SubgroupOperationResult { .. } => {
1369                Err(ConstantEvaluatorError::SubgroupExpression)
1370            }
1371            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1372                Err(ConstantEvaluatorError::CooperativeOperation)
1373            }
1374        }
1375    }
1376
1377    /// Splat `value` to `size`, without using [`Splat`] expressions.
1378    ///
1379    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1380    /// build a vector with the given `size` whose components are all
1381    /// `value`.
1382    ///
1383    /// Use `span` as the span of the inserted expressions and
1384    /// resulting types.
1385    ///
1386    /// [`Splat`]: Expression::Splat
1387    /// [`Compose`]: Expression::Compose
1388    /// [`ZeroValue`]: Expression::ZeroValue
1389    fn splat(
1390        &mut self,
1391        value: Handle<Expression>,
1392        size: crate::VectorSize,
1393        span: Span,
1394    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1395        match self.expressions[value] {
1396            Expression::Literal(literal) => {
1397                let scalar = literal.scalar();
1398                let ty = self.types.insert(
1399                    Type {
1400                        name: None,
1401                        inner: TypeInner::Vector { size, scalar },
1402                    },
1403                    span,
1404                );
1405                let expr = Expression::Compose {
1406                    ty,
1407                    components: vec![value; size as usize],
1408                };
1409                self.register_evaluated_expr(expr, span)
1410            }
1411            Expression::ZeroValue(ty) => {
1412                let inner = match self.types[ty].inner {
1413                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1414                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1415                };
1416                let res_ty = self.types.insert(Type { name: None, inner }, span);
1417                let expr = Expression::ZeroValue(res_ty);
1418                self.register_evaluated_expr(expr, span)
1419            }
1420            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1421        }
1422    }
1423
1424    fn swizzle(
1425        &mut self,
1426        size: crate::VectorSize,
1427        span: Span,
1428        src_constant: Handle<Expression>,
1429        pattern: [crate::SwizzleComponent; 4],
1430    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1431        let mut get_dst_ty = |ty| match self.types[ty].inner {
1432            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1433                Type {
1434                    name: None,
1435                    inner: TypeInner::Vector { size, scalar },
1436                },
1437                span,
1438            )),
1439            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1440        };
1441
1442        match self.expressions[src_constant] {
1443            Expression::ZeroValue(ty) => {
1444                let dst_ty = get_dst_ty(ty)?;
1445                let expr = Expression::ZeroValue(dst_ty);
1446                self.register_evaluated_expr(expr, span)
1447            }
1448            Expression::Splat { value, .. } => {
1449                let expr = Expression::Splat { size, value };
1450                self.register_evaluated_expr(expr, span)
1451            }
1452            Expression::Compose { ty, ref components } => {
1453                let dst_ty = get_dst_ty(ty)?;
1454
1455                let mut flattened = [src_constant; 4]; // dummy value
1456                let len =
1457                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1458                        .zip(flattened.iter_mut())
1459                        .map(|(component, elt)| *elt = component)
1460                        .count();
1461                let flattened = &flattened[..len];
1462
1463                let swizzled_components = pattern[..size as usize]
1464                    .iter()
1465                    .map(|&sc| {
1466                        let sc = sc as usize;
1467                        if let Some(elt) = flattened.get(sc) {
1468                            Ok(*elt)
1469                        } else {
1470                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1471                        }
1472                    })
1473                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1474                let expr = Expression::Compose {
1475                    ty: dst_ty,
1476                    components: swizzled_components,
1477                };
1478                self.register_evaluated_expr(expr, span)
1479            }
1480            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1481        }
1482    }
1483
1484    fn math(
1485        &mut self,
1486        arg: Handle<Expression>,
1487        arg1: Option<Handle<Expression>>,
1488        arg2: Option<Handle<Expression>>,
1489        arg3: Option<Handle<Expression>>,
1490        fun: crate::MathFunction,
1491        span: Span,
1492    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1493        let expected = fun.argument_count();
1494        let given = Some(arg)
1495            .into_iter()
1496            .chain(arg1)
1497            .chain(arg2)
1498            .chain(arg3)
1499            .count();
1500        if expected != given {
1501            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1502                fun, expected, given,
1503            ));
1504        }
1505
1506        // NOTE: We try to match the declaration order of `MathFunction` here.
1507        match fun {
1508            // comparison
1509            crate::MathFunction::Abs => {
1510                component_wise_scalar(self, span, [arg], |args| match args {
1511                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1512                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1513                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1514                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1515                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1516                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1517                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1518                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1519                })
1520            }
1521            crate::MathFunction::Min => {
1522                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1523                    Ok([e1.min(e2)])
1524                })
1525            }
1526            crate::MathFunction::Max => {
1527                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1528                    Ok([e1.max(e2)])
1529                })
1530            }
1531            crate::MathFunction::Clamp => {
1532                component_wise_scalar!(
1533                    self,
1534                    span,
1535                    [arg, arg1.unwrap(), arg2.unwrap()],
1536                    |e, low, high| {
1537                        if low > high {
1538                            Err(ConstantEvaluatorError::InvalidClamp)
1539                        } else {
1540                            Ok([e.clamp(low, high)])
1541                        }
1542                    }
1543                )
1544            }
1545            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1546                Float::F16([e]) => Ok(Float::F16(
1547                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1548                )),
1549                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1550                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1551            }),
1552
1553            // trigonometry
1554            crate::MathFunction::Cos => {
1555                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1556            }
1557            crate::MathFunction::Cosh => {
1558                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1559            }
1560            crate::MathFunction::Sin => {
1561                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1562            }
1563            crate::MathFunction::Sinh => {
1564                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1565            }
1566            crate::MathFunction::Tan => {
1567                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1568            }
1569            crate::MathFunction::Tanh => {
1570                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1571            }
1572            crate::MathFunction::Acos => {
1573                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1574            }
1575            crate::MathFunction::Asin => {
1576                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1577            }
1578            crate::MathFunction::Atan => {
1579                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1580            }
1581            crate::MathFunction::Atan2 => {
1582                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1583                    Ok([y.atan2(x)])
1584                })
1585            }
1586            crate::MathFunction::Asinh => {
1587                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1588            }
1589            crate::MathFunction::Acosh => {
1590                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1591            }
1592            crate::MathFunction::Atanh => {
1593                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1594            }
1595            crate::MathFunction::Radians => {
1596                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1597            }
1598            crate::MathFunction::Degrees => {
1599                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1600            }
1601
1602            // decomposition
1603            crate::MathFunction::Ceil => {
1604                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1605            }
1606            crate::MathFunction::Floor => {
1607                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1608            }
1609            crate::MathFunction::Round => {
1610                component_wise_float(self, span, [arg], |e| match e {
1611                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1612                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1613                    Float::F16([e]) => {
1614                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1615                        //
1616                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1617                        // which has licensing compatible with ours. See also
1618                        // <https://github.com/rust-lang/rust/issues/96710>.
1619                        //
1620                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1621                        fn round_ties_even(x: f64) -> f64 {
1622                            let i = x as i64;
1623                            let f = (x - i as f64).abs();
1624                            if f == 0.5 {
1625                                if i & 1 == 1 {
1626                                    // -1.5, 1.5, 3.5, ...
1627                                    (x.abs() + 0.5).copysign(x)
1628                                } else {
1629                                    (x.abs() - 0.5).copysign(x)
1630                                }
1631                            } else {
1632                                x.round()
1633                            }
1634                        }
1635
1636                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1637                    }
1638                })
1639            }
1640            crate::MathFunction::Fract => {
1641                component_wise_float!(self, span, [arg], |e| {
1642                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1643                    // here.
1644                    Ok([e - e.floor()])
1645                })
1646            }
1647            crate::MathFunction::Trunc => {
1648                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1649            }
1650
1651            // exponent
1652            crate::MathFunction::Exp => {
1653                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1654            }
1655            crate::MathFunction::Exp2 => {
1656                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1657            }
1658            crate::MathFunction::Log => {
1659                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1660            }
1661            crate::MathFunction::Log2 => {
1662                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1663            }
1664            crate::MathFunction::Pow => {
1665                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1666                    Ok([e1.powf(e2)])
1667                })
1668            }
1669
1670            // computational
1671            crate::MathFunction::Sign => {
1672                component_wise_signed!(self, span, [arg], |e| {
1673                    Ok([if e.is_zero() {
1674                        Zero::zero()
1675                    } else {
1676                        e.signum()
1677                    }])
1678                })
1679            }
1680            crate::MathFunction::Fma => {
1681                component_wise_float!(
1682                    self,
1683                    span,
1684                    [arg, arg1.unwrap(), arg2.unwrap()],
1685                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1686                )
1687            }
1688            crate::MathFunction::Step => {
1689                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1690                    Float::Abstract([edge, x]) => {
1691                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1692                    }
1693                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1694                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1695                        f16::one()
1696                    } else {
1697                        f16::zero()
1698                    }])),
1699                })
1700            }
1701            crate::MathFunction::Sqrt => {
1702                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1703            }
1704            crate::MathFunction::InverseSqrt => {
1705                component_wise_float(self, span, [arg], |e| match e {
1706                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1707                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1708                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1709                })
1710            }
1711
1712            // bits
1713            crate::MathFunction::CountTrailingZeros => {
1714                component_wise_concrete_int!(self, span, [arg], |e| {
1715                    #[allow(clippy::useless_conversion)]
1716                    Ok([e
1717                        .trailing_zeros()
1718                        .try_into()
1719                        .expect("bit count overflowed 32 bits, somehow!?")])
1720                })
1721            }
1722            crate::MathFunction::CountLeadingZeros => {
1723                component_wise_concrete_int!(self, span, [arg], |e| {
1724                    #[allow(clippy::useless_conversion)]
1725                    Ok([e
1726                        .leading_zeros()
1727                        .try_into()
1728                        .expect("bit count overflowed 32 bits, somehow!?")])
1729                })
1730            }
1731            crate::MathFunction::CountOneBits => {
1732                component_wise_concrete_int!(self, span, [arg], |e| {
1733                    #[allow(clippy::useless_conversion)]
1734                    Ok([e
1735                        .count_ones()
1736                        .try_into()
1737                        .expect("bit count overflowed 32 bits, somehow!?")])
1738                })
1739            }
1740            crate::MathFunction::ReverseBits => {
1741                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1742            }
1743            crate::MathFunction::FirstTrailingBit => {
1744                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1745            }
1746            crate::MathFunction::FirstLeadingBit => {
1747                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1748            }
1749
1750            // vector
1751            crate::MathFunction::Dot4I8Packed => {
1752                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1753            }
1754            crate::MathFunction::Dot4U8Packed => {
1755                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1756            }
1757            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1758            crate::MathFunction::Dot => {
1759                // https://www.w3.org/TR/WGSL/#dot-builtin
1760                let e1 = self.extract_vec(arg, false)?;
1761                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1762                if e1.len() != e2.len() {
1763                    return Err(ConstantEvaluatorError::InvalidMathArg);
1764                }
1765
1766                fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1767                where
1768                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1769                {
1770                    a.iter()
1771                        .zip(b.iter())
1772                        .map(|(&aa, bb)| aa.checked_mul(bb))
1773                        .try_fold(P::zero(), |acc, x| {
1774                            if let Some(x) = x {
1775                                acc.checked_add(&x)
1776                            } else {
1777                                None
1778                            }
1779                        })
1780                        .ok_or(ConstantEvaluatorError::Overflow(
1781                            "in dot built-in".to_string(),
1782                        ))
1783                }
1784
1785                let result = match_literal_vector!(match (e1, e2) => Literal {
1786                    Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1787                    Integer => |e1, e2| { int_dot(e1, e2)? },
1788                })?;
1789                self.register_evaluated_expr(Expression::Literal(result), span)
1790            }
1791            crate::MathFunction::Length => {
1792                // https://www.w3.org/TR/WGSL/#length-builtin
1793                let e1 = self.extract_vec(arg, true)?;
1794
1795                fn float_length<F>(e: &[F]) -> F
1796                where
1797                    F: core::ops::Mul<F>,
1798                    F: num_traits::Float + iter::Sum,
1799                {
1800                    if e.len() == 1 {
1801                        // Avoids possible overflow in squaring
1802                        e[0].abs()
1803                    } else {
1804                        e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1805                    }
1806                }
1807
1808                let result = match_literal_vector!(match e1 => Literal {
1809                    Float => |e1| { float_length(e1) },
1810                })?;
1811                self.register_evaluated_expr(Expression::Literal(result), span)
1812            }
1813            crate::MathFunction::Distance => {
1814                // https://www.w3.org/TR/WGSL/#distance-builtin
1815                let e1 = self.extract_vec(arg, true)?;
1816                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1817                if e1.len() != e2.len() {
1818                    return Err(ConstantEvaluatorError::InvalidMathArg);
1819                }
1820
1821                fn float_distance<F>(a: &[F], b: &[F]) -> F
1822                where
1823                    F: core::ops::Mul<F>,
1824                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1825                {
1826                    if a.len() == 1 {
1827                        // Avoids possible overflow in squaring
1828                        (a[0] - b[0]).abs()
1829                    } else {
1830                        a.iter()
1831                            .zip(b.iter())
1832                            .map(|(&aa, &bb)| aa - bb)
1833                            .map(|ei| ei * ei)
1834                            .sum::<F>()
1835                            .sqrt()
1836                    }
1837                }
1838                let result = match_literal_vector!(match (e1, e2) => Literal {
1839                    Float => |e1, e2| { float_distance(e1, e2) },
1840                })?;
1841                self.register_evaluated_expr(Expression::Literal(result), span)
1842            }
1843            crate::MathFunction::Normalize => {
1844                // https://www.w3.org/TR/WGSL/#normalize-builtin
1845                let e1 = self.extract_vec(arg, true)?;
1846
1847                fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1848                where
1849                    F: core::ops::Mul<F>,
1850                    F: num_traits::Float + iter::Sum,
1851                {
1852                    let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1853                    e.iter().map(|&ei| ei / len).collect()
1854                }
1855
1856                let result = match_literal_vector!(match e1 => LiteralVector {
1857                    Float => |e1| { float_normalize(e1) },
1858                })?;
1859                result.register_as_evaluated_expr(self, span)
1860            }
1861
1862            // unimplemented
1863            crate::MathFunction::Modf
1864            | crate::MathFunction::Frexp
1865            | crate::MathFunction::Ldexp
1866            | crate::MathFunction::Outer
1867            | crate::MathFunction::FaceForward
1868            | crate::MathFunction::Reflect
1869            | crate::MathFunction::Refract
1870            | crate::MathFunction::Mix
1871            | crate::MathFunction::SmoothStep
1872            | crate::MathFunction::Inverse
1873            | crate::MathFunction::Transpose
1874            | crate::MathFunction::Determinant
1875            | crate::MathFunction::QuantizeToF16
1876            | crate::MathFunction::ExtractBits
1877            | crate::MathFunction::InsertBits
1878            | crate::MathFunction::Pack4x8snorm
1879            | crate::MathFunction::Pack4x8unorm
1880            | crate::MathFunction::Pack2x16snorm
1881            | crate::MathFunction::Pack2x16unorm
1882            | crate::MathFunction::Pack2x16float
1883            | crate::MathFunction::Pack4xI8
1884            | crate::MathFunction::Pack4xU8
1885            | crate::MathFunction::Pack4xI8Clamp
1886            | crate::MathFunction::Pack4xU8Clamp
1887            | crate::MathFunction::Unpack4x8snorm
1888            | crate::MathFunction::Unpack4x8unorm
1889            | crate::MathFunction::Unpack2x16snorm
1890            | crate::MathFunction::Unpack2x16unorm
1891            | crate::MathFunction::Unpack2x16float
1892            | crate::MathFunction::Unpack4xI8
1893            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1894                format!("{fun:?} built-in function"),
1895            )),
1896        }
1897    }
1898
1899    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1900    fn packed_dot_product(
1901        &mut self,
1902        a: Handle<Expression>,
1903        b: Handle<Expression>,
1904        span: Span,
1905        signed: bool,
1906    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1907        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1908            return Err(ConstantEvaluatorError::InvalidMathArg);
1909        };
1910        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1911            return Err(ConstantEvaluatorError::InvalidMathArg);
1912        };
1913
1914        let result = if signed {
1915            Literal::I32(
1916                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1917                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1918                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1919                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1920            )
1921        } else {
1922            Literal::U32(
1923                (a & 0xFF) * (b & 0xFF)
1924                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1925                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1926                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1927            )
1928        };
1929
1930        self.register_evaluated_expr(Expression::Literal(result), span)
1931    }
1932
1933    /// Vector cross product.
1934    fn cross_product(
1935        &mut self,
1936        a: Handle<Expression>,
1937        b: Handle<Expression>,
1938        span: Span,
1939    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1940        use Literal as Li;
1941
1942        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1943        let (b, _) = self.extract_vec_with_size::<3>(b)?;
1944
1945        let product = match (a, b) {
1946            (
1947                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1948                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1949            ) => {
1950                // `cross` has no overload for AbstractInt, so AbstractInt
1951                // arguments are automatically converted to AbstractFloat. Since
1952                // `f64` has a much wider range than `i64`, there's no danger of
1953                // overflow here.
1954                let p = cross_product(
1955                    [a0 as f64, a1 as f64, a2 as f64],
1956                    [b0 as f64, b1 as f64, b2 as f64],
1957                );
1958                [
1959                    Li::AbstractFloat(p[0]),
1960                    Li::AbstractFloat(p[1]),
1961                    Li::AbstractFloat(p[2]),
1962                ]
1963            }
1964            (
1965                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1966                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1967            ) => {
1968                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1969                [
1970                    Li::AbstractFloat(p[0]),
1971                    Li::AbstractFloat(p[1]),
1972                    Li::AbstractFloat(p[2]),
1973                ]
1974            }
1975            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1976                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1977                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1978            }
1979            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1980                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1981                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1982            }
1983            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1984                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1985                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1986            }
1987            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1988        };
1989
1990        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1991        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1992        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1993
1994        self.register_evaluated_expr(
1995            Expression::Compose {
1996                ty,
1997                components: vec![p0, p1, p2],
1998            },
1999            span,
2000        )
2001    }
2002
2003    /// Extract the values of a `vecN` from `expr`.
2004    ///
2005    /// Return the value of `expr`, whose type is `vecN<S>` for some
2006    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2007    /// values.
2008    ///
2009    /// Also return the type handle from the `Compose` expression.
2010    fn extract_vec_with_size<const N: usize>(
2011        &mut self,
2012        expr: Handle<Expression>,
2013    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2014        let span = self.expressions.get_span(expr);
2015        let expr = self.eval_zero_value_and_splat(expr, span)?;
2016        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2017            return Err(ConstantEvaluatorError::InvalidMathArg);
2018        };
2019
2020        let mut value = [Literal::Bool(false); N];
2021        for (component, elt) in
2022            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2023                .zip(value.iter_mut())
2024        {
2025            let Expression::Literal(literal) = self.expressions[component] else {
2026                return Err(ConstantEvaluatorError::InvalidMathArg);
2027            };
2028            *elt = literal;
2029        }
2030
2031        Ok((value, ty))
2032    }
2033
2034    /// Extract the values of a `vecN` from `expr`.
2035    ///
2036    /// Return the value of `expr`, whose type is `vecN<S>` for some
2037    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2038    /// values.
2039    ///
2040    /// Also return the type handle from the `Compose` expression.
2041    fn extract_vec(
2042        &mut self,
2043        expr: Handle<Expression>,
2044        allow_single: bool,
2045    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2046        let span = self.expressions.get_span(expr);
2047        let expr = self.eval_zero_value_and_splat(expr, span)?;
2048
2049        match self.expressions[expr] {
2050            Expression::Literal(literal) if allow_single => {
2051                Ok(LiteralVector::from_literal(literal))
2052            }
2053            Expression::Compose { ty, ref components } => {
2054                let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2055                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2056                        .map(|expr| match self.expressions[expr] {
2057                            Expression::Literal(l) => Ok(l),
2058                            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2059                        })
2060                        .collect::<Result<_, ConstantEvaluatorError>>()?;
2061                LiteralVector::from_literal_vec(components)
2062            }
2063            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2064        }
2065    }
2066
2067    fn array_length(
2068        &mut self,
2069        array: Handle<Expression>,
2070        span: Span,
2071    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2072        match self.expressions[array] {
2073            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2074                match self.types[ty].inner {
2075                    TypeInner::Array { size, .. } => match size {
2076                        ArraySize::Constant(len) => {
2077                            let expr = Expression::Literal(Literal::U32(len.get()));
2078                            self.register_evaluated_expr(expr, span)
2079                        }
2080                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2081                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2082                    },
2083                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2084                }
2085            }
2086            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2087        }
2088    }
2089
2090    fn access(
2091        &mut self,
2092        base: Handle<Expression>,
2093        index: usize,
2094        span: Span,
2095    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2096        match self.expressions[base] {
2097            Expression::ZeroValue(ty) => {
2098                let ty_inner = &self.types[ty].inner;
2099                let components = ty_inner
2100                    .components()
2101                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2102
2103                if index >= components as usize {
2104                    Err(ConstantEvaluatorError::InvalidAccessBase)
2105                } else {
2106                    let ty_res = ty_inner
2107                        .component_type(index)
2108                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2109                    let ty = match ty_res {
2110                        crate::proc::TypeResolution::Handle(ty) => ty,
2111                        crate::proc::TypeResolution::Value(inner) => {
2112                            self.types.insert(Type { name: None, inner }, span)
2113                        }
2114                    };
2115                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2116                }
2117            }
2118            Expression::Splat { size, value } => {
2119                if index >= size as usize {
2120                    Err(ConstantEvaluatorError::InvalidAccessBase)
2121                } else {
2122                    Ok(value)
2123                }
2124            }
2125            Expression::Compose { ty, ref components } => {
2126                let _ = self.types[ty]
2127                    .inner
2128                    .components()
2129                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2130
2131                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2132                    .nth(index)
2133                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2134            }
2135            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2136        }
2137    }
2138
2139    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2140    ///
2141    /// [`ZeroValue`]: Expression::ZeroValue
2142    /// [`Splat`]: Expression::Splat
2143    /// [`Literal`]: Expression::Literal
2144    /// [`Compose`]: Expression::Compose
2145    fn eval_zero_value_and_splat(
2146        &mut self,
2147        mut expr: Handle<Expression>,
2148        span: Span,
2149    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2150        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2151        // each of its components.
2152        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2153            let components = components
2154                .clone()
2155                .iter()
2156                .map(|component| self.eval_zero_value_and_splat(*component, span))
2157                .collect::<Result<_, _>>()?;
2158            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2159        }
2160
2161        // The result of the splat() for a Splat of a scalar ZeroValue is a
2162        // vector ZeroValue, so we must call eval_zero_value_impl() after
2163        // splat() in order to ensure we have no ZeroValues remaining.
2164        if let Expression::Splat { size, value } = self.expressions[expr] {
2165            expr = self.splat(value, size, span)?;
2166        }
2167        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2168            expr = self.eval_zero_value_impl(ty, span)?;
2169        }
2170        Ok(expr)
2171    }
2172
2173    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2174    ///
2175    /// [`ZeroValue`]: Expression::ZeroValue
2176    /// [`Literal`]: Expression::Literal
2177    /// [`Compose`]: Expression::Compose
2178    fn eval_zero_value(
2179        &mut self,
2180        expr: Handle<Expression>,
2181        span: Span,
2182    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2183        match self.expressions[expr] {
2184            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2185            _ => Ok(expr),
2186        }
2187    }
2188
2189    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2190    ///
2191    /// [`ZeroValue`]: Expression::ZeroValue
2192    /// [`Literal`]: Expression::Literal
2193    /// [`Compose`]: Expression::Compose
2194    fn eval_zero_value_impl(
2195        &mut self,
2196        ty: Handle<Type>,
2197        span: Span,
2198    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2199        match self.types[ty].inner {
2200            TypeInner::Scalar(scalar) => {
2201                let expr = Expression::Literal(
2202                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2203                );
2204                self.register_evaluated_expr(expr, span)
2205            }
2206            TypeInner::Vector { size, scalar } => {
2207                let scalar_ty = self.types.insert(
2208                    Type {
2209                        name: None,
2210                        inner: TypeInner::Scalar(scalar),
2211                    },
2212                    span,
2213                );
2214                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2215                let expr = Expression::Compose {
2216                    ty,
2217                    components: vec![el; size as usize],
2218                };
2219                self.register_evaluated_expr(expr, span)
2220            }
2221            TypeInner::Matrix {
2222                columns,
2223                rows,
2224                scalar,
2225            } => {
2226                let vec_ty = self.types.insert(
2227                    Type {
2228                        name: None,
2229                        inner: TypeInner::Vector { size: rows, scalar },
2230                    },
2231                    span,
2232                );
2233                let el = self.eval_zero_value_impl(vec_ty, span)?;
2234                let expr = Expression::Compose {
2235                    ty,
2236                    components: vec![el; columns as usize],
2237                };
2238                self.register_evaluated_expr(expr, span)
2239            }
2240            TypeInner::Array {
2241                base,
2242                size: ArraySize::Constant(size),
2243                ..
2244            } => {
2245                let el = self.eval_zero_value_impl(base, span)?;
2246                let expr = Expression::Compose {
2247                    ty,
2248                    components: vec![el; size.get() as usize],
2249                };
2250                self.register_evaluated_expr(expr, span)
2251            }
2252            TypeInner::Struct { ref members, .. } => {
2253                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2254                let mut components = Vec::with_capacity(members.len());
2255                for ty in types {
2256                    components.push(self.eval_zero_value_impl(ty, span)?);
2257                }
2258                let expr = Expression::Compose { ty, components };
2259                self.register_evaluated_expr(expr, span)
2260            }
2261            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2262        }
2263    }
2264
2265    /// Convert the scalar components of `expr` to `target`.
2266    ///
2267    /// Treat `span` as the location of the resulting expression.
2268    pub fn cast(
2269        &mut self,
2270        expr: Handle<Expression>,
2271        target: crate::Scalar,
2272        span: Span,
2273    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2274        use crate::Scalar as Sc;
2275
2276        let expr = self.eval_zero_value(expr, span)?;
2277
2278        let make_error = || -> Result<_, ConstantEvaluatorError> {
2279            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2280
2281            #[cfg(feature = "wgsl-in")]
2282            let to = target.to_wgsl_for_diagnostics();
2283
2284            #[cfg(not(feature = "wgsl-in"))]
2285            let to = format!("{target:?}");
2286
2287            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2288        };
2289
2290        use crate::proc::type_methods::IntFloatLimits;
2291
2292        let expr = match self.expressions[expr] {
2293            Expression::Literal(literal) => {
2294                let literal = match target {
2295                    Sc::I32 => Literal::I32(match literal {
2296                        Literal::I32(v) => v,
2297                        Literal::U32(v) => v as i32,
2298                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2299                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2300                        Literal::Bool(v) => v as i32,
2301                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2302                            return make_error();
2303                        }
2304                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2305                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2306                    }),
2307                    Sc::U32 => Literal::U32(match literal {
2308                        Literal::I32(v) => v as u32,
2309                        Literal::U32(v) => v,
2310                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2311                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2312                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2313                        Literal::Bool(v) => v as u32,
2314                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2315                            return make_error();
2316                        }
2317                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2318                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2319                    }),
2320                    Sc::I64 => Literal::I64(match literal {
2321                        Literal::I32(v) => v as i64,
2322                        Literal::U32(v) => v as i64,
2323                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2324                        Literal::Bool(v) => v as i64,
2325                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2326                        Literal::I64(v) => v,
2327                        Literal::U64(v) => v as i64,
2328                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2329                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2330                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2331                    }),
2332                    Sc::U64 => Literal::U64(match literal {
2333                        Literal::I32(v) => v as u64,
2334                        Literal::U32(v) => v as u64,
2335                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2336                        Literal::Bool(v) => v as u64,
2337                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2338                        Literal::I64(v) => v as u64,
2339                        Literal::U64(v) => v,
2340                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2341                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2342                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2343                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2344                    }),
2345                    Sc::F16 => Literal::F16(match literal {
2346                        Literal::F16(v) => v,
2347                        Literal::F32(v) => f16::from_f32(v),
2348                        Literal::F64(v) => f16::from_f64(v),
2349                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2350                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2351                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2352                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2353                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2354                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2355                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2356                    }),
2357                    Sc::F32 => Literal::F32(match literal {
2358                        Literal::I32(v) => v as f32,
2359                        Literal::U32(v) => v as f32,
2360                        Literal::F32(v) => v,
2361                        Literal::Bool(v) => v as u32 as f32,
2362                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2363                            return make_error();
2364                        }
2365                        Literal::F16(v) => f16::to_f32(v),
2366                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2367                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2368                    }),
2369                    Sc::F64 => Literal::F64(match literal {
2370                        Literal::I32(v) => v as f64,
2371                        Literal::U32(v) => v as f64,
2372                        Literal::F16(v) => f16::to_f64(v),
2373                        Literal::F32(v) => v as f64,
2374                        Literal::F64(v) => v,
2375                        Literal::Bool(v) => v as u32 as f64,
2376                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2377                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2378                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2379                    }),
2380                    Sc::BOOL => Literal::Bool(match literal {
2381                        Literal::I32(v) => v != 0,
2382                        Literal::U32(v) => v != 0,
2383                        Literal::F32(v) => v != 0.0,
2384                        Literal::F16(v) => v != f16::zero(),
2385                        Literal::Bool(v) => v,
2386                        Literal::AbstractInt(v) => v != 0,
2387                        Literal::AbstractFloat(v) => v != 0.0,
2388                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2389                            return make_error();
2390                        }
2391                    }),
2392                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2393                        Literal::AbstractInt(v) => {
2394                            // Overflow is forbidden, but inexact conversions
2395                            // are fine. The range of f64 is far larger than
2396                            // that of i64, so we don't have to check anything
2397                            // here.
2398                            v as f64
2399                        }
2400                        Literal::AbstractFloat(v) => v,
2401                        _ => return make_error(),
2402                    }),
2403                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2404                        Literal::AbstractInt(v) => v,
2405                        _ => return make_error(),
2406                    }),
2407                    _ => {
2408                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2409                        return make_error();
2410                    }
2411                };
2412                Expression::Literal(literal)
2413            }
2414            Expression::Compose {
2415                ty,
2416                components: ref src_components,
2417            } => {
2418                let ty_inner = match self.types[ty].inner {
2419                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2420                        size,
2421                        scalar: target,
2422                    },
2423                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2424                        columns,
2425                        rows,
2426                        scalar: target,
2427                    },
2428                    _ => return make_error(),
2429                };
2430
2431                let mut components = src_components.clone();
2432                for component in &mut components {
2433                    *component = self.cast(*component, target, span)?;
2434                }
2435
2436                let ty = self.types.insert(
2437                    Type {
2438                        name: None,
2439                        inner: ty_inner,
2440                    },
2441                    span,
2442                );
2443
2444                Expression::Compose { ty, components }
2445            }
2446            Expression::Splat { size, value } => {
2447                let value_span = self.expressions.get_span(value);
2448                let cast_value = self.cast(value, target, value_span)?;
2449                Expression::Splat {
2450                    size,
2451                    value: cast_value,
2452                }
2453            }
2454            _ => return make_error(),
2455        };
2456
2457        self.register_evaluated_expr(expr, span)
2458    }
2459
2460    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2461    ///
2462    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2463    /// matrix, or nested arrays of such.
2464    ///
2465    /// This is basically the same as the [`cast`] method, except that that
2466    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2467    ///
2468    /// Treat `span` as the location of the resulting expression.
2469    ///
2470    /// [`cast`]: ConstantEvaluator::cast
2471    /// [`As`]: crate::Expression::As
2472    pub fn cast_array(
2473        &mut self,
2474        expr: Handle<Expression>,
2475        target: crate::Scalar,
2476        span: Span,
2477    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2478        let expr = self.check_and_get(expr)?;
2479
2480        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2481            return self.cast(expr, target, span);
2482        };
2483
2484        let TypeInner::Array {
2485            base: _,
2486            size,
2487            stride: _,
2488        } = self.types[ty].inner
2489        else {
2490            return self.cast(expr, target, span);
2491        };
2492
2493        let mut components = components.clone();
2494        for component in &mut components {
2495            *component = self.cast_array(*component, target, span)?;
2496        }
2497
2498        let first = components.first().unwrap();
2499        let new_base = match self.resolve_type(*first)? {
2500            crate::proc::TypeResolution::Handle(ty) => ty,
2501            crate::proc::TypeResolution::Value(inner) => {
2502                self.types.insert(Type { name: None, inner }, span)
2503            }
2504        };
2505        let mut layouter = core::mem::take(self.layouter);
2506        layouter.update(self.to_ctx()).unwrap();
2507        *self.layouter = layouter;
2508
2509        let new_base_stride = self.layouter[new_base].to_stride();
2510        let new_array_ty = self.types.insert(
2511            Type {
2512                name: None,
2513                inner: TypeInner::Array {
2514                    base: new_base,
2515                    size,
2516                    stride: new_base_stride,
2517                },
2518            },
2519            span,
2520        );
2521
2522        let compose = Expression::Compose {
2523            ty: new_array_ty,
2524            components,
2525        };
2526        self.register_evaluated_expr(compose, span)
2527    }
2528
2529    fn unary_op(
2530        &mut self,
2531        op: UnaryOperator,
2532        expr: Handle<Expression>,
2533        span: Span,
2534    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2535        let expr = self.eval_zero_value_and_splat(expr, span)?;
2536
2537        let expr = match self.expressions[expr] {
2538            Expression::Literal(value) => Expression::Literal(match op {
2539                UnaryOperator::Negate => match value {
2540                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2541                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2542                    Literal::F32(v) => Literal::F32(-v),
2543                    Literal::F16(v) => Literal::F16(-v),
2544                    Literal::F64(v) => Literal::F64(-v),
2545                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2546                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2547                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2548                },
2549                UnaryOperator::LogicalNot => match value {
2550                    Literal::Bool(v) => Literal::Bool(!v),
2551                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2552                },
2553                UnaryOperator::BitwiseNot => match value {
2554                    Literal::I32(v) => Literal::I32(!v),
2555                    Literal::I64(v) => Literal::I64(!v),
2556                    Literal::U32(v) => Literal::U32(!v),
2557                    Literal::U64(v) => Literal::U64(!v),
2558                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2559                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2560                },
2561            }),
2562            Expression::Compose {
2563                ty,
2564                components: ref src_components,
2565            } => {
2566                match self.types[ty].inner {
2567                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2568                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2569                }
2570
2571                let mut components = src_components.clone();
2572                for component in &mut components {
2573                    *component = self.unary_op(op, *component, span)?;
2574                }
2575
2576                Expression::Compose { ty, components }
2577            }
2578            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2579        };
2580
2581        self.register_evaluated_expr(expr, span)
2582    }
2583
2584    fn binary_op(
2585        &mut self,
2586        op: BinaryOperator,
2587        left: Handle<Expression>,
2588        right: Handle<Expression>,
2589        span: Span,
2590    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2591        let left = self.eval_zero_value_and_splat(left, span)?;
2592        let right = self.eval_zero_value_and_splat(right, span)?;
2593
2594        // Note: in most cases constant evaluation checks for overflow, but for
2595        // i32/u32, it uses wrapping arithmetic. See
2596        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2597
2598        let expr = match (&self.expressions[left], &self.expressions[right]) {
2599            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2600                let literal = match op {
2601                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2602                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2603                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2604                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2605                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2606                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2607
2608                    _ => match (left_value, right_value) {
2609                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2610                            BinaryOperator::Add => a.wrapping_add(b),
2611                            BinaryOperator::Subtract => a.wrapping_sub(b),
2612                            BinaryOperator::Multiply => a.wrapping_mul(b),
2613                            BinaryOperator::Divide => {
2614                                if b == 0 {
2615                                    return Err(ConstantEvaluatorError::DivisionByZero);
2616                                } else {
2617                                    a.wrapping_div(b)
2618                                }
2619                            }
2620                            BinaryOperator::Modulo => {
2621                                if b == 0 {
2622                                    return Err(ConstantEvaluatorError::RemainderByZero);
2623                                } else {
2624                                    a.wrapping_rem(b)
2625                                }
2626                            }
2627                            BinaryOperator::And => a & b,
2628                            BinaryOperator::ExclusiveOr => a ^ b,
2629                            BinaryOperator::InclusiveOr => a | b,
2630                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2631                        }),
2632                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2633                            BinaryOperator::ShiftLeft => {
2634                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2635                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2636                                }
2637                                a.checked_shl(b)
2638                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2639                            }
2640                            BinaryOperator::ShiftRight => a
2641                                .checked_shr(b)
2642                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2643                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2644                        }),
2645                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2646                            BinaryOperator::Add => a.wrapping_add(b),
2647                            BinaryOperator::Subtract => a.wrapping_sub(b),
2648                            BinaryOperator::Multiply => a.wrapping_mul(b),
2649                            BinaryOperator::Divide => a
2650                                .checked_div(b)
2651                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2652                            BinaryOperator::Modulo => a
2653                                .checked_rem(b)
2654                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2655                            BinaryOperator::And => a & b,
2656                            BinaryOperator::ExclusiveOr => a ^ b,
2657                            BinaryOperator::InclusiveOr => a | b,
2658                            BinaryOperator::ShiftLeft => a
2659                                .checked_mul(
2660                                    1u32.checked_shl(b)
2661                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2662                                )
2663                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2664                            BinaryOperator::ShiftRight => a
2665                                .checked_shr(b)
2666                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2667                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2668                        }),
2669                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2670                            BinaryOperator::Add => a + b,
2671                            BinaryOperator::Subtract => a - b,
2672                            BinaryOperator::Multiply => a * b,
2673                            BinaryOperator::Divide => a / b,
2674                            BinaryOperator::Modulo => a % b,
2675                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2676                        }),
2677                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2678                            Literal::AbstractInt(match op {
2679                                BinaryOperator::ShiftLeft => {
2680                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2681                                        return Err(ConstantEvaluatorError::Overflow(
2682                                            "<<".to_string(),
2683                                        ));
2684                                    }
2685                                    a.checked_shl(b).unwrap_or(0)
2686                                }
2687                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2688                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2689                            })
2690                        }
2691                        (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2692                            BinaryOperator::Add => a + b,
2693                            BinaryOperator::Subtract => a - b,
2694                            BinaryOperator::Multiply => a * b,
2695                            BinaryOperator::Divide => a / b,
2696                            BinaryOperator::Modulo => a % b,
2697                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2698                        }),
2699                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2700                            Literal::AbstractInt(match op {
2701                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2702                                    ConstantEvaluatorError::Overflow("addition".into())
2703                                })?,
2704                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2705                                    ConstantEvaluatorError::Overflow("subtraction".into())
2706                                })?,
2707                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2708                                    ConstantEvaluatorError::Overflow("multiplication".into())
2709                                })?,
2710                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2711                                    if b == 0 {
2712                                        ConstantEvaluatorError::DivisionByZero
2713                                    } else {
2714                                        ConstantEvaluatorError::Overflow("division".into())
2715                                    }
2716                                })?,
2717                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2718                                    if b == 0 {
2719                                        ConstantEvaluatorError::RemainderByZero
2720                                    } else {
2721                                        ConstantEvaluatorError::Overflow("remainder".into())
2722                                    }
2723                                })?,
2724                                BinaryOperator::And => a & b,
2725                                BinaryOperator::ExclusiveOr => a ^ b,
2726                                BinaryOperator::InclusiveOr => a | b,
2727                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728                            })
2729                        }
2730                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2731                            Literal::AbstractFloat(match op {
2732                                BinaryOperator::Add => a + b,
2733                                BinaryOperator::Subtract => a - b,
2734                                BinaryOperator::Multiply => a * b,
2735                                BinaryOperator::Divide => a / b,
2736                                BinaryOperator::Modulo => a % b,
2737                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2738                            })
2739                        }
2740                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2741                            BinaryOperator::LogicalAnd => a && b,
2742                            BinaryOperator::LogicalOr => a || b,
2743                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2744                        }),
2745                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2746                    },
2747                };
2748                Expression::Literal(literal)
2749            }
2750            (
2751                &Expression::Compose {
2752                    components: ref src_components,
2753                    ty,
2754                },
2755                &Expression::Literal(_),
2756            ) => match op {
2757                BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
2758                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2759                }
2760                _ => {
2761                    let mut components = src_components.clone();
2762                    for component in &mut components {
2763                        *component = self.binary_op(op, *component, right, span)?;
2764                    }
2765                    Expression::Compose { ty, components }
2766                }
2767            },
2768            (
2769                &Expression::Literal(_),
2770                &Expression::Compose {
2771                    components: ref src_components,
2772                    ty,
2773                },
2774            ) => match op {
2775                BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
2776                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2777                }
2778                _ => {
2779                    let mut components = src_components.clone();
2780                    for component in &mut components {
2781                        *component = self.binary_op(op, left, *component, span)?;
2782                    }
2783                    Expression::Compose { ty, components }
2784                }
2785            },
2786            (
2787                &Expression::Compose {
2788                    components: ref left_components,
2789                    ty: left_ty,
2790                },
2791                &Expression::Compose {
2792                    components: ref right_components,
2793                    ty: right_ty,
2794                },
2795            ) => {
2796                // We have to make a copy of the component lists, because the
2797                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2798                // the component lists.
2799                let left_flattened = crate::proc::flatten_compose(
2800                    left_ty,
2801                    left_components,
2802                    self.expressions,
2803                    self.types,
2804                );
2805                let right_flattened = crate::proc::flatten_compose(
2806                    right_ty,
2807                    right_components,
2808                    self.expressions,
2809                    self.types,
2810                );
2811
2812                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2813                // make a reasonable guess of the capacity we'll need.
2814                let mut flattened = Vec::with_capacity(left_components.len());
2815                flattened.extend(left_flattened.zip(right_flattened));
2816
2817                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2818                    (
2819                        &TypeInner::Vector {
2820                            size: left_size, ..
2821                        },
2822                        &TypeInner::Vector {
2823                            size: right_size, ..
2824                        },
2825                    ) if left_size == right_size => {
2826                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2827                    }
2828                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2829                }
2830            }
2831            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2832        };
2833
2834        self.register_evaluated_expr(expr, span)
2835    }
2836
2837    fn binary_op_vector(
2838        &mut self,
2839        op: BinaryOperator,
2840        size: crate::VectorSize,
2841        components: &[(Handle<Expression>, Handle<Expression>)],
2842        left_ty: Handle<Type>,
2843        span: Span,
2844    ) -> Result<Expression, ConstantEvaluatorError> {
2845        let ty = match op {
2846            // Relational operators produce vectors of booleans.
2847            BinaryOperator::Equal
2848            | BinaryOperator::NotEqual
2849            | BinaryOperator::Less
2850            | BinaryOperator::LessEqual
2851            | BinaryOperator::Greater
2852            | BinaryOperator::GreaterEqual => self.types.insert(
2853                Type {
2854                    name: None,
2855                    inner: TypeInner::Vector {
2856                        size,
2857                        scalar: crate::Scalar::BOOL,
2858                    },
2859                },
2860                span,
2861            ),
2862
2863            // Other operators produce the same type as their left
2864            // operand.
2865            BinaryOperator::Add
2866            | BinaryOperator::Subtract
2867            | BinaryOperator::Multiply
2868            | BinaryOperator::Divide
2869            | BinaryOperator::Modulo
2870            | BinaryOperator::And
2871            | BinaryOperator::ExclusiveOr
2872            | BinaryOperator::InclusiveOr
2873            | BinaryOperator::ShiftLeft
2874            | BinaryOperator::ShiftRight => left_ty,
2875
2876            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2877                // Not supported on vectors
2878                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2879            }
2880        };
2881
2882        let components = components
2883            .iter()
2884            .map(|&(left, right)| self.binary_op(op, left, right, span))
2885            .collect::<Result<Vec<_>, _>>()?;
2886
2887        Ok(Expression::Compose { ty, components })
2888    }
2889
2890    fn relational(
2891        &mut self,
2892        fun: RelationalFunction,
2893        arg: Handle<Expression>,
2894        span: Span,
2895    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2896        let arg = self.eval_zero_value_and_splat(arg, span)?;
2897        match fun {
2898            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2899                Expression::Literal(Literal::Bool(_)) => Ok(arg),
2900                Expression::Compose { ty, ref components }
2901                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2902                {
2903                    let components =
2904                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2905                            .map(|component| match self.expressions[component] {
2906                                Expression::Literal(Literal::Bool(val)) => Ok(val),
2907                                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2908                            })
2909                            .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2910                    let result = match fun {
2911                        RelationalFunction::All => components.iter().all(|c| *c),
2912                        RelationalFunction::Any => components.iter().any(|c| *c),
2913                        _ => unreachable!(),
2914                    };
2915                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2916                }
2917                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2918            },
2919            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2920                "{fun:?} built-in function"
2921            ))),
2922        }
2923    }
2924
2925    /// Deep copy `expr` from `expressions` into `self.expressions`.
2926    ///
2927    /// Return the root of the new copy.
2928    ///
2929    /// This is used when we're evaluating expressions in a function's
2930    /// expression arena that refer to a constant: we need to copy the
2931    /// constant's value into the function's arena so we can operate on it.
2932    fn copy_from(
2933        &mut self,
2934        expr: Handle<Expression>,
2935        expressions: &Arena<Expression>,
2936    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2937        let span = expressions.get_span(expr);
2938        match expressions[expr] {
2939            ref expr @ (Expression::Literal(_)
2940            | Expression::Constant(_)
2941            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2942            Expression::Compose { ty, ref components } => {
2943                let mut components = components.clone();
2944                for component in &mut components {
2945                    *component = self.copy_from(*component, expressions)?;
2946                }
2947                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2948            }
2949            Expression::Splat { size, value } => {
2950                let value = self.copy_from(value, expressions)?;
2951                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2952            }
2953            _ => {
2954                log::debug!("copy_from: SubexpressionsAreNotConstant");
2955                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2956            }
2957        }
2958    }
2959
2960    /// Returns the total number of components, after flattening, of a vector compose expression.
2961    fn vector_compose_flattened_size(
2962        &self,
2963        components: &[Handle<Expression>],
2964    ) -> Result<usize, ConstantEvaluatorError> {
2965        components
2966            .iter()
2967            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2968                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2969                    TypeInner::Scalar(_) => 1,
2970                    // We trust that the vector size of `component` is correct,
2971                    // as it will have already been validated when `component`
2972                    // was registered.
2973                    TypeInner::Vector { size, .. } => size as usize,
2974                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2975                };
2976                Ok(acc + size)
2977            })
2978    }
2979
2980    fn register_evaluated_expr(
2981        &mut self,
2982        expr: Expression,
2983        span: Span,
2984    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2985        // It suffices to only check_literal_value() for `Literal` expressions,
2986        // since we only register one expression at a time, `Compose`
2987        // expressions can only refer to other expressions, and `ZeroValue`
2988        // expressions are always okay.
2989        if let Expression::Literal(literal) = expr {
2990            crate::valid::check_literal_value(literal)?;
2991        }
2992
2993        // Ensure vector composes contain the correct number of components. We
2994        // do so here when each compose is registered to avoid having to deal
2995        // with the mess each time the compose is used in another expression.
2996        if let Expression::Compose { ty, ref components } = expr {
2997            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2998                let expected = size as usize;
2999                let actual = self.vector_compose_flattened_size(components)?;
3000                if expected != actual {
3001                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3002                        expected,
3003                        actual,
3004                    });
3005                }
3006            }
3007        }
3008
3009        Ok(self.append_expr(expr, span, ExpressionKind::Const))
3010    }
3011
3012    fn append_expr(
3013        &mut self,
3014        expr: Expression,
3015        span: Span,
3016        expr_type: ExpressionKind,
3017    ) -> Handle<Expression> {
3018        let h = match self.behavior {
3019            Behavior::Wgsl(
3020                WgslRestrictions::Runtime(ref mut function_local_data)
3021                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3022            )
3023            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3024                let is_running = function_local_data.emitter.is_running();
3025                let needs_pre_emit = expr.needs_pre_emit();
3026                if is_running && needs_pre_emit {
3027                    function_local_data
3028                        .block
3029                        .extend(function_local_data.emitter.finish(self.expressions));
3030                    let h = self.expressions.append(expr, span);
3031                    function_local_data.emitter.start(self.expressions);
3032                    h
3033                } else {
3034                    self.expressions.append(expr, span)
3035                }
3036            }
3037            _ => self.expressions.append(expr, span),
3038        };
3039        self.expression_kind_tracker.insert(h, expr_type);
3040        h
3041    }
3042
3043    /// Resolve the type of `expr` if it is a constant expression.
3044    ///
3045    /// If `expr` was evaluated to a constant, returns its type.
3046    /// Otherwise, returns an error.
3047    fn resolve_type(
3048        &self,
3049        expr: Handle<Expression>,
3050    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3051        use crate::proc::TypeResolution as Tr;
3052        use crate::Expression as Ex;
3053        let resolution = match self.expressions[expr] {
3054            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3055            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3056            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3057            Ex::Splat { size, value } => {
3058                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3059                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3060                };
3061                Tr::Value(TypeInner::Vector { scalar, size })
3062            }
3063            _ => {
3064                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3065                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3066            }
3067        };
3068
3069        Ok(resolution)
3070    }
3071
3072    fn select(
3073        &mut self,
3074        reject: Handle<Expression>,
3075        accept: Handle<Expression>,
3076        condition: Handle<Expression>,
3077        span: Span,
3078    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3079        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3080
3081        let reject = arg(reject)?;
3082        let accept = arg(accept)?;
3083        let condition = arg(condition)?;
3084
3085        let select_single_component =
3086            |this: &mut Self, reject_scalar, reject, accept, condition| {
3087                let accept = this.cast(accept, reject_scalar, span)?;
3088                if condition {
3089                    Ok(accept)
3090                } else {
3091                    Ok(reject)
3092                }
3093            };
3094
3095        match (&self.expressions[reject], &self.expressions[accept]) {
3096            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3097                let reject_scalar = reject_lit.scalar();
3098                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3099                else {
3100                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3101                };
3102                select_single_component(self, reject_scalar, reject, accept, condition)
3103            }
3104            (
3105                &Expression::Compose {
3106                    ty: reject_ty,
3107                    components: ref reject_components,
3108                },
3109                &Expression::Compose {
3110                    ty: accept_ty,
3111                    components: ref accept_components,
3112                },
3113            ) => {
3114                let ty_deets = |ty| {
3115                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3116                    (size.unwrap(), scalar)
3117                };
3118
3119                let expected_vec_size = {
3120                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3121                        [reject_ty, accept_ty].map(ty_deets);
3122
3123                    if reject_vec_size != accept_vec_size {
3124                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3125                            reject: reject_vec_size,
3126                            accept: accept_vec_size,
3127                        });
3128                    }
3129                    reject_vec_size
3130                };
3131
3132                let condition_components = match self.expressions[condition] {
3133                    Expression::Literal(Literal::Bool(condition)) => {
3134                        vec![condition; (expected_vec_size as u8).into()]
3135                    }
3136                    Expression::Compose {
3137                        ty: condition_ty,
3138                        components: ref condition_components,
3139                    } => {
3140                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3141                        if condition_scalar.kind != ScalarKind::Bool {
3142                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3143                        }
3144                        if condition_vec_size != expected_vec_size {
3145                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3146                        }
3147                        condition_components
3148                            .iter()
3149                            .copied()
3150                            .map(|component| match &self.expressions[component] {
3151                                &Expression::Literal(Literal::Bool(condition)) => condition,
3152                                _ => unreachable!(),
3153                            })
3154                            .collect()
3155                    }
3156
3157                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3158                };
3159
3160                let evaluated = Expression::Compose {
3161                    ty: reject_ty,
3162                    components: reject_components
3163                        .clone()
3164                        .into_iter()
3165                        .zip(accept_components.clone().into_iter())
3166                        .zip(condition_components.into_iter())
3167                        .map(|((reject, accept), condition)| {
3168                            let reject_scalar = match &self.expressions[reject] {
3169                                &Expression::Literal(lit) => lit.scalar(),
3170                                _ => unreachable!(),
3171                            };
3172                            select_single_component(self, reject_scalar, reject, accept, condition)
3173                        })
3174                        .collect::<Result<_, _>>()?,
3175                };
3176                self.register_evaluated_expr(evaluated, span)
3177            }
3178            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3179        }
3180    }
3181}
3182
3183fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3184    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3185    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3186    // return a right-to-left bit index of 0.
3187    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3188        match e {
3189            idx @ 0..=31 => idx,
3190            32 => u32::MAX,
3191            _ => unreachable!(),
3192        }
3193    };
3194    match concrete_int {
3195        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3196        ConcreteInt::I32([e]) => {
3197            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3198        }
3199    }
3200}
3201
3202#[test]
3203fn first_trailing_bit_smoke() {
3204    assert_eq!(
3205        first_trailing_bit(ConcreteInt::I32([0])),
3206        ConcreteInt::I32([-1])
3207    );
3208    assert_eq!(
3209        first_trailing_bit(ConcreteInt::I32([1])),
3210        ConcreteInt::I32([0])
3211    );
3212    assert_eq!(
3213        first_trailing_bit(ConcreteInt::I32([2])),
3214        ConcreteInt::I32([1])
3215    );
3216    assert_eq!(
3217        first_trailing_bit(ConcreteInt::I32([-1])),
3218        ConcreteInt::I32([0]),
3219    );
3220    assert_eq!(
3221        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3222        ConcreteInt::I32([31]),
3223    );
3224    assert_eq!(
3225        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3226        ConcreteInt::I32([0]),
3227    );
3228    for idx in 0..32 {
3229        assert_eq!(
3230            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3231            ConcreteInt::I32([idx])
3232        )
3233    }
3234
3235    assert_eq!(
3236        first_trailing_bit(ConcreteInt::U32([0])),
3237        ConcreteInt::U32([u32::MAX])
3238    );
3239    assert_eq!(
3240        first_trailing_bit(ConcreteInt::U32([1])),
3241        ConcreteInt::U32([0])
3242    );
3243    assert_eq!(
3244        first_trailing_bit(ConcreteInt::U32([2])),
3245        ConcreteInt::U32([1])
3246    );
3247    assert_eq!(
3248        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3249        ConcreteInt::U32([31]),
3250    );
3251    assert_eq!(
3252        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3253        ConcreteInt::U32([0]),
3254    );
3255    for idx in 0..32 {
3256        assert_eq!(
3257            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3258            ConcreteInt::U32([idx])
3259        )
3260    }
3261}
3262
3263fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3264    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3265    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3266    // index of 0.
3267    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3268        match e {
3269            idx @ 0..=31 => 31 - idx,
3270            32 => u32::MAX,
3271            _ => unreachable!(),
3272        }
3273    };
3274    match concrete_int {
3275        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3276            let rtl_bit_index = if e.is_negative() {
3277                e.leading_ones()
3278            } else {
3279                e.leading_zeros()
3280            };
3281            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3282        }]),
3283        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3284    }
3285}
3286
3287#[test]
3288fn first_leading_bit_smoke() {
3289    assert_eq!(
3290        first_leading_bit(ConcreteInt::I32([-1])),
3291        ConcreteInt::I32([-1])
3292    );
3293    assert_eq!(
3294        first_leading_bit(ConcreteInt::I32([0])),
3295        ConcreteInt::I32([-1])
3296    );
3297    assert_eq!(
3298        first_leading_bit(ConcreteInt::I32([1])),
3299        ConcreteInt::I32([0])
3300    );
3301    assert_eq!(
3302        first_leading_bit(ConcreteInt::I32([-2])),
3303        ConcreteInt::I32([0])
3304    );
3305    assert_eq!(
3306        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3307        ConcreteInt::I32([12])
3308    );
3309    assert_eq!(
3310        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3311        ConcreteInt::I32([30])
3312    );
3313    assert_eq!(
3314        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3315        ConcreteInt::I32([30])
3316    );
3317    // NOTE: Ignore the sign bit, which is a separate (above) case.
3318    for idx in 0..(32 - 1) {
3319        assert_eq!(
3320            first_leading_bit(ConcreteInt::I32([1 << idx])),
3321            ConcreteInt::I32([idx])
3322        );
3323    }
3324    for idx in 1..(32 - 1) {
3325        assert_eq!(
3326            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3327            ConcreteInt::I32([idx - 1])
3328        );
3329    }
3330
3331    assert_eq!(
3332        first_leading_bit(ConcreteInt::U32([0])),
3333        ConcreteInt::U32([u32::MAX])
3334    );
3335    assert_eq!(
3336        first_leading_bit(ConcreteInt::U32([1])),
3337        ConcreteInt::U32([0])
3338    );
3339    assert_eq!(
3340        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3341        ConcreteInt::U32([31])
3342    );
3343    for idx in 0..32 {
3344        assert_eq!(
3345            first_leading_bit(ConcreteInt::U32([1 << idx])),
3346            ConcreteInt::U32([idx])
3347        )
3348    }
3349}
3350
3351/// Trait for conversions of abstract values to concrete types.
3352trait TryFromAbstract<T>: Sized {
3353    /// Convert an abstract literal `value` to `Self`.
3354    ///
3355    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3356    /// WGSL, we follow WGSL's conversion rules here:
3357    ///
3358    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3359    ///   from [`AbstractInt`] to an integer type are either lossless or an
3360    ///   error.
3361    ///
3362    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3363    ///   to floating point in constant expressions and override
3364    ///   expressions are errors if the value is out of range for the
3365    ///   destination type, but rounding is okay.
3366    ///
3367    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3368    ///   other floating point type, following the scalar floating point to
3369    ///   integral conversion algorithm (§15.7.6). There is no automatic
3370    ///   conversion from AbstractFloat to integer types.
3371    ///
3372    /// [`AbstractInt`]: crate::Literal::AbstractInt
3373    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3374    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3375}
3376
3377impl TryFromAbstract<i64> for i32 {
3378    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3379        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3380            value: format!("{value:?}"),
3381            to_type: "i32",
3382        })
3383    }
3384}
3385
3386impl TryFromAbstract<i64> for u32 {
3387    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3388        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3389            value: format!("{value:?}"),
3390            to_type: "u32",
3391        })
3392    }
3393}
3394
3395impl TryFromAbstract<i64> for u64 {
3396    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3397        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3398            value: format!("{value:?}"),
3399            to_type: "u64",
3400        })
3401    }
3402}
3403
3404impl TryFromAbstract<i64> for i64 {
3405    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3406        Ok(value)
3407    }
3408}
3409
3410impl TryFromAbstract<i64> for f32 {
3411    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3412        let f = value as f32;
3413        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3414        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3415        // overflow here.
3416        Ok(f)
3417    }
3418}
3419
3420impl TryFromAbstract<f64> for f32 {
3421    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3422        let f = value as f32;
3423        if f.is_infinite() {
3424            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3425                value: format!("{value:?}"),
3426                to_type: "f32",
3427            });
3428        }
3429        Ok(f)
3430    }
3431}
3432
3433impl TryFromAbstract<i64> for f64 {
3434    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3435        let f = value as f64;
3436        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3437        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3438        // overflow here.
3439        Ok(f)
3440    }
3441}
3442
3443impl TryFromAbstract<f64> for f64 {
3444    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3445        Ok(value)
3446    }
3447}
3448
3449impl TryFromAbstract<f64> for i32 {
3450    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3451        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3452        // To convert a floating point scalar value X to an integer scalar type T:
3453        // * If X is a NaN, the result is an indeterminate value in T.
3454        // * If X is exactly representable in the target type T, then the
3455        //   result is that value.
3456        // * Otherwise, the result is the value in T closest to truncate(X) and
3457        //   also exactly representable in the original floating point type.
3458        //
3459        // A rust cast satisfies these requirements apart from "the result
3460        // is... exactly representable in the original floating point type".
3461        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3462        // we're all good.
3463        Ok(value as i32)
3464    }
3465}
3466
3467impl TryFromAbstract<f64> for u32 {
3468    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3469        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3470        // so a simple rust cast is sufficient.
3471        Ok(value as u32)
3472    }
3473}
3474
3475impl TryFromAbstract<f64> for i64 {
3476    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3477        // As above, except we clamp to the minimum and maximum values
3478        // representable by both f64 and i64.
3479        use crate::proc::type_methods::IntFloatLimits;
3480        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3481    }
3482}
3483
3484impl TryFromAbstract<f64> for u64 {
3485    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3486        // As above, this time clamping to the minimum and maximum values
3487        // representable by both f64 and u64.
3488        use crate::proc::type_methods::IntFloatLimits;
3489        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3490    }
3491}
3492
3493impl TryFromAbstract<f64> for f16 {
3494    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3495        let f = f16::from_f64(value);
3496        if f.is_infinite() {
3497            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3498                value: format!("{value:?}"),
3499                to_type: "f16",
3500            });
3501        }
3502        Ok(f)
3503    }
3504}
3505
3506impl TryFromAbstract<i64> for f16 {
3507    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3508        let f = f16::from_i64(value);
3509        if f.is_none() {
3510            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3511                value: format!("{value:?}"),
3512                to_type: "f16",
3513            });
3514        }
3515        Ok(f.unwrap())
3516    }
3517}
3518
3519fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3520where
3521    T: Copy,
3522    T: core::ops::Mul<T, Output = T>,
3523    T: core::ops::Sub<T, Output = T>,
3524{
3525    [
3526        a[1] * b[2] - a[2] * b[1],
3527        a[2] * b[0] - a[0] * b[2],
3528        a[0] * b[1] - a[1] * b[0],
3529    ]
3530}
3531
3532#[cfg(test)]
3533mod tests {
3534    use alloc::{vec, vec::Vec};
3535
3536    use crate::{
3537        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3538        UniqueArena, VectorSize,
3539    };
3540
3541    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3542
3543    #[test]
3544    fn unary_op() {
3545        let mut types = UniqueArena::new();
3546        let mut constants = Arena::new();
3547        let overrides = Arena::new();
3548        let mut global_expressions = Arena::new();
3549
3550        let scalar_ty = types.insert(
3551            Type {
3552                name: None,
3553                inner: TypeInner::Scalar(crate::Scalar::I32),
3554            },
3555            Default::default(),
3556        );
3557
3558        let vec_ty = types.insert(
3559            Type {
3560                name: None,
3561                inner: TypeInner::Vector {
3562                    size: VectorSize::Bi,
3563                    scalar: crate::Scalar::I32,
3564                },
3565            },
3566            Default::default(),
3567        );
3568
3569        let h = constants.append(
3570            Constant {
3571                name: None,
3572                ty: scalar_ty,
3573                init: global_expressions
3574                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3575            },
3576            Default::default(),
3577        );
3578
3579        let h1 = constants.append(
3580            Constant {
3581                name: None,
3582                ty: scalar_ty,
3583                init: global_expressions
3584                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3585            },
3586            Default::default(),
3587        );
3588
3589        let vec_h = constants.append(
3590            Constant {
3591                name: None,
3592                ty: vec_ty,
3593                init: global_expressions.append(
3594                    Expression::Compose {
3595                        ty: vec_ty,
3596                        components: vec![constants[h].init, constants[h1].init],
3597                    },
3598                    Default::default(),
3599                ),
3600            },
3601            Default::default(),
3602        );
3603
3604        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3605        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3606
3607        let expr2 = Expression::Unary {
3608            op: UnaryOperator::Negate,
3609            expr,
3610        };
3611
3612        let expr3 = Expression::Unary {
3613            op: UnaryOperator::BitwiseNot,
3614            expr,
3615        };
3616
3617        let expr4 = Expression::Unary {
3618            op: UnaryOperator::BitwiseNot,
3619            expr: expr1,
3620        };
3621
3622        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3623        let mut solver = ConstantEvaluator {
3624            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3625            types: &mut types,
3626            constants: &constants,
3627            overrides: &overrides,
3628            expressions: &mut global_expressions,
3629            expression_kind_tracker,
3630            layouter: &mut crate::proc::Layouter::default(),
3631        };
3632
3633        let res1 = solver
3634            .try_eval_and_append(expr2, Default::default())
3635            .unwrap();
3636        let res2 = solver
3637            .try_eval_and_append(expr3, Default::default())
3638            .unwrap();
3639        let res3 = solver
3640            .try_eval_and_append(expr4, Default::default())
3641            .unwrap();
3642
3643        assert_eq!(
3644            global_expressions[res1],
3645            Expression::Literal(Literal::I32(-4))
3646        );
3647
3648        assert_eq!(
3649            global_expressions[res2],
3650            Expression::Literal(Literal::I32(!4))
3651        );
3652
3653        let res3_inner = &global_expressions[res3];
3654
3655        match *res3_inner {
3656            Expression::Compose {
3657                ref ty,
3658                ref components,
3659            } => {
3660                assert_eq!(*ty, vec_ty);
3661                let mut components_iter = components.iter().copied();
3662                assert_eq!(
3663                    global_expressions[components_iter.next().unwrap()],
3664                    Expression::Literal(Literal::I32(!4))
3665                );
3666                assert_eq!(
3667                    global_expressions[components_iter.next().unwrap()],
3668                    Expression::Literal(Literal::I32(!8))
3669                );
3670                assert!(components_iter.next().is_none());
3671            }
3672            _ => panic!("Expected vector"),
3673        }
3674    }
3675
3676    #[test]
3677    fn cast() {
3678        let mut types = UniqueArena::new();
3679        let mut constants = Arena::new();
3680        let overrides = Arena::new();
3681        let mut global_expressions = Arena::new();
3682
3683        let scalar_ty = types.insert(
3684            Type {
3685                name: None,
3686                inner: TypeInner::Scalar(crate::Scalar::I32),
3687            },
3688            Default::default(),
3689        );
3690
3691        let h = constants.append(
3692            Constant {
3693                name: None,
3694                ty: scalar_ty,
3695                init: global_expressions
3696                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3697            },
3698            Default::default(),
3699        );
3700
3701        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3702
3703        let root = Expression::As {
3704            expr,
3705            kind: ScalarKind::Bool,
3706            convert: Some(crate::BOOL_WIDTH),
3707        };
3708
3709        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3710        let mut solver = ConstantEvaluator {
3711            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3712            types: &mut types,
3713            constants: &constants,
3714            overrides: &overrides,
3715            expressions: &mut global_expressions,
3716            expression_kind_tracker,
3717            layouter: &mut crate::proc::Layouter::default(),
3718        };
3719
3720        let res = solver
3721            .try_eval_and_append(root, Default::default())
3722            .unwrap();
3723
3724        assert_eq!(
3725            global_expressions[res],
3726            Expression::Literal(Literal::Bool(true))
3727        );
3728    }
3729
3730    #[test]
3731    fn access() {
3732        let mut types = UniqueArena::new();
3733        let mut constants = Arena::new();
3734        let overrides = Arena::new();
3735        let mut global_expressions = Arena::new();
3736
3737        let matrix_ty = types.insert(
3738            Type {
3739                name: None,
3740                inner: TypeInner::Matrix {
3741                    columns: VectorSize::Bi,
3742                    rows: VectorSize::Tri,
3743                    scalar: crate::Scalar::F32,
3744                },
3745            },
3746            Default::default(),
3747        );
3748
3749        let vec_ty = types.insert(
3750            Type {
3751                name: None,
3752                inner: TypeInner::Vector {
3753                    size: VectorSize::Tri,
3754                    scalar: crate::Scalar::F32,
3755                },
3756            },
3757            Default::default(),
3758        );
3759
3760        let mut vec1_components = Vec::with_capacity(3);
3761        let mut vec2_components = Vec::with_capacity(3);
3762
3763        for i in 0..3 {
3764            let h = global_expressions.append(
3765                Expression::Literal(Literal::F32(i as f32)),
3766                Default::default(),
3767            );
3768
3769            vec1_components.push(h)
3770        }
3771
3772        for i in 3..6 {
3773            let h = global_expressions.append(
3774                Expression::Literal(Literal::F32(i as f32)),
3775                Default::default(),
3776            );
3777
3778            vec2_components.push(h)
3779        }
3780
3781        let vec1 = constants.append(
3782            Constant {
3783                name: None,
3784                ty: vec_ty,
3785                init: global_expressions.append(
3786                    Expression::Compose {
3787                        ty: vec_ty,
3788                        components: vec1_components,
3789                    },
3790                    Default::default(),
3791                ),
3792            },
3793            Default::default(),
3794        );
3795
3796        let vec2 = constants.append(
3797            Constant {
3798                name: None,
3799                ty: vec_ty,
3800                init: global_expressions.append(
3801                    Expression::Compose {
3802                        ty: vec_ty,
3803                        components: vec2_components,
3804                    },
3805                    Default::default(),
3806                ),
3807            },
3808            Default::default(),
3809        );
3810
3811        let h = constants.append(
3812            Constant {
3813                name: None,
3814                ty: matrix_ty,
3815                init: global_expressions.append(
3816                    Expression::Compose {
3817                        ty: matrix_ty,
3818                        components: vec![constants[vec1].init, constants[vec2].init],
3819                    },
3820                    Default::default(),
3821                ),
3822            },
3823            Default::default(),
3824        );
3825
3826        let base = global_expressions.append(Expression::Constant(h), Default::default());
3827
3828        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3829        let mut solver = ConstantEvaluator {
3830            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3831            types: &mut types,
3832            constants: &constants,
3833            overrides: &overrides,
3834            expressions: &mut global_expressions,
3835            expression_kind_tracker,
3836            layouter: &mut crate::proc::Layouter::default(),
3837        };
3838
3839        let root1 = Expression::AccessIndex { base, index: 1 };
3840
3841        let res1 = solver
3842            .try_eval_and_append(root1, Default::default())
3843            .unwrap();
3844
3845        let root2 = Expression::AccessIndex {
3846            base: res1,
3847            index: 2,
3848        };
3849
3850        let res2 = solver
3851            .try_eval_and_append(root2, Default::default())
3852            .unwrap();
3853
3854        match global_expressions[res1] {
3855            Expression::Compose {
3856                ref ty,
3857                ref components,
3858            } => {
3859                assert_eq!(*ty, vec_ty);
3860                let mut components_iter = components.iter().copied();
3861                assert_eq!(
3862                    global_expressions[components_iter.next().unwrap()],
3863                    Expression::Literal(Literal::F32(3.))
3864                );
3865                assert_eq!(
3866                    global_expressions[components_iter.next().unwrap()],
3867                    Expression::Literal(Literal::F32(4.))
3868                );
3869                assert_eq!(
3870                    global_expressions[components_iter.next().unwrap()],
3871                    Expression::Literal(Literal::F32(5.))
3872                );
3873                assert!(components_iter.next().is_none());
3874            }
3875            _ => panic!("Expected vector"),
3876        }
3877
3878        assert_eq!(
3879            global_expressions[res2],
3880            Expression::Literal(Literal::F32(5.))
3881        );
3882    }
3883
3884    #[test]
3885    fn compose_of_constants() {
3886        let mut types = UniqueArena::new();
3887        let mut constants = Arena::new();
3888        let overrides = Arena::new();
3889        let mut global_expressions = Arena::new();
3890
3891        let i32_ty = types.insert(
3892            Type {
3893                name: None,
3894                inner: TypeInner::Scalar(crate::Scalar::I32),
3895            },
3896            Default::default(),
3897        );
3898
3899        let vec2_i32_ty = types.insert(
3900            Type {
3901                name: None,
3902                inner: TypeInner::Vector {
3903                    size: VectorSize::Bi,
3904                    scalar: crate::Scalar::I32,
3905                },
3906            },
3907            Default::default(),
3908        );
3909
3910        let h = constants.append(
3911            Constant {
3912                name: None,
3913                ty: i32_ty,
3914                init: global_expressions
3915                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3916            },
3917            Default::default(),
3918        );
3919
3920        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3921
3922        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3923        let mut solver = ConstantEvaluator {
3924            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3925            types: &mut types,
3926            constants: &constants,
3927            overrides: &overrides,
3928            expressions: &mut global_expressions,
3929            expression_kind_tracker,
3930            layouter: &mut crate::proc::Layouter::default(),
3931        };
3932
3933        let solved_compose = solver
3934            .try_eval_and_append(
3935                Expression::Compose {
3936                    ty: vec2_i32_ty,
3937                    components: vec![h_expr, h_expr],
3938                },
3939                Default::default(),
3940            )
3941            .unwrap();
3942        let solved_negate = solver
3943            .try_eval_and_append(
3944                Expression::Unary {
3945                    op: UnaryOperator::Negate,
3946                    expr: solved_compose,
3947                },
3948                Default::default(),
3949            )
3950            .unwrap();
3951
3952        let pass = match global_expressions[solved_negate] {
3953            Expression::Compose { ty, ref components } => {
3954                ty == vec2_i32_ty
3955                    && components.iter().all(|&component| {
3956                        let component = &global_expressions[component];
3957                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3958                    })
3959            }
3960            _ => false,
3961        };
3962        if !pass {
3963            panic!("unexpected evaluation result")
3964        }
3965    }
3966
3967    #[test]
3968    fn splat_of_constant() {
3969        let mut types = UniqueArena::new();
3970        let mut constants = Arena::new();
3971        let overrides = Arena::new();
3972        let mut global_expressions = Arena::new();
3973
3974        let i32_ty = types.insert(
3975            Type {
3976                name: None,
3977                inner: TypeInner::Scalar(crate::Scalar::I32),
3978            },
3979            Default::default(),
3980        );
3981
3982        let vec2_i32_ty = types.insert(
3983            Type {
3984                name: None,
3985                inner: TypeInner::Vector {
3986                    size: VectorSize::Bi,
3987                    scalar: crate::Scalar::I32,
3988                },
3989            },
3990            Default::default(),
3991        );
3992
3993        let h = constants.append(
3994            Constant {
3995                name: None,
3996                ty: i32_ty,
3997                init: global_expressions
3998                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3999            },
4000            Default::default(),
4001        );
4002
4003        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4004
4005        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4006        let mut solver = ConstantEvaluator {
4007            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4008            types: &mut types,
4009            constants: &constants,
4010            overrides: &overrides,
4011            expressions: &mut global_expressions,
4012            expression_kind_tracker,
4013            layouter: &mut crate::proc::Layouter::default(),
4014        };
4015
4016        let solved_compose = solver
4017            .try_eval_and_append(
4018                Expression::Splat {
4019                    size: VectorSize::Bi,
4020                    value: h_expr,
4021                },
4022                Default::default(),
4023            )
4024            .unwrap();
4025        let solved_negate = solver
4026            .try_eval_and_append(
4027                Expression::Unary {
4028                    op: UnaryOperator::Negate,
4029                    expr: solved_compose,
4030                },
4031                Default::default(),
4032            )
4033            .unwrap();
4034
4035        let pass = match global_expressions[solved_negate] {
4036            Expression::Compose { ty, ref components } => {
4037                ty == vec2_i32_ty
4038                    && components.iter().all(|&component| {
4039                        let component = &global_expressions[component];
4040                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4041                    })
4042            }
4043            _ => false,
4044        };
4045        if !pass {
4046            panic!("unexpected evaluation result")
4047        }
4048    }
4049
4050    #[test]
4051    fn splat_of_zero_value() {
4052        let mut types = UniqueArena::new();
4053        let constants = Arena::new();
4054        let overrides = Arena::new();
4055        let mut global_expressions = Arena::new();
4056
4057        let f32_ty = types.insert(
4058            Type {
4059                name: None,
4060                inner: TypeInner::Scalar(crate::Scalar::F32),
4061            },
4062            Default::default(),
4063        );
4064
4065        let vec2_f32_ty = types.insert(
4066            Type {
4067                name: None,
4068                inner: TypeInner::Vector {
4069                    size: VectorSize::Bi,
4070                    scalar: crate::Scalar::F32,
4071                },
4072            },
4073            Default::default(),
4074        );
4075
4076        let five =
4077            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4078        let five_splat = global_expressions.append(
4079            Expression::Splat {
4080                size: VectorSize::Bi,
4081                value: five,
4082            },
4083            Default::default(),
4084        );
4085        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4086        let zero_splat = global_expressions.append(
4087            Expression::Splat {
4088                size: VectorSize::Bi,
4089                value: zero,
4090            },
4091            Default::default(),
4092        );
4093
4094        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4095        let mut solver = ConstantEvaluator {
4096            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4097            types: &mut types,
4098            constants: &constants,
4099            overrides: &overrides,
4100            expressions: &mut global_expressions,
4101            expression_kind_tracker,
4102            layouter: &mut crate::proc::Layouter::default(),
4103        };
4104
4105        let solved_add = solver
4106            .try_eval_and_append(
4107                Expression::Binary {
4108                    op: crate::BinaryOperator::Add,
4109                    left: zero_splat,
4110                    right: five_splat,
4111                },
4112                Default::default(),
4113            )
4114            .unwrap();
4115
4116        let pass = match global_expressions[solved_add] {
4117            Expression::Compose { ty, ref components } => {
4118                ty == vec2_f32_ty
4119                    && components.iter().all(|&component| {
4120                        let component = &global_expressions[component];
4121                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4122                    })
4123            }
4124            _ => false,
4125        };
4126        if !pass {
4127            panic!("unexpected evaluation result")
4128        }
4129    }
4130}