1use alloc::{
7 format,
8 string::{String, ToString},
9 vec,
10 vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19 arena::{Arena, Handle, HandleVec, UniqueArena},
20 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21 ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27macro_rules! with_dollar_sign {
33 ($($body:tt)*) => {
34 macro_rules! __with_dollar_sign { $($body)* }
35 __with_dollar_sign!($);
36 }
37}
38
39macro_rules! gen_component_wise_extractor {
40 (
41 $ident:ident -> $target:ident,
42 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44 ) => {
45 #[derive(Debug)]
47 #[cfg_attr(test, derive(PartialEq))]
48 enum $target<const N: usize> {
49 $(
50 #[doc = concat!(
51 "Maps to [`Literal::",
52 stringify!($literal),
53 "`]",
54 )]
55 $mapping([$ty; N]),
56 )+
57 }
58
59 impl From<$target<1>> for Expression {
60 fn from(value: $target<1>) -> Self {
61 match value {
62 $(
63 $target::$mapping([value]) => {
64 Expression::Literal(Literal::$literal(value))
65 }
66 )+
67 }
68 }
69 }
70
71 #[doc = concat!(
72 "Attempts to evaluate multiple `exprs` as a combined [`",
73 stringify!($target),
74 "`] to pass to `handler`. ",
75 )]
76 fn $ident<const N: usize, const M: usize>(
83 eval: &mut ConstantEvaluator<'_>,
84 span: Span,
85 exprs: [Handle<Expression>; N],
86 handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88 where
89 $target<M>: Into<Expression>,
90 {
91 assert!(N > 0);
92 let err = ConstantEvaluatorError::InvalidMathArg;
93 let mut exprs = exprs.into_iter();
94
95 macro_rules! sanitize {
96 ($expr:expr) => {
97 eval.eval_zero_value_and_splat($expr, span)
98 .map(|expr| &eval.expressions[expr])
99 };
100 }
101
102 let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103 $(
104 &Expression::Literal(Literal::$literal(x)) => {
105 let mut arr = ArrayVec::<_, N>::new();
106 arr.push(x);
107 for expr in exprs {
108 match sanitize!(expr)? {
109 &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110 _ => return Err(err),
111 }
112 }
113 let comps = $target::$mapping(arr.into_inner().unwrap());
114 Ok(handler(comps)?.into())
115 },
116 )+
117 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118 &TypeInner::Vector { size, scalar } => match scalar.kind {
119 $(ScalarKind::$scalar_kind)|* => {
120 let first_ty = ty;
121 let mut component_groups =
122 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123 {
124 let mut inner = ArrayVec::new();
125 for item in crate::proc::flatten_compose(
126 first_ty,
127 components,
128 eval.expressions,
129 eval.types,
130 ) {
131 inner.push(item);
132 }
133 component_groups.push(inner);
134 }
135 for expr in exprs {
136 match sanitize!(expr)? {
137 &Expression::Compose { ty, ref components }
138 if &eval.types[ty].inner
139 == &eval.types[first_ty].inner =>
140 {
141 let mut inner = ArrayVec::new();
142 for item in crate::proc::flatten_compose(
143 ty,
144 components,
145 eval.expressions,
146 eval.types,
147 ) {
148 inner.push(item);
149 }
150 component_groups.push(inner);
151 }
152 _ => return Err(err),
153 }
154 }
155 let component_groups = component_groups.into_inner().unwrap();
156 let mut new_components =
157 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158 for idx in 0..(size as u8).into() {
159 let mut group_arr = ArrayVec::<_, N>::new();
160 for cs in component_groups.iter() {
161 group_arr.push(
162 cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163 );
164 }
165 let group = group_arr.into_inner().unwrap();
166 new_components.push($ident(
167 eval,
168 span,
169 group,
170 handler,
171 )?);
172 }
173 Ok(Expression::Compose {
174 ty: first_ty,
175 components: new_components.into_iter().collect(),
176 })
177 }
178 _ => return Err(err),
179 },
180 _ => return Err(err),
181 },
182 _ => return Err(err),
183 };
184 eval.register_evaluated_expr(new_expr?, span)
185 }
186
187 with_dollar_sign! {
188 ($d:tt) => {
189 #[allow(unused)]
190 #[doc = concat!(
191 "A convenience macro for using the same RHS for each [`",
192 stringify!($target),
193 "`] variant in a call to [`",
194 stringify!($ident),
195 "`].",
196 )]
197 macro_rules! $ident {
198 (
199 $eval:expr,
200 $span:expr,
201 [$d ($d expr:expr),+ $d (,)?],
202 |$d ($d arg:ident),+| $d tt:tt
203 ) => {
204 $ident($eval, $span, [$d ($d expr),+], |args| match args {
205 $(
206 $target::$mapping([$d ($d arg),+]) => {
207 let res = $d tt;
208 Result::map(res, $target::$mapping)
209 },
210 )+
211 })
212 };
213 }
214 };
215 }
216 };
217}
218
219gen_component_wise_extractor! {
220 component_wise_scalar -> Scalar,
221 literals: [
222 AbstractFloat => AbstractFloat: f64,
223 F32 => F32: f32,
224 F16 => F16: f16,
225 AbstractInt => AbstractInt: i64,
226 U32 => U32: u32,
227 I32 => I32: i32,
228 U64 => U64: u64,
229 I64 => I64: i64,
230 ],
231 scalar_kinds: [
232 Float,
233 AbstractFloat,
234 Sint,
235 Uint,
236 AbstractInt,
237 ],
238}
239
240gen_component_wise_extractor! {
241 component_wise_float -> Float,
242 literals: [
243 AbstractFloat => Abstract: f64,
244 F32 => F32: f32,
245 F16 => F16: f16,
246 ],
247 scalar_kinds: [
248 Float,
249 AbstractFloat,
250 ],
251}
252
253gen_component_wise_extractor! {
254 component_wise_concrete_int -> ConcreteInt,
255 literals: [
256 U32 => U32: u32,
257 I32 => I32: i32,
258 ],
259 scalar_kinds: [
260 Sint,
261 Uint,
262 ],
263}
264
265gen_component_wise_extractor! {
266 component_wise_signed -> Signed,
267 literals: [
268 AbstractFloat => AbstractFloat: f64,
269 AbstractInt => AbstractInt: i64,
270 F32 => F32: f32,
271 F16 => F16: f16,
272 I32 => I32: i32,
273 ],
274 scalar_kinds: [
275 Sint,
276 AbstractInt,
277 Float,
278 AbstractFloat,
279 ],
280}
281
282#[derive(Debug)]
284enum LiteralVector {
285 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288 U16(ArrayVec<u16, { crate::VectorSize::MAX }>),
289 I16(ArrayVec<i16, { crate::VectorSize::MAX }>),
290 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
291 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
292 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
293 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
294 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
295 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
296 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
297}
298
299impl LiteralVector {
300 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
301 fn len(&self) -> usize {
302 match *self {
303 LiteralVector::F64(ref v) => v.len(),
304 LiteralVector::F32(ref v) => v.len(),
305 LiteralVector::F16(ref v) => v.len(),
306 LiteralVector::U16(ref v) => v.len(),
307 LiteralVector::I16(ref v) => v.len(),
308 LiteralVector::U32(ref v) => v.len(),
309 LiteralVector::I32(ref v) => v.len(),
310 LiteralVector::U64(ref v) => v.len(),
311 LiteralVector::I64(ref v) => v.len(),
312 LiteralVector::Bool(ref v) => v.len(),
313 LiteralVector::AbstractInt(ref v) => v.len(),
314 LiteralVector::AbstractFloat(ref v) => v.len(),
315 }
316 }
317
318 fn from_literal(literal: Literal) -> Self {
320 fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
321 let mut v = ArrayVec::new();
322 v.push(val);
323 v
324 }
325 match literal {
326 Literal::F64(e) => Self::F64(arrayvec_of(e)),
327 Literal::F32(e) => Self::F32(arrayvec_of(e)),
328 Literal::U16(e) => Self::U16(arrayvec_of(e)),
329 Literal::I16(e) => Self::I16(arrayvec_of(e)),
330 Literal::U32(e) => Self::U32(arrayvec_of(e)),
331 Literal::I32(e) => Self::I32(arrayvec_of(e)),
332 Literal::U64(e) => Self::U64(arrayvec_of(e)),
333 Literal::I64(e) => Self::I64(arrayvec_of(e)),
334 Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
335 Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
336 Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
337 Literal::F16(e) => Self::F16(arrayvec_of(e)),
338 }
339 }
340
341 fn from_literal_vec(
346 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
347 ) -> Result<Self, ConstantEvaluatorError> {
348 assert!(!components.is_empty());
349 macro_rules! compose_literals {
351 ($components:expr, $variant:ident, $self_variant:ident) => {{
352 let mut out = ArrayVec::new();
353 for l in &$components {
354 match l {
355 &Literal::$variant(v) => out.push(v),
356 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
357 }
358 }
359 Self::$self_variant(out)
360 }};
361 }
362 Ok(match components[0] {
363 Literal::I16(_) => compose_literals!(components, I16, I16),
364 Literal::U16(_) => compose_literals!(components, U16, U16),
365 Literal::I32(_) => compose_literals!(components, I32, I32),
366 Literal::U32(_) => compose_literals!(components, U32, U32),
367 Literal::I64(_) => compose_literals!(components, I64, I64),
368 Literal::U64(_) => compose_literals!(components, U64, U64),
369 Literal::F32(_) => compose_literals!(components, F32, F32),
370 Literal::F64(_) => compose_literals!(components, F64, F64),
371 Literal::Bool(_) => compose_literals!(components, Bool, Bool),
372 Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
373 Literal::AbstractFloat(_) => {
374 compose_literals!(components, AbstractFloat, AbstractFloat)
375 }
376 Literal::F16(_) => compose_literals!(components, F16, F16),
377 })
378 }
379
380 #[allow(dead_code)]
381 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
383 macro_rules! decompose_literals {
384 ($v:expr, $variant:ident) => {{
385 let mut out = ArrayVec::new();
386 for e in $v {
387 out.push(Literal::$variant(*e));
388 }
389 out
390 }};
391 }
392 match *self {
393 LiteralVector::F64(ref v) => decompose_literals!(v, F64),
394 LiteralVector::F32(ref v) => decompose_literals!(v, F32),
395 LiteralVector::F16(ref v) => decompose_literals!(v, F16),
396 LiteralVector::U16(ref v) => decompose_literals!(v, U16),
397 LiteralVector::I16(ref v) => decompose_literals!(v, I16),
398 LiteralVector::U32(ref v) => decompose_literals!(v, U32),
399 LiteralVector::I32(ref v) => decompose_literals!(v, I32),
400 LiteralVector::U64(ref v) => decompose_literals!(v, U64),
401 LiteralVector::I64(ref v) => decompose_literals!(v, I64),
402 LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
403 LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
404 LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
405 }
406 }
407
408 #[allow(dead_code)]
409 fn register_as_evaluated_expr(
411 &self,
412 eval: &mut ConstantEvaluator<'_>,
413 span: Span,
414 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
415 let lit_vec = self.to_literal_vec();
416 assert!(!lit_vec.is_empty());
417 let expr = if lit_vec.len() == 1 {
418 Expression::Literal(lit_vec[0])
419 } else {
420 Expression::Compose {
421 ty: eval.types.insert(
422 Type {
423 name: None,
424 inner: TypeInner::Vector {
425 size: match lit_vec.len() {
426 2 => crate::VectorSize::Bi,
427 3 => crate::VectorSize::Tri,
428 4 => crate::VectorSize::Quad,
429 _ => unreachable!(),
430 },
431 scalar: lit_vec[0].scalar(),
432 },
433 },
434 Span::UNDEFINED,
435 ),
436 components: lit_vec
437 .iter()
438 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
439 .collect::<Result<_, _>>()?,
440 }
441 };
442 eval.register_evaluated_expr(expr, span)
443 }
444}
445
446macro_rules! match_literal_vector {
471 (match $lit_vec:expr => $out:ident {
472 $(
473 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
474 ),+
475 $(,)?
476 }) => {
477 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
478 };
479
480 (@inner_start
481 $lit_vec:expr;
482 $out:ident;
483 [$($ty:ident),+];
484 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
485 ) => {
486 match_literal_vector!(@inner
487 $lit_vec;
488 $out;
489 [$($ty),+];
490 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
491 )
492 };
493
494 (@inner
495 $lit_vec:expr;
496 $out:ident;
497 [$ty:ident $(, $ty1:ident)*];
498 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
499 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
500 ) => {
501 match_literal_vector!(@inner
502 $ty;
503 $lit_vec;
504 $out;
505 [$($ty1),*];
506 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
507 [$({ $($var),+ ; $($ret)? ; $body }),+]
508 )
509 };
510 (@inner
511 Integer;
512 $lit_vec:expr;
513 $out:ident;
514 [$($ty:ident),*];
515 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
516 [
517 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
518 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
519 ]
520 ) => {
521 match_literal_vector!(@inner
522 $lit_vec;
523 $out;
524 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
525 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
526 [
527 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
533 ]
534 )
535 };
536 (@inner
537 Float;
538 $lit_vec:expr;
539 $out:ident;
540 [$($ty:ident),*];
541 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
542 [
543 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
544 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
545 ]
546 ) => {
547 match_literal_vector!(@inner
548 $lit_vec;
549 $out;
550 [F16, F32, F64, AbstractFloat $(, $ty)*];
551 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
552 [
553 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
558 ]
559 )
560 };
561 (@inner
562 $ty:ident;
563 $lit_vec:expr;
564 $out:ident;
565 [$ty1:ident $(,$ty2:ident)*];
566 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
567 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
568 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
569 ]
570 ) => {
571 match_literal_vector!(@inner
572 $ty1;
573 $lit_vec;
574 $out;
575 [$($ty2),*];
576 [
577 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
578 { $ty; $($var),+ ; $($ret)? ; $body }
579 ] <>
580 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
581
582 )
583 };
584 (@inner
585 $ty:ident;
586 $lit_vec:expr;
587 $out:ident;
588 [];
589 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
590 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
591 ) => {
592 match_literal_vector!(@inner_finish
593 $lit_vec;
594 $out;
595 [
596 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
597 { $ty; $($var),+ ; $($ret)? ; $body }
598 ]
599 )
600 };
601 (@inner_finish
602 $lit_vec:expr;
603 $out:ident;
604 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
605 ) => {
606 match $lit_vec {
607 $(
608 #[allow(unused_parens)]
609 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
610 )+
611 _ => Err(ConstantEvaluatorError::InvalidMathArg),
612 }
613 };
614 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
615 $out::$ty($body)
616 };
617 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
618 $out::$ret($body)
619 };
620}
621
622fn float_length<F>(e: &[F]) -> Option<F>
623where
624 F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
625{
626 if e.len() == 1 {
627 Some(e[0].abs())
629 } else {
630 let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
631 result.is_finite().then_some(result)
632 }
633}
634
635#[derive(Debug)]
636enum Behavior<'a> {
637 Wgsl(WgslRestrictions<'a>),
638 Glsl(GlslRestrictions<'a>),
639}
640
641impl Behavior<'_> {
642 const fn has_runtime_restrictions(&self) -> bool {
644 matches!(
645 self,
646 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
647 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
648 )
649 }
650}
651
652#[derive(Debug)]
670pub struct ConstantEvaluator<'a> {
671 behavior: Behavior<'a>,
673
674 types: &'a mut UniqueArena<Type>,
681
682 constants: &'a Arena<Constant>,
684
685 overrides: &'a Arena<Override>,
687
688 expressions: &'a mut Arena<Expression>,
690
691 expression_kind_tracker: &'a mut ExpressionKindTracker,
693
694 layouter: &'a mut crate::proc::Layouter,
695}
696
697#[derive(Debug)]
698enum WgslRestrictions<'a> {
699 Const(Option<FunctionLocalData<'a>>),
701 Override,
704 Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711enum GlslRestrictions<'a> {
712 Const,
714 Runtime(FunctionLocalData<'a>),
718}
719
720#[derive(Debug)]
721struct FunctionLocalData<'a> {
722 global_expressions: &'a Arena<Expression>,
724 emitter: &'a mut super::Emitter,
725 block: &'a mut crate::Block,
726}
727
728#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
729pub enum ExpressionKind {
730 Const,
731 Override,
732 Runtime,
733}
734
735#[derive(Debug)]
736pub struct ExpressionKindTracker {
737 inner: HandleVec<Expression, ExpressionKind>,
738}
739
740impl ExpressionKindTracker {
741 pub const fn new() -> Self {
742 Self {
743 inner: HandleVec::new(),
744 }
745 }
746
747 pub fn force_non_const(&mut self, value: Handle<Expression>) {
749 self.inner[value] = ExpressionKind::Runtime;
750 }
751
752 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
753 self.inner.insert(value, expr_type);
754 }
755
756 pub fn is_const(&self, h: Handle<Expression>) -> bool {
757 matches!(self.type_of(h), ExpressionKind::Const)
758 }
759
760 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
761 matches!(
762 self.type_of(h),
763 ExpressionKind::Const | ExpressionKind::Override
764 )
765 }
766
767 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
768 self.inner[value]
769 }
770
771 pub fn from_arena(arena: &Arena<Expression>) -> Self {
772 let mut tracker = Self {
773 inner: HandleVec::with_capacity(arena.len()),
774 };
775 for (handle, expr) in arena.iter() {
776 tracker
777 .inner
778 .insert(handle, tracker.type_of_with_expr(expr));
779 }
780 tracker
781 }
782
783 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
784 match *expr {
785 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
786 ExpressionKind::Const
787 }
788 Expression::Override(_) => ExpressionKind::Override,
789 Expression::Compose { ref components, .. } => {
790 let mut expr_type = ExpressionKind::Const;
791 for component in components {
792 expr_type = expr_type.max(self.type_of(*component))
793 }
794 expr_type
795 }
796 Expression::Splat { value, .. } => self.type_of(value),
797 Expression::AccessIndex { base, .. } => self.type_of(base),
798 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
799 Expression::Swizzle { vector, .. } => self.type_of(vector),
800 Expression::Unary { expr, .. } => self.type_of(expr),
801 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
802 Expression::Math {
803 arg,
804 arg1,
805 arg2,
806 arg3,
807 ..
808 } => self
809 .type_of(arg)
810 .max(
811 arg1.map(|arg| self.type_of(arg))
812 .unwrap_or(ExpressionKind::Const),
813 )
814 .max(
815 arg2.map(|arg| self.type_of(arg))
816 .unwrap_or(ExpressionKind::Const),
817 )
818 .max(
819 arg3.map(|arg| self.type_of(arg))
820 .unwrap_or(ExpressionKind::Const),
821 ),
822 Expression::As { expr, .. } => self.type_of(expr),
823 Expression::Select {
824 condition,
825 accept,
826 reject,
827 } => self
828 .type_of(condition)
829 .max(self.type_of(accept))
830 .max(self.type_of(reject)),
831 Expression::Relational { argument, .. } => self.type_of(argument),
832 Expression::ArrayLength(expr) => self.type_of(expr),
833 _ => ExpressionKind::Runtime,
834 }
835 }
836}
837
838#[derive(Clone, Debug, thiserror::Error)]
839#[cfg_attr(test, derive(PartialEq))]
840pub enum ConstantEvaluatorError {
841 #[error("Constants cannot access function arguments")]
842 FunctionArg,
843 #[error("Constants cannot access global variables")]
844 GlobalVariable,
845 #[error("Constants cannot access local variables")]
846 LocalVariable,
847 #[error("Cannot get the array length of a non array type")]
848 InvalidArrayLengthArg,
849 #[error("Constants cannot get the array length of a dynamically sized array")]
850 ArrayLengthDynamic,
851 #[error("Cannot call arrayLength on array sized by override-expression")]
852 ArrayLengthOverridden,
853 #[error("Constants cannot call functions")]
854 Call,
855 #[error("Constants don't support workGroupUniformLoad")]
856 WorkGroupUniformLoadResult,
857 #[error("Constants don't support atomic functions")]
858 Atomic,
859 #[error("Constants don't support derivative functions")]
860 Derivative,
861 #[error("Constants don't support load expressions")]
862 Load,
863 #[error("Constants don't support image expressions")]
864 ImageExpression,
865 #[error("Constants don't support ray query expressions")]
866 RayQueryExpression,
867 #[error("Constants don't support subgroup expressions")]
868 SubgroupExpression,
869 #[error("Cannot access the type")]
870 InvalidAccessBase,
871 #[error("Cannot access at the index")]
872 InvalidAccessIndex,
873 #[error("Cannot access with index of type")]
874 InvalidAccessIndexTy,
875 #[error("Constants don't support array length expressions")]
876 ArrayLength,
877 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
878 InvalidCastArg { from: String, to: String },
879 #[error("Cannot apply the unary op to the argument")]
880 InvalidUnaryOpArg,
881 #[error("Cannot apply the binary op to the arguments")]
882 InvalidBinaryOpArgs,
883 #[error("Cannot apply math function to type")]
884 InvalidMathArg,
885 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
886 InvalidMathArgCount(crate::MathFunction, usize, usize),
887 #[error("{0} built-in function argument is out of valid range")]
888 InvalidMathArgValue(String),
889 #[error("Cannot apply relational function to type")]
890 InvalidRelationalArg(RelationalFunction),
891 #[error("value of `low` is greater than `high` for clamp built-in function")]
892 InvalidClamp,
893 #[error("Constructor expects {expected} components, found {actual}")]
894 InvalidVectorComposeLength { expected: usize, actual: usize },
895 #[error("Constructor must only contain vector or scalar arguments")]
896 InvalidVectorComposeComponent,
897 #[error("Splat is defined only on scalar values")]
898 SplatScalarOnly,
899 #[error("Can only swizzle vector constants")]
900 SwizzleVectorOnly,
901 #[error("swizzle component not present in source expression")]
902 SwizzleOutOfBounds,
903 #[error("Type is not constructible")]
904 TypeNotConstructible,
905 #[error("Subexpression(s) are not constant")]
906 SubexpressionsAreNotConstant,
907 #[error("Not implemented as constant expression: {0}")]
908 NotImplemented(String),
909 #[error("{0} operation overflowed")]
910 Overflow(String),
911 #[error(
912 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
913 )]
914 AutomaticConversionLossy {
915 value: String,
916 to_type: &'static str,
917 },
918 #[error("Division by zero")]
919 DivisionByZero,
920 #[error("Remainder by zero")]
921 RemainderByZero,
922 #[error("RHS of shift operation is greater than or equal to 32")]
923 ShiftedMoreThan32Bits,
924 #[error(transparent)]
925 Literal(#[from] crate::valid::LiteralError),
926 #[error("Can't use pipeline-overridable constants in const-expressions")]
927 Override,
928 #[error("Unexpected runtime-expression")]
929 RuntimeExpr,
930 #[error("Unexpected override-expression")]
931 OverrideExpr,
932 #[error("Expected boolean expression for condition argument of `select`, got something else")]
933 SelectScalarConditionNotABool,
934 #[error(
935 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
936 reject,
937 accept
938 )]
939 SelectVecRejectAcceptSizeMismatch {
940 reject: crate::VectorSize,
941 accept: crate::VectorSize,
942 },
943 #[error("Expected boolean vector for condition arg., got something else")]
944 SelectConditionNotAVecBool,
945 #[error(
946 "Expected same number of vector components between condition, accept, and reject args., got something else",
947 )]
948 SelectConditionVecSizeMismatch,
949 #[error(
950 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
951 )]
952 SelectAcceptRejectTypeMismatch,
953 #[error("Cooperative operations can't be constant")]
954 CooperativeOperation,
955 #[error("Type is too large")]
956 TypeTooLarge(Handle<Type>),
957}
958
959impl<'a> ConstantEvaluator<'a> {
960 pub const fn for_wgsl_module(
965 module: &'a mut crate::Module,
966 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
967 layouter: &'a mut crate::proc::Layouter,
968 in_override_ctx: bool,
969 ) -> Self {
970 Self::for_module(
971 Behavior::Wgsl(if in_override_ctx {
972 WgslRestrictions::Override
973 } else {
974 WgslRestrictions::Const(None)
975 }),
976 module,
977 global_expression_kind_tracker,
978 layouter,
979 )
980 }
981
982 pub const fn for_glsl_module(
987 module: &'a mut crate::Module,
988 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
989 layouter: &'a mut crate::proc::Layouter,
990 ) -> Self {
991 Self::for_module(
992 Behavior::Glsl(GlslRestrictions::Const),
993 module,
994 global_expression_kind_tracker,
995 layouter,
996 )
997 }
998
999 const fn for_module(
1000 behavior: Behavior<'a>,
1001 module: &'a mut crate::Module,
1002 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1003 layouter: &'a mut crate::proc::Layouter,
1004 ) -> Self {
1005 Self {
1006 behavior,
1007 types: &mut module.types,
1008 constants: &module.constants,
1009 overrides: &module.overrides,
1010 expressions: &mut module.global_expressions,
1011 expression_kind_tracker: global_expression_kind_tracker,
1012 layouter,
1013 }
1014 }
1015
1016 pub const fn for_wgsl_function(
1021 module: &'a mut crate::Module,
1022 expressions: &'a mut Arena<Expression>,
1023 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1024 layouter: &'a mut crate::proc::Layouter,
1025 emitter: &'a mut super::Emitter,
1026 block: &'a mut crate::Block,
1027 is_const: bool,
1028 ) -> Self {
1029 let local_data = FunctionLocalData {
1030 global_expressions: &module.global_expressions,
1031 emitter,
1032 block,
1033 };
1034 Self {
1035 behavior: Behavior::Wgsl(if is_const {
1036 WgslRestrictions::Const(Some(local_data))
1037 } else {
1038 WgslRestrictions::Runtime(local_data)
1039 }),
1040 types: &mut module.types,
1041 constants: &module.constants,
1042 overrides: &module.overrides,
1043 expressions,
1044 expression_kind_tracker: local_expression_kind_tracker,
1045 layouter,
1046 }
1047 }
1048
1049 pub const fn for_glsl_function(
1054 module: &'a mut crate::Module,
1055 expressions: &'a mut Arena<Expression>,
1056 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1057 layouter: &'a mut crate::proc::Layouter,
1058 emitter: &'a mut super::Emitter,
1059 block: &'a mut crate::Block,
1060 ) -> Self {
1061 Self {
1062 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1063 global_expressions: &module.global_expressions,
1064 emitter,
1065 block,
1066 })),
1067 types: &mut module.types,
1068 constants: &module.constants,
1069 overrides: &module.overrides,
1070 expressions,
1071 expression_kind_tracker: local_expression_kind_tracker,
1072 layouter,
1073 }
1074 }
1075
1076 pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1077 crate::proc::GlobalCtx {
1078 types: self.types,
1079 constants: self.constants,
1080 overrides: self.overrides,
1081 global_expressions: match self.function_local_data() {
1082 Some(data) => data.global_expressions,
1083 None => self.expressions,
1084 },
1085 }
1086 }
1087
1088 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1089 if !self.expression_kind_tracker.is_const(expr) {
1090 log::debug!("check: SubexpressionsAreNotConstant");
1091 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1092 }
1093 Ok(())
1094 }
1095
1096 fn check_and_get(
1097 &mut self,
1098 expr: Handle<Expression>,
1099 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1100 match self.expressions[expr] {
1101 Expression::Constant(c) => {
1102 if let Some(function_local_data) = self.function_local_data() {
1105 self.copy_from(
1107 self.constants[c].init,
1108 function_local_data.global_expressions,
1109 )
1110 } else {
1111 Ok(self.constants[c].init)
1113 }
1114 }
1115 _ => {
1116 self.check(expr)?;
1117 Ok(expr)
1118 }
1119 }
1120 }
1121
1122 pub fn try_eval_and_append(
1146 &mut self,
1147 expr: Expression,
1148 span: Span,
1149 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1150 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1151 ExpressionKind::Const => {
1152 let eval_result = self.try_eval_and_append_impl(&expr, span);
1153 if self.behavior.has_runtime_restrictions()
1158 && matches!(
1159 eval_result,
1160 Err(ConstantEvaluatorError::NotImplemented(_)
1161 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1162 )
1163 {
1164 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1165 } else {
1166 eval_result
1167 }
1168 }
1169 ExpressionKind::Override => match self.behavior {
1170 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1171 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1172 }
1173 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1174 Err(ConstantEvaluatorError::OverrideExpr)
1175 }
1176
1177 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1179 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1180 }
1181 Behavior::Glsl(GlslRestrictions::Const) => {
1182 Err(ConstantEvaluatorError::OverrideExpr)
1183 }
1184 },
1185 ExpressionKind::Runtime => {
1186 if self.behavior.has_runtime_restrictions() {
1187 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1188 } else {
1189 Err(ConstantEvaluatorError::RuntimeExpr)
1190 }
1191 }
1192 }
1193 }
1194
1195 const fn is_global_arena(&self) -> bool {
1197 matches!(
1198 self.behavior,
1199 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1200 | Behavior::Glsl(GlslRestrictions::Const)
1201 )
1202 }
1203
1204 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1205 match self.behavior {
1206 Behavior::Wgsl(
1207 WgslRestrictions::Runtime(ref function_local_data)
1208 | WgslRestrictions::Const(Some(ref function_local_data)),
1209 )
1210 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1211 Some(function_local_data)
1212 }
1213 _ => None,
1214 }
1215 }
1216
1217 fn try_eval_and_append_impl(
1218 &mut self,
1219 expr: &Expression,
1220 span: Span,
1221 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1222 log::trace!("try_eval_and_append: {expr:?}");
1223 match *expr {
1224 Expression::Constant(c) if self.is_global_arena() => {
1225 Ok(self.constants[c].init)
1228 }
1229 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1230 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1231 self.register_evaluated_expr(expr.clone(), span)
1232 }
1233 Expression::Compose { ty, ref components } => {
1234 let components = components
1235 .iter()
1236 .map(|component| self.check_and_get(*component))
1237 .collect::<Result<Vec<_>, _>>()?;
1238 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1239 }
1240 Expression::Splat { size, value } => {
1241 let value = self.check_and_get(value)?;
1242 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1243 }
1244 Expression::AccessIndex { base, index } => {
1245 let base = self.check_and_get(base)?;
1246
1247 self.access(base, index as usize, span)
1248 }
1249 Expression::Access { base, index } => {
1250 let base = self.check_and_get(base)?;
1251 let index = self.check_and_get(index)?;
1252
1253 let index_val: u32 = self
1254 .to_ctx()
1255 .get_const_val_from(index, self.expressions)
1256 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1257 self.access(base, index_val as usize, span)
1258 }
1259 Expression::Swizzle {
1260 size,
1261 vector,
1262 pattern,
1263 } => {
1264 let vector = self.check_and_get(vector)?;
1265
1266 self.swizzle(size, span, vector, pattern)
1267 }
1268 Expression::Unary { expr, op } => {
1269 let expr = self.check_and_get(expr)?;
1270
1271 self.unary_op(op, expr, span)
1272 }
1273 Expression::Binary { left, right, op } => {
1274 let left = self.check_and_get(left)?;
1275 let right = self.check_and_get(right)?;
1276
1277 self.binary_op(op, left, right, span)
1278 }
1279 Expression::Math {
1280 fun,
1281 arg,
1282 arg1,
1283 arg2,
1284 arg3,
1285 } => {
1286 let arg = self.check_and_get(arg)?;
1287 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1288 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1289 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1290
1291 self.math(arg, arg1, arg2, arg3, fun, span)
1292 }
1293 Expression::As {
1294 convert,
1295 expr,
1296 kind,
1297 } => {
1298 let expr = self.check_and_get(expr)?;
1299
1300 match convert {
1301 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1302 None => Err(ConstantEvaluatorError::NotImplemented(
1303 "bitcast built-in function".into(),
1304 )),
1305 }
1306 }
1307 Expression::Select {
1308 reject,
1309 accept,
1310 condition,
1311 } => {
1312 let mut arg = |expr| self.check_and_get(expr);
1313
1314 let reject = arg(reject)?;
1315 let accept = arg(accept)?;
1316 let condition = arg(condition)?;
1317
1318 self.select(reject, accept, condition, span)
1319 }
1320 Expression::Relational { fun, argument } => {
1321 let argument = self.check_and_get(argument)?;
1322 self.relational(fun, argument, span)
1323 }
1324 Expression::ArrayLength(expr) => match self.behavior {
1325 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1326 Behavior::Glsl(_) => {
1327 let expr = self.check_and_get(expr)?;
1328 self.array_length(expr, span)
1329 }
1330 },
1331 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1332 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1333 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1334 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1335 Expression::WorkGroupUniformLoadResult { .. } => {
1336 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1337 }
1338 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1339 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1340 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1341 Expression::ImageSample { .. }
1342 | Expression::ImageLoad { .. }
1343 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1344 Expression::RayQueryProceedResult
1345 | Expression::RayQueryGetIntersection { .. }
1346 | Expression::RayQueryVertexPositions { .. } => {
1347 Err(ConstantEvaluatorError::RayQueryExpression)
1348 }
1349 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1350 Expression::SubgroupOperationResult { .. } => {
1351 Err(ConstantEvaluatorError::SubgroupExpression)
1352 }
1353 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1354 Err(ConstantEvaluatorError::CooperativeOperation)
1355 }
1356 }
1357 }
1358
1359 fn splat(
1372 &mut self,
1373 value: Handle<Expression>,
1374 size: crate::VectorSize,
1375 span: Span,
1376 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1377 match self.expressions[value] {
1378 Expression::Literal(literal) => {
1379 let scalar = literal.scalar();
1380 let ty = self.types.insert(
1381 Type {
1382 name: None,
1383 inner: TypeInner::Vector { size, scalar },
1384 },
1385 span,
1386 );
1387 let expr = Expression::Compose {
1388 ty,
1389 components: vec![value; size as usize],
1390 };
1391 self.register_evaluated_expr(expr, span)
1392 }
1393 Expression::ZeroValue(ty) => {
1394 let inner = match self.types[ty].inner {
1395 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1396 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1397 };
1398 let res_ty = self.types.insert(Type { name: None, inner }, span);
1399 let expr = Expression::ZeroValue(res_ty);
1400 self.register_evaluated_expr(expr, span)
1401 }
1402 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1403 }
1404 }
1405
1406 fn swizzle(
1407 &mut self,
1408 size: crate::VectorSize,
1409 span: Span,
1410 src_constant: Handle<Expression>,
1411 pattern: [crate::SwizzleComponent; 4],
1412 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1413 let mut get_dst_ty = |ty| match self.types[ty].inner {
1414 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1415 Type {
1416 name: None,
1417 inner: TypeInner::Vector { size, scalar },
1418 },
1419 span,
1420 )),
1421 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1422 };
1423
1424 match self.expressions[src_constant] {
1425 Expression::ZeroValue(ty) => {
1426 let dst_ty = get_dst_ty(ty)?;
1427 let expr = Expression::ZeroValue(dst_ty);
1428 self.register_evaluated_expr(expr, span)
1429 }
1430 Expression::Splat { value, .. } => {
1431 let expr = Expression::Splat { size, value };
1432 self.register_evaluated_expr(expr, span)
1433 }
1434 Expression::Compose { ty, ref components } => {
1435 let dst_ty = get_dst_ty(ty)?;
1436
1437 let mut flattened = [src_constant; 4]; let len =
1439 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1440 .zip(flattened.iter_mut())
1441 .map(|(component, elt)| *elt = component)
1442 .count();
1443 let flattened = &flattened[..len];
1444
1445 let swizzled_components = pattern[..size as usize]
1446 .iter()
1447 .map(|&sc| {
1448 let sc = sc as usize;
1449 if let Some(elt) = flattened.get(sc) {
1450 Ok(*elt)
1451 } else {
1452 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1453 }
1454 })
1455 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1456 let expr = Expression::Compose {
1457 ty: dst_ty,
1458 components: swizzled_components,
1459 };
1460 self.register_evaluated_expr(expr, span)
1461 }
1462 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1463 }
1464 }
1465
1466 fn math(
1467 &mut self,
1468 arg: Handle<Expression>,
1469 arg1: Option<Handle<Expression>>,
1470 arg2: Option<Handle<Expression>>,
1471 arg3: Option<Handle<Expression>>,
1472 fun: crate::MathFunction,
1473 span: Span,
1474 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1475 let expected = fun.argument_count();
1476 let given = Some(arg)
1477 .into_iter()
1478 .chain(arg1)
1479 .chain(arg2)
1480 .chain(arg3)
1481 .count();
1482 if expected != given {
1483 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1484 fun, expected, given,
1485 ));
1486 }
1487
1488 match fun {
1490 crate::MathFunction::Abs => {
1492 component_wise_scalar(self, span, [arg], |args| match args {
1493 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1494 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1495 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1496 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1497 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1498 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1500 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1501 })
1502 }
1503 crate::MathFunction::Min => {
1504 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1505 Ok([e1.min(e2)])
1506 })
1507 }
1508 crate::MathFunction::Max => {
1509 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1510 Ok([e1.max(e2)])
1511 })
1512 }
1513 crate::MathFunction::Clamp => {
1514 component_wise_scalar!(
1515 self,
1516 span,
1517 [arg, arg1.unwrap(), arg2.unwrap()],
1518 |e, low, high| {
1519 if low > high {
1520 Err(ConstantEvaluatorError::InvalidClamp)
1521 } else {
1522 Ok([e.clamp(low, high)])
1523 }
1524 }
1525 )
1526 }
1527 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1528 Float::F16([e]) => Ok(Float::F16(
1529 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1530 )),
1531 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1532 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1533 }),
1534
1535 crate::MathFunction::Cos => {
1537 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1538 }
1539 crate::MathFunction::Cosh => {
1540 component_wise_float!(self, span, [arg], |e| {
1541 let result = e.cosh();
1542 if result.is_finite() {
1543 Ok([result])
1544 } else {
1545 Err(ConstantEvaluatorError::Overflow("cosh".into()))
1546 }
1547 })
1548 }
1549 crate::MathFunction::Sin => {
1550 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1551 }
1552 crate::MathFunction::Sinh => {
1553 component_wise_float!(self, span, [arg], |e| {
1554 let result = e.sinh();
1555 if result.is_finite() {
1556 Ok([result])
1557 } else {
1558 Err(ConstantEvaluatorError::Overflow("sinh".into()))
1559 }
1560 })
1561 }
1562 crate::MathFunction::Tan => {
1563 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1564 }
1565 crate::MathFunction::Tanh => {
1566 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1567 }
1568 crate::MathFunction::Acos => {
1569 component_wise_float!(self, span, [arg], |e| {
1570 if e.abs() <= One::one() {
1571 Ok([e.acos()])
1572 } else {
1573 Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1574 }
1575 })
1576 }
1577 crate::MathFunction::Asin => {
1578 component_wise_float!(self, span, [arg], |e| {
1579 if e.abs() <= One::one() {
1580 Ok([e.asin()])
1581 } else {
1582 Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1583 }
1584 })
1585 }
1586 crate::MathFunction::Atan => {
1587 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1588 }
1589 crate::MathFunction::Atan2 => {
1590 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1591 Ok([y.atan2(x)])
1592 })
1593 }
1594 crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1595 Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1596 Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1597 Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1598 }),
1599 crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1600 Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1601 Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1602 Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1603 _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1604 }),
1605 crate::MathFunction::Atanh => {
1606 component_wise_float!(self, span, [arg], |e| {
1607 if e.abs() < One::one() {
1608 Ok([e.atanh()])
1609 } else {
1610 Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1611 }
1612 })
1613 }
1614 crate::MathFunction::Radians => {
1615 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1616 }
1617 crate::MathFunction::Degrees => {
1618 component_wise_float!(self, span, [arg], |e| {
1619 let result = e.to_degrees();
1620 if result.is_finite() {
1621 Ok([result])
1622 } else {
1623 Err(ConstantEvaluatorError::Overflow("degrees".into()))
1624 }
1625 })
1626 }
1627
1628 crate::MathFunction::Ceil => {
1630 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1631 }
1632 crate::MathFunction::Floor => {
1633 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1634 }
1635 crate::MathFunction::Round => {
1636 component_wise_float(self, span, [arg], |e| match e {
1637 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1638 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1639 Float::F16([e]) => {
1640 fn round_ties_even(x: f64) -> f64 {
1648 let i = x as i64;
1649 let f = (x - i as f64).abs();
1650 if f == 0.5 {
1651 if i & 1 == 1 {
1652 (x.abs() + 0.5).copysign(x)
1654 } else {
1655 (x.abs() - 0.5).copysign(x)
1656 }
1657 } else {
1658 x.round()
1659 }
1660 }
1661
1662 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1663 }
1664 })
1665 }
1666 crate::MathFunction::Fract => {
1667 component_wise_float!(self, span, [arg], |e| {
1668 Ok([e - e.floor()])
1671 })
1672 }
1673 crate::MathFunction::Trunc => {
1674 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1675 }
1676
1677 crate::MathFunction::Exp => {
1679 component_wise_float!(self, span, [arg], |e| {
1680 let result = e.exp();
1681 if result.is_finite() {
1682 Ok([result])
1683 } else {
1684 Err(ConstantEvaluatorError::Overflow("exp".into()))
1685 }
1686 })
1687 }
1688 crate::MathFunction::Exp2 => {
1689 component_wise_float!(self, span, [arg], |e| {
1690 let result = e.exp2();
1691 if result.is_finite() {
1692 Ok([result])
1693 } else {
1694 Err(ConstantEvaluatorError::Overflow("exp2".into()))
1695 }
1696 })
1697 }
1698 crate::MathFunction::Log => {
1699 component_wise_float!(self, span, [arg], |e| {
1700 if e > Zero::zero() {
1701 Ok([e.ln()])
1702 } else {
1703 Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1704 }
1705 })
1706 }
1707 crate::MathFunction::Log2 => {
1708 component_wise_float!(self, span, [arg], |e| {
1709 if e > Zero::zero() {
1710 Ok([e.log2()])
1711 } else {
1712 Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1713 }
1714 })
1715 }
1716 crate::MathFunction::Pow => {
1717 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1718 if e1 < Zero::zero()
1721 || e1.is_one() && e2.is_infinite()
1722 || e1.is_infinite() && e2.is_zero()
1723 || e1.is_zero() && e2.is_zero()
1724 {
1725 Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1726 } else {
1727 let result = e1.powf(e2);
1728 if result.is_finite() {
1729 Ok([result])
1730 } else {
1731 Err(ConstantEvaluatorError::Overflow("pow".into()))
1732 }
1733 }
1734 })
1735 }
1736
1737 crate::MathFunction::Sign => {
1739 component_wise_signed!(self, span, [arg], |e| {
1740 Ok([if e.is_zero() {
1741 Zero::zero()
1742 } else {
1743 e.signum()
1744 }])
1745 })
1746 }
1747 crate::MathFunction::Fma => {
1748 component_wise_float!(
1749 self,
1750 span,
1751 [arg, arg1.unwrap(), arg2.unwrap()],
1752 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1753 )
1754 }
1755 crate::MathFunction::Step => {
1756 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1757 Float::Abstract([edge, x]) => {
1758 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1759 }
1760 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1761 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1762 f16::one()
1763 } else {
1764 f16::zero()
1765 }])),
1766 })
1767 }
1768 crate::MathFunction::Sqrt => {
1769 component_wise_float!(self, span, [arg], |e| {
1770 if e >= Zero::zero() {
1771 Ok([e.sqrt()])
1772 } else {
1773 Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1774 }
1775 })
1776 }
1777 crate::MathFunction::InverseSqrt => {
1778 component_wise_float(self, span, [arg], |e| match e {
1779 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1780 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1781 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1782 })
1783 }
1784
1785 crate::MathFunction::CountTrailingZeros => {
1787 component_wise_concrete_int!(self, span, [arg], |e| {
1788 #[allow(clippy::useless_conversion)]
1789 Ok([e
1790 .trailing_zeros()
1791 .try_into()
1792 .expect("bit count overflowed 32 bits, somehow!?")])
1793 })
1794 }
1795 crate::MathFunction::CountLeadingZeros => {
1796 component_wise_concrete_int!(self, span, [arg], |e| {
1797 #[allow(clippy::useless_conversion)]
1798 Ok([e
1799 .leading_zeros()
1800 .try_into()
1801 .expect("bit count overflowed 32 bits, somehow!?")])
1802 })
1803 }
1804 crate::MathFunction::CountOneBits => {
1805 component_wise_concrete_int!(self, span, [arg], |e| {
1806 #[allow(clippy::useless_conversion)]
1807 Ok([e
1808 .count_ones()
1809 .try_into()
1810 .expect("bit count overflowed 32 bits, somehow!?")])
1811 })
1812 }
1813 crate::MathFunction::ReverseBits => {
1814 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1815 }
1816 crate::MathFunction::FirstTrailingBit => {
1817 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1818 }
1819 crate::MathFunction::FirstLeadingBit => {
1820 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1821 }
1822
1823 crate::MathFunction::Dot4I8Packed => {
1825 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1826 }
1827 crate::MathFunction::Dot4U8Packed => {
1828 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1829 }
1830 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1831 crate::MathFunction::Dot => {
1832 let e1 = self.extract_vec(arg, false)?;
1834 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1835 if e1.len() != e2.len() {
1836 return Err(ConstantEvaluatorError::InvalidMathArg);
1837 }
1838
1839 fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1840 where
1841 P: num_traits::Float,
1842 {
1843 let result = a
1844 .iter()
1845 .zip(b.iter())
1846 .map(|(&aa, &bb)| aa * bb)
1847 .fold(P::zero(), |acc, x| acc + x);
1848 if result.is_finite() {
1849 Ok(result)
1850 } else {
1851 Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1852 }
1853 }
1854
1855 fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1856 where
1857 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1858 {
1859 a.iter()
1860 .zip(b.iter())
1861 .map(|(&aa, bb)| aa.checked_mul(bb))
1862 .try_fold(P::zero(), |acc, x| {
1863 if let Some(x) = x {
1864 acc.checked_add(&x)
1865 } else {
1866 None
1867 }
1868 })
1869 .ok_or(ConstantEvaluatorError::Overflow(
1870 "in dot built-in".to_string(),
1871 ))
1872 }
1873
1874 fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1875 where
1876 P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1877 {
1878 a.iter()
1879 .zip(b.iter())
1880 .map(|(&aa, bb)| aa.wrapping_mul(bb))
1881 .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1882 }
1883
1884 let result = match_literal_vector!(match (e1, e2) => Literal {
1885 Float => |e1, e2| { float_dot_checked(e1, e2)? },
1886 AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1887 I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1888 U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1889 })?;
1890 self.register_evaluated_expr(Expression::Literal(result), span)
1891 }
1892 crate::MathFunction::Length => {
1893 let e1 = self.extract_vec(arg, true)?;
1895
1896 let result = match_literal_vector!(match e1 => Literal {
1897 Float => |e1| {
1898 float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1899 },
1900 })?;
1901 self.register_evaluated_expr(Expression::Literal(result), span)
1902 }
1903 crate::MathFunction::Distance => {
1904 let e1 = self.extract_vec(arg, true)?;
1906 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1907 if e1.len() != e2.len() {
1908 return Err(ConstantEvaluatorError::InvalidMathArg);
1909 }
1910
1911 fn float_distance<F>(a: &[F], b: &[F]) -> F
1912 where
1913 F: core::ops::Mul<F>,
1914 F: num_traits::Float + iter::Sum + core::ops::Sub,
1915 {
1916 if a.len() == 1 {
1917 (a[0] - b[0]).abs()
1919 } else {
1920 a.iter()
1921 .zip(b.iter())
1922 .map(|(&aa, &bb)| aa - bb)
1923 .map(|ei| ei * ei)
1924 .sum::<F>()
1925 .sqrt()
1926 }
1927 }
1928 let result = match_literal_vector!(match (e1, e2) => Literal {
1929 Float => |e1, e2| { float_distance(e1, e2) },
1930 })?;
1931 self.register_evaluated_expr(Expression::Literal(result), span)
1932 }
1933 crate::MathFunction::Normalize => {
1934 let e1 = self.extract_vec(arg, true)?;
1936
1937 fn float_normalize<F>(
1938 e: &[F],
1939 ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1940 where
1941 F: core::ops::Mul<F>,
1942 F: num_traits::Float + iter::Sum,
1943 {
1944 let len = match float_length(e) {
1945 Some(len) if !len.is_zero() => Ok(len),
1946 Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1947 "normalize".into(),
1948 )),
1949 None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1950 }?;
1951
1952 let mut out = ArrayVec::new();
1953 for &ei in e {
1954 out.push(ei / len);
1955 }
1956 Ok(out)
1957 }
1958
1959 let result = match_literal_vector!(match e1 => LiteralVector {
1960 Float => |e1| { float_normalize(e1)? },
1961 })?;
1962 result.register_as_evaluated_expr(self, span)
1963 }
1964
1965 crate::MathFunction::Modf
1967 | crate::MathFunction::Frexp
1968 | crate::MathFunction::Ldexp
1969 | crate::MathFunction::Outer
1970 | crate::MathFunction::FaceForward
1971 | crate::MathFunction::Reflect
1972 | crate::MathFunction::Refract
1973 | crate::MathFunction::Mix
1974 | crate::MathFunction::SmoothStep
1975 | crate::MathFunction::Inverse
1976 | crate::MathFunction::Transpose
1977 | crate::MathFunction::Determinant
1978 | crate::MathFunction::QuantizeToF16
1979 | crate::MathFunction::ExtractBits
1980 | crate::MathFunction::InsertBits
1981 | crate::MathFunction::Pack4x8snorm
1982 | crate::MathFunction::Pack4x8unorm
1983 | crate::MathFunction::Pack2x16snorm
1984 | crate::MathFunction::Pack2x16unorm
1985 | crate::MathFunction::Pack2x16float
1986 | crate::MathFunction::Pack4xI8
1987 | crate::MathFunction::Pack4xU8
1988 | crate::MathFunction::Pack4xI8Clamp
1989 | crate::MathFunction::Pack4xU8Clamp
1990 | crate::MathFunction::Unpack4x8snorm
1991 | crate::MathFunction::Unpack4x8unorm
1992 | crate::MathFunction::Unpack2x16snorm
1993 | crate::MathFunction::Unpack2x16unorm
1994 | crate::MathFunction::Unpack2x16float
1995 | crate::MathFunction::Unpack4xI8
1996 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1997 format!("{fun:?} built-in function"),
1998 )),
1999 }
2000 }
2001
2002 fn packed_dot_product(
2004 &mut self,
2005 a: Handle<Expression>,
2006 b: Handle<Expression>,
2007 span: Span,
2008 signed: bool,
2009 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2010 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
2011 return Err(ConstantEvaluatorError::InvalidMathArg);
2012 };
2013 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2014 return Err(ConstantEvaluatorError::InvalidMathArg);
2015 };
2016
2017 let result = if signed {
2018 Literal::I32(
2019 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2020 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2021 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2022 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2023 )
2024 } else {
2025 Literal::U32(
2026 (a & 0xFF) * (b & 0xFF)
2027 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2028 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2029 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2030 )
2031 };
2032
2033 self.register_evaluated_expr(Expression::Literal(result), span)
2034 }
2035
2036 fn cross_product(
2038 &mut self,
2039 a: Handle<Expression>,
2040 b: Handle<Expression>,
2041 span: Span,
2042 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2043 use Literal as Li;
2044
2045 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2046 let (b, _) = self.extract_vec_with_size::<3>(b)?;
2047
2048 let product = match (a, b) {
2049 (
2050 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2051 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2052 ) => {
2053 let p = cross_product(
2058 [a0 as f64, a1 as f64, a2 as f64],
2059 [b0 as f64, b1 as f64, b2 as f64],
2060 );
2061 [
2062 Li::AbstractFloat(p[0]),
2063 Li::AbstractFloat(p[1]),
2064 Li::AbstractFloat(p[2]),
2065 ]
2066 }
2067 (
2068 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2069 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2070 ) => {
2071 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2072 [
2073 Li::AbstractFloat(p[0]),
2074 Li::AbstractFloat(p[1]),
2075 Li::AbstractFloat(p[2]),
2076 ]
2077 }
2078 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2079 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2080 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2081 }
2082 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2083 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2084 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2085 }
2086 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2087 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2088 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2089 }
2090 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2091 };
2092
2093 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2094 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2095 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2096
2097 self.register_evaluated_expr(
2098 Expression::Compose {
2099 ty,
2100 components: vec![p0, p1, p2],
2101 },
2102 span,
2103 )
2104 }
2105
2106 fn extract_vec_with_size<const N: usize>(
2114 &mut self,
2115 expr: Handle<Expression>,
2116 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2117 let span = self.expressions.get_span(expr);
2118 let expr = self.eval_zero_value_and_splat(expr, span)?;
2119 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2120 return Err(ConstantEvaluatorError::InvalidMathArg);
2121 };
2122
2123 let mut value = [Literal::Bool(false); N];
2124 for (component, elt) in
2125 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2126 .zip(value.iter_mut())
2127 {
2128 let Expression::Literal(literal) = self.expressions[component] else {
2129 return Err(ConstantEvaluatorError::InvalidMathArg);
2130 };
2131 *elt = literal;
2132 }
2133
2134 Ok((value, ty))
2135 }
2136
2137 fn extract_vec(
2145 &mut self,
2146 expr: Handle<Expression>,
2147 allow_single: bool,
2148 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2149 let span = self.expressions.get_span(expr);
2150 let expr = self.eval_zero_value_and_splat(expr, span)?;
2151
2152 match self.expressions[expr] {
2153 Expression::Literal(literal) if allow_single => {
2154 Ok(LiteralVector::from_literal(literal))
2155 }
2156 Expression::Compose { ty, ref components } => {
2157 let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2158 for expr in
2159 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2160 {
2161 match self.expressions[expr] {
2162 Expression::Literal(l) => components_out.push(l),
2163 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2164 }
2165 }
2166 LiteralVector::from_literal_vec(components_out)
2167 }
2168 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2169 }
2170 }
2171
2172 fn array_length(
2173 &mut self,
2174 array: Handle<Expression>,
2175 span: Span,
2176 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2177 match self.expressions[array] {
2178 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2179 match self.types[ty].inner {
2180 TypeInner::Array { size, .. } => match size {
2181 ArraySize::Constant(len) => {
2182 let expr = Expression::Literal(Literal::U32(len.get()));
2183 self.register_evaluated_expr(expr, span)
2184 }
2185 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2186 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2187 },
2188 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2189 }
2190 }
2191 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2192 }
2193 }
2194
2195 fn access(
2196 &mut self,
2197 base: Handle<Expression>,
2198 index: usize,
2199 span: Span,
2200 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2201 match self.expressions[base] {
2202 Expression::ZeroValue(ty) => {
2203 let ty_inner = &self.types[ty].inner;
2204 let components = ty_inner
2205 .components()
2206 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2207
2208 if index >= components as usize {
2209 Err(ConstantEvaluatorError::InvalidAccessBase)
2210 } else {
2211 let ty_res = ty_inner
2212 .component_type(index)
2213 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2214 let ty = match ty_res {
2215 crate::proc::TypeResolution::Handle(ty) => ty,
2216 crate::proc::TypeResolution::Value(inner) => {
2217 self.types.insert(Type { name: None, inner }, span)
2218 }
2219 };
2220 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2221 }
2222 }
2223 Expression::Splat { size, value } => {
2224 if index >= size as usize {
2225 Err(ConstantEvaluatorError::InvalidAccessBase)
2226 } else {
2227 Ok(value)
2228 }
2229 }
2230 Expression::Compose { ty, ref components } => {
2231 let _ = self.types[ty]
2232 .inner
2233 .components()
2234 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2235
2236 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2237 .nth(index)
2238 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2239 }
2240 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2241 }
2242 }
2243
2244 fn eval_zero_value_and_splat(
2251 &mut self,
2252 mut expr: Handle<Expression>,
2253 span: Span,
2254 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2255 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2258 let components = components
2259 .clone()
2260 .iter()
2261 .map(|component| self.eval_zero_value_and_splat(*component, span))
2262 .collect::<Result<_, _>>()?;
2263 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2264 }
2265
2266 if let Expression::Splat { size, value } = self.expressions[expr] {
2270 expr = self.splat(value, size, span)?;
2271 }
2272 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2273 expr = self.eval_zero_value_impl(ty, span)?;
2274 }
2275 Ok(expr)
2276 }
2277
2278 fn eval_zero_value(
2284 &mut self,
2285 expr: Handle<Expression>,
2286 span: Span,
2287 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2288 match self.expressions[expr] {
2289 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2290 _ => Ok(expr),
2291 }
2292 }
2293
2294 fn eval_zero_value_impl(
2300 &mut self,
2301 ty: Handle<Type>,
2302 span: Span,
2303 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2304 match self.types[ty].inner {
2305 TypeInner::Scalar(scalar) => {
2306 let expr = Expression::Literal(
2307 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2308 );
2309 self.register_evaluated_expr(expr, span)
2310 }
2311 TypeInner::Vector { size, scalar } => {
2312 let scalar_ty = self.types.insert(
2313 Type {
2314 name: None,
2315 inner: TypeInner::Scalar(scalar),
2316 },
2317 span,
2318 );
2319 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2320 let expr = Expression::Compose {
2321 ty,
2322 components: vec![el; size as usize],
2323 };
2324 self.register_evaluated_expr(expr, span)
2325 }
2326 TypeInner::Matrix {
2327 columns,
2328 rows,
2329 scalar,
2330 } => {
2331 let vec_ty = self.types.insert(
2332 Type {
2333 name: None,
2334 inner: TypeInner::Vector { size: rows, scalar },
2335 },
2336 span,
2337 );
2338 let el = self.eval_zero_value_impl(vec_ty, span)?;
2339 let expr = Expression::Compose {
2340 ty,
2341 components: vec![el; columns as usize],
2342 };
2343 self.register_evaluated_expr(expr, span)
2344 }
2345 TypeInner::Array {
2346 base,
2347 size: ArraySize::Constant(size),
2348 ..
2349 } => {
2350 let el = self.eval_zero_value_impl(base, span)?;
2351 let expr = Expression::Compose {
2352 ty,
2353 components: vec![el; size.get() as usize],
2354 };
2355 self.register_evaluated_expr(expr, span)
2356 }
2357 TypeInner::Struct { ref members, .. } => {
2358 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2359 let mut components = Vec::with_capacity(members.len());
2360 for ty in types {
2361 components.push(self.eval_zero_value_impl(ty, span)?);
2362 }
2363 let expr = Expression::Compose { ty, components };
2364 self.register_evaluated_expr(expr, span)
2365 }
2366 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2367 }
2368 }
2369
2370 pub fn cast(
2374 &mut self,
2375 expr: Handle<Expression>,
2376 target: crate::Scalar,
2377 span: Span,
2378 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2379 use crate::Scalar as Sc;
2380
2381 let expr = self.eval_zero_value(expr, span)?;
2382
2383 let make_error = || -> Result<_, ConstantEvaluatorError> {
2384 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2385
2386 #[cfg(feature = "wgsl-in")]
2387 let to = target.to_wgsl_for_diagnostics();
2388
2389 #[cfg(not(feature = "wgsl-in"))]
2390 let to = format!("{target:?}");
2391
2392 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2393 };
2394
2395 use crate::proc::type_methods::IntFloatLimits;
2396
2397 let expr = match self.expressions[expr] {
2398 Expression::Literal(literal) => {
2399 let literal = match target {
2400 Sc::I16 => Literal::I16(match literal {
2401 Literal::I16(v) => v,
2402 Literal::U16(v) => v as i16,
2403 Literal::I32(v) => v as i16,
2404 Literal::U32(v) => v as i16,
2405 Literal::F32(v) => v.clamp(i16::MIN as f32, i16::MAX as f32) as i16,
2406 Literal::F16(v) => {
2407 f16::to_f32(v).clamp(i16::MIN as f32, i16::MAX as f32) as i16
2408 }
2409 Literal::Bool(v) => v as i16,
2410 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2411 return make_error();
2412 }
2413 Literal::AbstractInt(v) => v as i16,
2414 Literal::AbstractFloat(v) => v as i16,
2415 }),
2416 Sc::U16 => Literal::U16(match literal {
2417 Literal::I16(v) => v as u16,
2418 Literal::U16(v) => v,
2419 Literal::I32(v) => v as u16,
2420 Literal::U32(v) => v as u16,
2421 Literal::F32(v) => v.clamp(u16::MIN as f32, u16::MAX as f32) as u16,
2422 Literal::F16(v) => f16::to_u16(&v.max(f16::ZERO)).unwrap(),
2423 Literal::Bool(v) => v as u16,
2424 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2425 return make_error();
2426 }
2427 Literal::AbstractInt(v) => v as u16,
2428 Literal::AbstractFloat(v) => v as u16,
2429 }),
2430 Sc::I32 => Literal::I32(match literal {
2431 Literal::I16(v) => v as i32,
2432 Literal::U16(v) => v as i32,
2433 Literal::I32(v) => v,
2434 Literal::U32(v) => v as i32,
2435 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2436 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2438 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2439 return make_error();
2440 }
2441 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2442 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2443 }),
2444 Sc::U32 => Literal::U32(match literal {
2445 Literal::I16(v) => v as u32,
2446 Literal::U16(v) => v as u32,
2447 Literal::I32(v) => v as u32,
2448 Literal::U32(v) => v,
2449 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2450 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2452 Literal::Bool(v) => v as u32,
2453 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2454 return make_error();
2455 }
2456 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2457 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2458 }),
2459 Sc::I64 => Literal::I64(match literal {
2460 Literal::I16(v) => v as i64,
2461 Literal::U16(v) => v as i64,
2462 Literal::I32(v) => v as i64,
2463 Literal::U32(v) => v as i64,
2464 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2465 Literal::Bool(v) => v as i64,
2466 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2467 Literal::I64(v) => v,
2468 Literal::U64(v) => v as i64,
2469 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2471 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2472 }),
2473 Sc::U64 => Literal::U64(match literal {
2474 Literal::I16(v) => v as u64,
2475 Literal::U16(v) => v as u64,
2476 Literal::I32(v) => v as u64,
2477 Literal::U32(v) => v as u64,
2478 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2479 Literal::Bool(v) => v as u64,
2480 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2481 Literal::I64(v) => v as u64,
2482 Literal::U64(v) => v,
2483 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2485 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2486 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2487 }),
2488 Sc::F16 => Literal::F16(match literal {
2489 Literal::F16(v) => v,
2490 Literal::F32(v) => f16::from_f32(v),
2491 Literal::F64(v) => f16::from_f64(v),
2492 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2493 Literal::I16(v) => f16::from_f32(v as f32),
2494 Literal::U16(v) => f16::from_f32(v as f32),
2495 Literal::I64(v) => f16::from_i64(v).unwrap(),
2496 Literal::U64(v) => f16::from_u64(v).unwrap(),
2497 Literal::I32(v) => f16::from_i32(v).unwrap(),
2498 Literal::U32(v) => f16::from_u32(v).unwrap(),
2499 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2500 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2501 }),
2502 Sc::F32 => Literal::F32(match literal {
2503 Literal::I16(v) => v as f32,
2504 Literal::U16(v) => v as f32,
2505 Literal::I32(v) => v as f32,
2506 Literal::U32(v) => v as f32,
2507 Literal::F32(v) => v,
2508 Literal::Bool(v) => v as u32 as f32,
2509 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2510 return make_error();
2511 }
2512 Literal::F16(v) => f16::to_f32(v),
2513 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2514 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2515 }),
2516 Sc::F64 => Literal::F64(match literal {
2517 Literal::I16(v) => v as f64,
2518 Literal::U16(v) => v as f64,
2519 Literal::I32(v) => v as f64,
2520 Literal::U32(v) => v as f64,
2521 Literal::F16(v) => f16::to_f64(v),
2522 Literal::F32(v) => v as f64,
2523 Literal::F64(v) => v,
2524 Literal::Bool(v) => v as u32 as f64,
2525 Literal::I64(_) | Literal::U64(_) => return make_error(),
2526 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2527 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2528 }),
2529 Sc::BOOL => Literal::Bool(match literal {
2530 Literal::I16(v) => v != 0,
2531 Literal::U16(v) => v != 0,
2532 Literal::I32(v) => v != 0,
2533 Literal::U32(v) => v != 0,
2534 Literal::F32(v) => v != 0.0,
2535 Literal::F16(v) => v != f16::zero(),
2536 Literal::Bool(v) => v,
2537 Literal::AbstractInt(v) => v != 0,
2538 Literal::AbstractFloat(v) => v != 0.0,
2539 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2540 return make_error();
2541 }
2542 }),
2543 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2544 Literal::AbstractInt(v) => {
2545 v as f64
2550 }
2551 Literal::AbstractFloat(v) => v,
2552 _ => return make_error(),
2553 }),
2554 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2555 Literal::AbstractInt(v) => v,
2556 _ => return make_error(),
2557 }),
2558 _ => {
2559 log::debug!("Constant evaluator refused to convert value to {target:?}");
2560 return make_error();
2561 }
2562 };
2563 Expression::Literal(literal)
2564 }
2565 Expression::Compose {
2566 ty,
2567 components: ref src_components,
2568 } => {
2569 let ty_inner = match self.types[ty].inner {
2570 TypeInner::Vector { size, .. } => TypeInner::Vector {
2571 size,
2572 scalar: target,
2573 },
2574 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2575 columns,
2576 rows,
2577 scalar: target,
2578 },
2579 _ => return make_error(),
2580 };
2581
2582 let mut components = src_components.clone();
2583 for component in &mut components {
2584 *component = self.cast(*component, target, span)?;
2585 }
2586
2587 let ty = self.types.insert(
2588 Type {
2589 name: None,
2590 inner: ty_inner,
2591 },
2592 span,
2593 );
2594
2595 Expression::Compose { ty, components }
2596 }
2597 Expression::Splat { size, value } => {
2598 let value_span = self.expressions.get_span(value);
2599 let cast_value = self.cast(value, target, value_span)?;
2600 Expression::Splat {
2601 size,
2602 value: cast_value,
2603 }
2604 }
2605 _ => return make_error(),
2606 };
2607
2608 self.register_evaluated_expr(expr, span)
2609 }
2610
2611 pub fn cast_array(
2624 &mut self,
2625 expr: Handle<Expression>,
2626 target: crate::Scalar,
2627 span: Span,
2628 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2629 let expr = self.check_and_get(expr)?;
2630
2631 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2632 return self.cast(expr, target, span);
2633 };
2634
2635 let TypeInner::Array {
2636 base: _,
2637 size,
2638 stride: _,
2639 } = self.types[ty].inner
2640 else {
2641 return self.cast(expr, target, span);
2642 };
2643
2644 let mut components = components.clone();
2645 for component in &mut components {
2646 *component = self.cast_array(*component, target, span)?;
2647 }
2648
2649 let first = components.first().unwrap();
2650 let new_base = match self.resolve_type(*first)? {
2651 crate::proc::TypeResolution::Handle(ty) => ty,
2652 crate::proc::TypeResolution::Value(inner) => {
2653 self.types.insert(Type { name: None, inner }, span)
2654 }
2655 };
2656 let mut layouter = core::mem::take(self.layouter);
2657 layouter.update(self.to_ctx()).map_err(|err| {
2658 let crate::proc::LayoutErrorInner::TooLarge = err.inner else {
2661 unreachable!("unexpected layout error: {err:?}");
2662 };
2663 ConstantEvaluatorError::TypeTooLarge(err.ty)
2664 })?;
2665 *self.layouter = layouter;
2666
2667 let new_base_stride = self.layouter[new_base].to_stride();
2668 let new_array_ty = self.types.insert(
2669 Type {
2670 name: None,
2671 inner: TypeInner::Array {
2672 base: new_base,
2673 size,
2674 stride: new_base_stride,
2675 },
2676 },
2677 span,
2678 );
2679
2680 let compose = Expression::Compose {
2681 ty: new_array_ty,
2682 components,
2683 };
2684 self.register_evaluated_expr(compose, span)
2685 }
2686
2687 fn unary_op(
2688 &mut self,
2689 op: UnaryOperator,
2690 expr: Handle<Expression>,
2691 span: Span,
2692 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2693 let expr = self.eval_zero_value_and_splat(expr, span)?;
2694
2695 let expr = match self.expressions[expr] {
2696 Expression::Literal(value) => Expression::Literal(match op {
2697 UnaryOperator::Negate => match value {
2698 Literal::I16(v) => Literal::I16(v.wrapping_neg()),
2699 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2700 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2701 Literal::F32(v) => Literal::F32(-v),
2702 Literal::F16(v) => Literal::F16(-v),
2703 Literal::F64(v) => Literal::F64(-v),
2704 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2705 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2706 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2707 },
2708 UnaryOperator::LogicalNot => match value {
2709 Literal::Bool(v) => Literal::Bool(!v),
2710 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2711 },
2712 UnaryOperator::BitwiseNot => match value {
2713 Literal::I16(v) => Literal::I16(!v),
2714 Literal::I32(v) => Literal::I32(!v),
2715 Literal::I64(v) => Literal::I64(!v),
2716 Literal::U16(v) => Literal::U16(!v),
2717 Literal::U32(v) => Literal::U32(!v),
2718 Literal::U64(v) => Literal::U64(!v),
2719 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2720 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2721 },
2722 }),
2723 Expression::Compose {
2724 ty,
2725 components: ref src_components,
2726 } => {
2727 match self.types[ty].inner {
2728 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2729 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2730 }
2731
2732 let mut components = src_components.clone();
2733 for component in &mut components {
2734 *component = self.unary_op(op, *component, span)?;
2735 }
2736
2737 Expression::Compose { ty, components }
2738 }
2739 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2740 };
2741
2742 self.register_evaluated_expr(expr, span)
2743 }
2744
2745 fn binary_op(
2746 &mut self,
2747 op: BinaryOperator,
2748 left: Handle<Expression>,
2749 right: Handle<Expression>,
2750 span: Span,
2751 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2752 let left = self.eval_zero_value_and_splat(left, span)?;
2753 let right = self.eval_zero_value_and_splat(right, span)?;
2754
2755 let expr = match (&self.expressions[left], &self.expressions[right]) {
2760 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2761 if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2762 && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2763 {
2764 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2765 }
2766
2767 if matches!(
2768 (left_value, op),
2769 (
2770 Literal::Bool(_),
2771 BinaryOperator::Less
2772 | BinaryOperator::LessEqual
2773 | BinaryOperator::Greater
2774 | BinaryOperator::GreaterEqual
2775 )
2776 ) {
2777 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2778 }
2779
2780 let literal = match op {
2781 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2782 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2783 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2784 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2785 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2786 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2787
2788 _ => match (left_value, right_value) {
2789 (Literal::I16(a), Literal::I16(b)) => Literal::I16(match op {
2790 BinaryOperator::Add => a.wrapping_add(b),
2791 BinaryOperator::Subtract => a.wrapping_sub(b),
2792 BinaryOperator::Multiply => a.wrapping_mul(b),
2793 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2794 if b == 0 {
2795 ConstantEvaluatorError::DivisionByZero
2796 } else {
2797 ConstantEvaluatorError::Overflow("division".into())
2798 }
2799 })?,
2800 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2801 if b == 0 {
2802 ConstantEvaluatorError::RemainderByZero
2803 } else {
2804 ConstantEvaluatorError::Overflow("remainder".into())
2805 }
2806 })?,
2807 BinaryOperator::And => a & b,
2808 BinaryOperator::ExclusiveOr => a ^ b,
2809 BinaryOperator::InclusiveOr => a | b,
2810 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2811 }),
2812 (Literal::I16(a), Literal::U32(b)) => Literal::I16(match op {
2813 BinaryOperator::ShiftLeft => {
2814 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2815 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2816 }
2817 a.checked_shl(b)
2818 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2819 }
2820 BinaryOperator::ShiftRight => a
2821 .checked_shr(b)
2822 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2823 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2824 }),
2825 (Literal::U16(a), Literal::U16(b)) => Literal::U16(match op {
2826 BinaryOperator::Add => a.wrapping_add(b),
2827 BinaryOperator::Subtract => a.wrapping_sub(b),
2828 BinaryOperator::Multiply => a.wrapping_mul(b),
2829 BinaryOperator::Divide => a
2830 .checked_div(b)
2831 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2832 BinaryOperator::Modulo => a
2833 .checked_rem(b)
2834 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2835 BinaryOperator::And => a & b,
2836 BinaryOperator::ExclusiveOr => a ^ b,
2837 BinaryOperator::InclusiveOr => a | b,
2838 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2839 }),
2840 (Literal::U16(a), Literal::U32(b)) => Literal::U16(match op {
2841 BinaryOperator::ShiftLeft => a
2842 .checked_mul(
2843 1u16.checked_shl(b)
2844 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2845 )
2846 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2847 BinaryOperator::ShiftRight => a
2848 .checked_shr(b)
2849 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2850 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2851 }),
2852 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2853 BinaryOperator::Add => a.wrapping_add(b),
2854 BinaryOperator::Subtract => a.wrapping_sub(b),
2855 BinaryOperator::Multiply => a.wrapping_mul(b),
2856 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2857 if b == 0 {
2858 ConstantEvaluatorError::DivisionByZero
2859 } else {
2860 ConstantEvaluatorError::Overflow("division".into())
2861 }
2862 })?,
2863 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2864 if b == 0 {
2865 ConstantEvaluatorError::RemainderByZero
2866 } else {
2867 ConstantEvaluatorError::Overflow("remainder".into())
2868 }
2869 })?,
2870 BinaryOperator::And => a & b,
2871 BinaryOperator::ExclusiveOr => a ^ b,
2872 BinaryOperator::InclusiveOr => a | b,
2873 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2874 }),
2875 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2876 BinaryOperator::ShiftLeft => {
2877 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2878 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2879 }
2880 a.checked_shl(b)
2881 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2882 }
2883 BinaryOperator::ShiftRight => a
2884 .checked_shr(b)
2885 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2886 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2887 }),
2888 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2889 BinaryOperator::Add => a.wrapping_add(b),
2890 BinaryOperator::Subtract => a.wrapping_sub(b),
2891 BinaryOperator::Multiply => a.wrapping_mul(b),
2892 BinaryOperator::Divide => a
2893 .checked_div(b)
2894 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2895 BinaryOperator::Modulo => a
2896 .checked_rem(b)
2897 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2898 BinaryOperator::And => a & b,
2899 BinaryOperator::ExclusiveOr => a ^ b,
2900 BinaryOperator::InclusiveOr => a | b,
2901 BinaryOperator::ShiftLeft => a
2902 .checked_mul(
2903 1u32.checked_shl(b)
2904 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2905 )
2906 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2907 BinaryOperator::ShiftRight => a
2908 .checked_shr(b)
2909 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2910 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2911 }),
2912 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2913 BinaryOperator::Add => a + b,
2914 BinaryOperator::Subtract => a - b,
2915 BinaryOperator::Multiply => a * b,
2916 BinaryOperator::Divide => a / b,
2917 BinaryOperator::Modulo => a % b,
2918 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2919 }),
2920 (Literal::AbstractInt(a), Literal::U32(b)) => {
2921 Literal::AbstractInt(match op {
2922 BinaryOperator::ShiftLeft => {
2923 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2924 return Err(ConstantEvaluatorError::Overflow(
2925 "<<".to_string(),
2926 ));
2927 }
2928 a.checked_shl(b).unwrap_or(0)
2929 }
2930 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2931 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2932 })
2933 }
2934 (Literal::F16(a), Literal::F16(b)) => {
2935 let result = match op {
2936 BinaryOperator::Add => a + b,
2937 BinaryOperator::Subtract => a - b,
2938 BinaryOperator::Multiply => a * b,
2939 BinaryOperator::Divide => a / b,
2940 BinaryOperator::Modulo => a % b,
2941 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2942 };
2943 if !result.is_finite() {
2944 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2945 }
2946 Literal::F16(result)
2947 }
2948 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2949 Literal::AbstractInt(match op {
2950 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2951 ConstantEvaluatorError::Overflow("addition".into())
2952 })?,
2953 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2954 ConstantEvaluatorError::Overflow("subtraction".into())
2955 })?,
2956 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2957 ConstantEvaluatorError::Overflow("multiplication".into())
2958 })?,
2959 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2960 if b == 0 {
2961 ConstantEvaluatorError::DivisionByZero
2962 } else {
2963 ConstantEvaluatorError::Overflow("division".into())
2964 }
2965 })?,
2966 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2967 if b == 0 {
2968 ConstantEvaluatorError::RemainderByZero
2969 } else {
2970 ConstantEvaluatorError::Overflow("remainder".into())
2971 }
2972 })?,
2973 BinaryOperator::And => a & b,
2974 BinaryOperator::ExclusiveOr => a ^ b,
2975 BinaryOperator::InclusiveOr => a | b,
2976 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2977 })
2978 }
2979 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2980 let result = match op {
2981 BinaryOperator::Add => a + b,
2982 BinaryOperator::Subtract => a - b,
2983 BinaryOperator::Multiply => a * b,
2984 BinaryOperator::Divide => a / b,
2985 BinaryOperator::Modulo => a % b,
2986 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2987 };
2988 if !result.is_finite() {
2989 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2990 }
2991 Literal::AbstractFloat(result)
2992 }
2993 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2994 BinaryOperator::LogicalAnd => a && b,
2995 BinaryOperator::LogicalOr => a || b,
2996 BinaryOperator::And => a & b,
2997 BinaryOperator::InclusiveOr => a | b,
2998 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2999 }),
3000 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3001 },
3002 };
3003 Expression::Literal(literal)
3004 }
3005 (
3006 &Expression::Compose {
3007 components: ref src_components,
3008 ty,
3009 },
3010 &Expression::Literal(_),
3011 ) => {
3012 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3013 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3014 }
3015 let mut components = src_components.clone();
3016 for component in &mut components {
3017 *component = self.binary_op(op, *component, right, span)?;
3018 }
3019 Expression::Compose { ty, components }
3020 }
3021 (
3022 &Expression::Literal(_),
3023 &Expression::Compose {
3024 components: ref src_components,
3025 ty,
3026 },
3027 ) => {
3028 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3029 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3030 }
3031 let mut components = src_components.clone();
3032 for component in &mut components {
3033 *component = self.binary_op(op, left, *component, span)?;
3034 }
3035 Expression::Compose { ty, components }
3036 }
3037 (
3038 &Expression::Compose {
3039 components: ref left_components,
3040 ty: left_ty,
3041 },
3042 &Expression::Compose {
3043 components: ref right_components,
3044 ty: right_ty,
3045 },
3046 ) => {
3047 let left_flattened = crate::proc::flatten_compose(
3051 left_ty,
3052 left_components,
3053 self.expressions,
3054 self.types,
3055 )
3056 .collect::<Vec<_>>();
3057 let right_flattened = crate::proc::flatten_compose(
3058 right_ty,
3059 right_components,
3060 self.expressions,
3061 self.types,
3062 )
3063 .collect::<Vec<_>>();
3064
3065 self.binary_op_compose(
3066 op,
3067 &left_flattened,
3068 &right_flattened,
3069 left_ty,
3070 right_ty,
3071 span,
3072 )?
3073 }
3074 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3075 };
3076
3077 return self.register_evaluated_expr(expr, span);
3078
3079 fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
3080 let is_numeric_vec = matches!(
3081 compose_ty, TypeInner::Vector { scalar, .. }
3082 if scalar.kind != ScalarKind::Bool
3083 );
3084 let is_allowed_vec_scalar_op = matches!(
3085 op,
3086 BinaryOperator::Add
3087 | BinaryOperator::Subtract
3088 | BinaryOperator::Multiply
3089 | BinaryOperator::Divide
3090 | BinaryOperator::Modulo
3091 );
3092 let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
3093 let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
3094 is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
3095 }
3096 }
3097
3098 fn binary_op_compose(
3099 &mut self,
3100 op: BinaryOperator,
3101 left_components: &[Handle<Expression>],
3102 right_components: &[Handle<Expression>],
3103 left_ty: Handle<Type>,
3104 right_ty: Handle<Type>,
3105 span: Span,
3106 ) -> Result<Expression, ConstantEvaluatorError> {
3107 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
3108 (
3110 &TypeInner::Vector {
3111 size: left_size, ..
3112 },
3113 &TypeInner::Vector {
3114 size: right_size, ..
3115 },
3116 ) if left_size == right_size => self.binary_op_vector(
3117 op,
3118 left_size,
3119 left_components,
3120 right_components,
3121 left_ty,
3122 span,
3123 ),
3124 (
3126 &TypeInner::Vector { size, .. },
3127 &TypeInner::Matrix {
3128 columns,
3129 rows,
3130 scalar,
3131 },
3132 ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3133 left_components,
3134 right_components,
3135 columns,
3136 scalar,
3137 span,
3138 ),
3139 (
3141 &TypeInner::Matrix {
3142 columns,
3143 rows,
3144 scalar,
3145 },
3146 &TypeInner::Vector { size, .. },
3147 ) if op == BinaryOperator::Multiply && size == columns => {
3148 self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3149 }
3150 (
3152 &TypeInner::Matrix {
3153 columns: left_columns,
3154 rows: left_rows,
3155 scalar,
3156 },
3157 &TypeInner::Matrix {
3158 columns: right_columns,
3159 rows: right_rows,
3160 ..
3161 },
3162 ) => match op {
3163 BinaryOperator::Add | BinaryOperator::Subtract
3164 if left_columns == right_columns && left_rows == right_rows =>
3165 {
3166 let components = left_components
3167 .iter()
3168 .zip(right_components)
3169 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3170 .collect::<Result<Vec<_>, _>>()?;
3171 Ok(Expression::Compose {
3172 ty: left_ty,
3173 components,
3174 })
3175 }
3176 BinaryOperator::Multiply if left_columns == right_rows => self
3177 .multiply_matrix_matrix(
3178 left_components,
3179 right_components,
3180 left_rows,
3181 right_columns,
3182 scalar,
3183 span,
3184 ),
3185 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3186 },
3187 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3188 }
3189 }
3190
3191 fn binary_op_vector(
3192 &mut self,
3193 op: BinaryOperator,
3194 size: crate::VectorSize,
3195 left_components: &[Handle<Expression>],
3196 right_components: &[Handle<Expression>],
3197 left_ty: Handle<Type>,
3198 span: Span,
3199 ) -> Result<Expression, ConstantEvaluatorError> {
3200 let ty = match op {
3201 BinaryOperator::Equal
3203 | BinaryOperator::NotEqual
3204 | BinaryOperator::Less
3205 | BinaryOperator::LessEqual
3206 | BinaryOperator::Greater
3207 | BinaryOperator::GreaterEqual => self.types.insert(
3208 Type {
3209 name: None,
3210 inner: TypeInner::Vector {
3211 size,
3212 scalar: crate::Scalar::BOOL,
3213 },
3214 },
3215 span,
3216 ),
3217
3218 BinaryOperator::Add
3221 | BinaryOperator::Subtract
3222 | BinaryOperator::Multiply
3223 | BinaryOperator::Divide
3224 | BinaryOperator::Modulo
3225 | BinaryOperator::And
3226 | BinaryOperator::ExclusiveOr
3227 | BinaryOperator::InclusiveOr
3228 | BinaryOperator::ShiftLeft
3229 | BinaryOperator::ShiftRight => left_ty,
3230
3231 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3232 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3234 }
3235 };
3236
3237 let components = left_components
3238 .iter()
3239 .zip(right_components)
3240 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3241 .collect::<Result<Vec<_>, _>>()?;
3242
3243 Ok(Expression::Compose { ty, components })
3244 }
3245
3246 fn multiply_vector_matrix(
3247 &mut self,
3248 vec_components: &[Handle<Expression>],
3249 mat_components: &[Handle<Expression>],
3250 mat_columns: crate::VectorSize,
3251 scalar: crate::Scalar,
3252 span: Span,
3253 ) -> Result<Expression, ConstantEvaluatorError> {
3254 let ty = self.types.insert(
3255 Type {
3256 name: None,
3257 inner: TypeInner::Vector {
3258 size: mat_columns,
3259 scalar,
3260 },
3261 },
3262 span,
3263 );
3264 let components = mat_components
3265 .iter()
3266 .map(|&column| {
3267 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3268 unreachable!()
3269 };
3270 self.dot_exprs(
3271 vec_components.iter().cloned(),
3272 components.clone().into_iter(),
3273 span,
3274 )
3275 })
3276 .collect::<Result<Vec<_>, _>>()?;
3277 Ok(Expression::Compose { ty, components })
3278 }
3279
3280 fn multiply_matrix_vector(
3281 &mut self,
3282 mat_components: &[Handle<Expression>],
3283 vec_components: &[Handle<Expression>],
3284 mat_rows: crate::VectorSize,
3285 scalar: crate::Scalar,
3286 span: Span,
3287 ) -> Result<Expression, ConstantEvaluatorError> {
3288 let ty = self.types.insert(
3289 Type {
3290 name: None,
3291 inner: TypeInner::Vector {
3292 size: mat_rows,
3293 scalar,
3294 },
3295 },
3296 span,
3297 );
3298
3299 let flatten = self.flatten_matrix(mat_components);
3300 let nr = mat_rows as usize;
3301 let components = (0..nr)
3302 .map(|r| {
3303 let row = flatten.iter().skip(r).step_by(nr).cloned();
3304 self.dot_exprs(row, vec_components.iter().cloned(), span)
3305 })
3306 .collect::<Result<Vec<_>, _>>()?;
3307 Ok(Expression::Compose { ty, components })
3308 }
3309
3310 fn multiply_matrix_matrix(
3311 &mut self,
3312 left_components: &[Handle<Expression>],
3313 right_components: &[Handle<Expression>],
3314 left_rows: crate::VectorSize,
3315 right_columns: crate::VectorSize,
3316 scalar: crate::Scalar,
3317 span: Span,
3318 ) -> Result<Expression, ConstantEvaluatorError> {
3319 let left_nc = left_components.len();
3320 let left_nr = left_rows as usize;
3321 let right_nc = right_columns as usize;
3322 let right_nr = left_nc;
3323
3324 let mut result = Vec::with_capacity(right_nc);
3325 let result_ty = self.types.insert(
3326 Type {
3327 name: None,
3328 inner: TypeInner::Matrix {
3329 columns: right_columns,
3330 rows: left_rows,
3331 scalar,
3332 },
3333 },
3334 span,
3335 );
3336 let result_column_ty = self.types.insert(
3337 Type {
3338 name: None,
3339 inner: TypeInner::Vector {
3340 size: left_rows,
3341 scalar,
3342 },
3343 },
3344 span,
3345 );
3346
3347 let left_flattened = self.flatten_matrix(left_components);
3348 let right_flattened = self.flatten_matrix(right_components);
3349 for c in 0..right_nc {
3350 let result_column = (0..left_nr)
3351 .map(|r| {
3352 let row = left_flattened.iter().skip(r).step_by(left_nr);
3353 let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3354 self.dot_exprs(row.cloned(), column.cloned(), span)
3355 })
3356 .collect::<Result<Vec<_>, _>>()?;
3357 let expr = Expression::Compose {
3358 ty: result_column_ty,
3359 components: result_column,
3360 };
3361 let handle = self.register_evaluated_expr(expr, span)?;
3362 result.push(handle);
3363 }
3364 Ok(Expression::Compose {
3365 ty: result_ty,
3366 components: result,
3367 })
3368 }
3369
3370 fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3371 let mut flattened = ArrayVec::<_, 16>::new();
3372 for &column in columns {
3373 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3374 unreachable!()
3375 };
3376 flattened.extend(components.iter().cloned());
3377 }
3378 flattened
3379 }
3380
3381 fn dot_exprs(
3382 &mut self,
3383 left: impl Iterator<Item = Handle<Expression>>,
3384 right: impl Iterator<Item = Handle<Expression>>,
3385 span: Span,
3386 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3387 let mut acc = None;
3388 for (l, r) in left.zip(right) {
3389 let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3390 match acc.as_mut() {
3391 Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3392 None => acc = Some(result),
3393 }
3394 }
3395 Ok(acc.unwrap())
3396 }
3397
3398 fn relational(
3399 &mut self,
3400 fun: RelationalFunction,
3401 arg: Handle<Expression>,
3402 span: Span,
3403 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3404 let arg = self.eval_zero_value_and_splat(arg, span)?;
3405 match fun {
3406 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3407 Expression::Literal(Literal::Bool(_)) => Ok(arg),
3408 Expression::Compose { ty, ref components }
3409 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3410 {
3411 let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3412 for component in
3413 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3414 {
3415 match self.expressions[component] {
3416 Expression::Literal(Literal::Bool(val)) => {
3417 bool_components.push(val);
3418 }
3419 _ => {
3420 return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3421 }
3422 }
3423 }
3424 let components = bool_components;
3425 let result = match fun {
3426 RelationalFunction::All => components.iter().all(|c| *c),
3427 RelationalFunction::Any => components.iter().any(|c| *c),
3428 _ => unreachable!(),
3429 };
3430 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3431 }
3432 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3433 },
3434 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3435 "{fun:?} built-in function"
3436 ))),
3437 }
3438 }
3439
3440 fn copy_from(
3448 &mut self,
3449 expr: Handle<Expression>,
3450 expressions: &Arena<Expression>,
3451 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3452 let span = expressions.get_span(expr);
3453 match expressions[expr] {
3454 ref expr @ (Expression::Literal(_)
3455 | Expression::Constant(_)
3456 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3457 Expression::Compose { ty, ref components } => {
3458 let mut components = components.clone();
3459 for component in &mut components {
3460 *component = self.copy_from(*component, expressions)?;
3461 }
3462 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3463 }
3464 Expression::Splat { size, value } => {
3465 let value = self.copy_from(value, expressions)?;
3466 self.register_evaluated_expr(Expression::Splat { size, value }, span)
3467 }
3468 _ => {
3469 log::debug!("copy_from: SubexpressionsAreNotConstant");
3470 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3471 }
3472 }
3473 }
3474
3475 fn vector_compose_flattened_size(
3477 &self,
3478 components: &[Handle<Expression>],
3479 ) -> Result<usize, ConstantEvaluatorError> {
3480 components
3481 .iter()
3482 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3483 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3484 TypeInner::Scalar(_) => 1,
3485 TypeInner::Vector { size, .. } => size as usize,
3489 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3490 };
3491 Ok(acc + size)
3492 })
3493 }
3494
3495 fn register_evaluated_expr(
3496 &mut self,
3497 expr: Expression,
3498 span: Span,
3499 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3500 if let Expression::Literal(literal) = expr {
3505 crate::valid::check_literal_value(literal)?;
3506 }
3507
3508 if let Expression::Compose { ty, ref components } = expr {
3512 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3513 let expected = size as usize;
3514 let actual = self.vector_compose_flattened_size(components)?;
3515 if expected != actual {
3516 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3517 expected,
3518 actual,
3519 });
3520 }
3521 }
3522 }
3523
3524 Ok(self.append_expr(expr, span, ExpressionKind::Const))
3525 }
3526
3527 fn append_expr(
3528 &mut self,
3529 expr: Expression,
3530 span: Span,
3531 expr_type: ExpressionKind,
3532 ) -> Handle<Expression> {
3533 let h = match self.behavior {
3534 Behavior::Wgsl(
3535 WgslRestrictions::Runtime(ref mut function_local_data)
3536 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3537 )
3538 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3539 let is_running = function_local_data.emitter.is_running();
3540 let needs_pre_emit = expr.needs_pre_emit();
3541 if is_running && needs_pre_emit {
3542 function_local_data
3543 .block
3544 .extend(function_local_data.emitter.finish(self.expressions));
3545 let h = self.expressions.append(expr, span);
3546 function_local_data.emitter.start(self.expressions);
3547 h
3548 } else {
3549 self.expressions.append(expr, span)
3550 }
3551 }
3552 _ => self.expressions.append(expr, span),
3553 };
3554 self.expression_kind_tracker.insert(h, expr_type);
3555 h
3556 }
3557
3558 fn resolve_type(
3563 &self,
3564 expr: Handle<Expression>,
3565 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3566 use crate::proc::TypeResolution as Tr;
3567 use crate::Expression as Ex;
3568 let resolution = match self.expressions[expr] {
3569 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3570 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3571 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3572 Ex::Splat { size, value } => {
3573 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3574 return Err(ConstantEvaluatorError::SplatScalarOnly);
3575 };
3576 Tr::Value(TypeInner::Vector { scalar, size })
3577 }
3578 _ => {
3579 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3580 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3581 }
3582 };
3583
3584 Ok(resolution)
3585 }
3586
3587 fn select(
3588 &mut self,
3589 reject: Handle<Expression>,
3590 accept: Handle<Expression>,
3591 condition: Handle<Expression>,
3592 span: Span,
3593 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3594 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3595
3596 let reject = arg(reject)?;
3597 let accept = arg(accept)?;
3598 let condition = arg(condition)?;
3599
3600 let select_single_component =
3601 |this: &mut Self, reject_scalar, reject, accept, condition| {
3602 let accept = this.cast(accept, reject_scalar, span)?;
3603 if condition {
3604 Ok(accept)
3605 } else {
3606 Ok(reject)
3607 }
3608 };
3609
3610 match (&self.expressions[reject], &self.expressions[accept]) {
3611 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3612 let reject_scalar = reject_lit.scalar();
3613 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3614 else {
3615 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3616 };
3617 select_single_component(self, reject_scalar, reject, accept, condition)
3618 }
3619 (
3620 &Expression::Compose {
3621 ty: reject_ty,
3622 components: ref reject_components,
3623 },
3624 &Expression::Compose {
3625 ty: accept_ty,
3626 components: ref accept_components,
3627 },
3628 ) => {
3629 let ty_deets = |ty| {
3630 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3631 (size.unwrap(), scalar)
3632 };
3633
3634 let expected_vec_size = {
3635 let [(reject_vec_size, _), (accept_vec_size, _)] =
3636 [reject_ty, accept_ty].map(ty_deets);
3637
3638 if reject_vec_size != accept_vec_size {
3639 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3640 reject: reject_vec_size,
3641 accept: accept_vec_size,
3642 });
3643 }
3644 reject_vec_size
3645 };
3646
3647 let condition_components = match self.expressions[condition] {
3648 Expression::Literal(Literal::Bool(condition)) => {
3649 vec![condition; (expected_vec_size as u8).into()]
3650 }
3651 Expression::Compose {
3652 ty: condition_ty,
3653 components: ref condition_components,
3654 } => {
3655 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3656 if condition_scalar.kind != ScalarKind::Bool {
3657 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3658 }
3659 if condition_vec_size != expected_vec_size {
3660 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3661 }
3662 condition_components
3663 .iter()
3664 .copied()
3665 .map(|component| match &self.expressions[component] {
3666 &Expression::Literal(Literal::Bool(condition)) => condition,
3667 _ => unreachable!(),
3668 })
3669 .collect()
3670 }
3671
3672 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3673 };
3674
3675 let evaluated = Expression::Compose {
3676 ty: reject_ty,
3677 components: reject_components
3678 .clone()
3679 .into_iter()
3680 .zip(accept_components.clone().into_iter())
3681 .zip(condition_components.into_iter())
3682 .map(|((reject, accept), condition)| {
3683 let reject_scalar = match &self.expressions[reject] {
3684 &Expression::Literal(lit) => lit.scalar(),
3685 _ => unreachable!(),
3686 };
3687 select_single_component(self, reject_scalar, reject, accept, condition)
3688 })
3689 .collect::<Result<_, _>>()?,
3690 };
3691 self.register_evaluated_expr(evaluated, span)
3692 }
3693 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3694 }
3695 }
3696}
3697
3698fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3699 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3703 match e {
3704 idx @ 0..=31 => idx,
3705 32 => u32::MAX,
3706 _ => unreachable!(),
3707 }
3708 };
3709 match concrete_int {
3710 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3711 ConcreteInt::I32([e]) => {
3712 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3713 }
3714 }
3715}
3716
3717#[test]
3718fn first_trailing_bit_smoke() {
3719 assert_eq!(
3720 first_trailing_bit(ConcreteInt::I32([0])),
3721 ConcreteInt::I32([-1])
3722 );
3723 assert_eq!(
3724 first_trailing_bit(ConcreteInt::I32([1])),
3725 ConcreteInt::I32([0])
3726 );
3727 assert_eq!(
3728 first_trailing_bit(ConcreteInt::I32([2])),
3729 ConcreteInt::I32([1])
3730 );
3731 assert_eq!(
3732 first_trailing_bit(ConcreteInt::I32([-1])),
3733 ConcreteInt::I32([0]),
3734 );
3735 assert_eq!(
3736 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3737 ConcreteInt::I32([31]),
3738 );
3739 assert_eq!(
3740 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3741 ConcreteInt::I32([0]),
3742 );
3743 for idx in 0..32 {
3744 assert_eq!(
3745 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3746 ConcreteInt::I32([idx])
3747 )
3748 }
3749
3750 assert_eq!(
3751 first_trailing_bit(ConcreteInt::U32([0])),
3752 ConcreteInt::U32([u32::MAX])
3753 );
3754 assert_eq!(
3755 first_trailing_bit(ConcreteInt::U32([1])),
3756 ConcreteInt::U32([0])
3757 );
3758 assert_eq!(
3759 first_trailing_bit(ConcreteInt::U32([2])),
3760 ConcreteInt::U32([1])
3761 );
3762 assert_eq!(
3763 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3764 ConcreteInt::U32([31]),
3765 );
3766 assert_eq!(
3767 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3768 ConcreteInt::U32([0]),
3769 );
3770 for idx in 0..32 {
3771 assert_eq!(
3772 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3773 ConcreteInt::U32([idx])
3774 )
3775 }
3776}
3777
3778fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3779 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3783 match e {
3784 idx @ 0..=31 => 31 - idx,
3785 32 => u32::MAX,
3786 _ => unreachable!(),
3787 }
3788 };
3789 match concrete_int {
3790 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3791 let rtl_bit_index = if e.is_negative() {
3792 e.leading_ones()
3793 } else {
3794 e.leading_zeros()
3795 };
3796 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3797 }]),
3798 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3799 }
3800}
3801
3802#[test]
3803fn first_leading_bit_smoke() {
3804 assert_eq!(
3805 first_leading_bit(ConcreteInt::I32([-1])),
3806 ConcreteInt::I32([-1])
3807 );
3808 assert_eq!(
3809 first_leading_bit(ConcreteInt::I32([0])),
3810 ConcreteInt::I32([-1])
3811 );
3812 assert_eq!(
3813 first_leading_bit(ConcreteInt::I32([1])),
3814 ConcreteInt::I32([0])
3815 );
3816 assert_eq!(
3817 first_leading_bit(ConcreteInt::I32([-2])),
3818 ConcreteInt::I32([0])
3819 );
3820 assert_eq!(
3821 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3822 ConcreteInt::I32([12])
3823 );
3824 assert_eq!(
3825 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3826 ConcreteInt::I32([30])
3827 );
3828 assert_eq!(
3829 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3830 ConcreteInt::I32([30])
3831 );
3832 for idx in 0..(32 - 1) {
3834 assert_eq!(
3835 first_leading_bit(ConcreteInt::I32([1 << idx])),
3836 ConcreteInt::I32([idx])
3837 );
3838 }
3839 for idx in 1..(32 - 1) {
3840 assert_eq!(
3841 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3842 ConcreteInt::I32([idx - 1])
3843 );
3844 }
3845
3846 assert_eq!(
3847 first_leading_bit(ConcreteInt::U32([0])),
3848 ConcreteInt::U32([u32::MAX])
3849 );
3850 assert_eq!(
3851 first_leading_bit(ConcreteInt::U32([1])),
3852 ConcreteInt::U32([0])
3853 );
3854 assert_eq!(
3855 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3856 ConcreteInt::U32([31])
3857 );
3858 for idx in 0..32 {
3859 assert_eq!(
3860 first_leading_bit(ConcreteInt::U32([1 << idx])),
3861 ConcreteInt::U32([idx])
3862 )
3863 }
3864}
3865
3866trait TryFromAbstract<T>: Sized {
3868 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3890}
3891
3892impl TryFromAbstract<i64> for i32 {
3893 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3894 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3895 value: format!("{value:?}"),
3896 to_type: "i32",
3897 })
3898 }
3899}
3900
3901impl TryFromAbstract<i64> for u32 {
3902 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3903 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3904 value: format!("{value:?}"),
3905 to_type: "u32",
3906 })
3907 }
3908}
3909
3910impl TryFromAbstract<i64> for u64 {
3911 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3912 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3913 value: format!("{value:?}"),
3914 to_type: "u64",
3915 })
3916 }
3917}
3918
3919impl TryFromAbstract<i64> for i64 {
3920 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3921 Ok(value)
3922 }
3923}
3924
3925impl TryFromAbstract<i64> for f32 {
3926 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3927 let f = value as f32;
3928 Ok(f)
3932 }
3933}
3934
3935impl TryFromAbstract<f64> for f32 {
3936 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3937 let f = value as f32;
3938 if f.is_infinite() {
3939 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3940 value: format!("{value:?}"),
3941 to_type: "f32",
3942 });
3943 }
3944 Ok(f)
3945 }
3946}
3947
3948impl TryFromAbstract<i64> for f64 {
3949 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3950 let f = value as f64;
3951 Ok(f)
3955 }
3956}
3957
3958impl TryFromAbstract<f64> for f64 {
3959 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3960 Ok(value)
3961 }
3962}
3963
3964impl TryFromAbstract<f64> for i32 {
3965 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3966 Ok(value as i32)
3979 }
3980}
3981
3982impl TryFromAbstract<f64> for u32 {
3983 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3984 Ok(value as u32)
3987 }
3988}
3989
3990impl TryFromAbstract<f64> for i64 {
3991 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3992 use crate::proc::type_methods::IntFloatLimits;
3995 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3996 }
3997}
3998
3999impl TryFromAbstract<f64> for u64 {
4000 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
4001 use crate::proc::type_methods::IntFloatLimits;
4004 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
4005 }
4006}
4007
4008impl TryFromAbstract<f64> for f16 {
4009 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
4010 let f = f16::from_f64(value);
4011 if f.is_infinite() {
4012 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4013 value: format!("{value:?}"),
4014 to_type: "f16",
4015 });
4016 }
4017 Ok(f)
4018 }
4019}
4020
4021impl TryFromAbstract<i64> for f16 {
4022 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
4023 let f = f16::from_i64(value);
4024 if f.is_none() {
4025 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4026 value: format!("{value:?}"),
4027 to_type: "f16",
4028 });
4029 }
4030 Ok(f.unwrap())
4031 }
4032}
4033
4034fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
4035where
4036 T: Copy,
4037 T: core::ops::Mul<T, Output = T>,
4038 T: core::ops::Sub<T, Output = T>,
4039{
4040 [
4041 a[1] * b[2] - a[2] * b[1],
4042 a[2] * b[0] - a[0] * b[2],
4043 a[0] * b[1] - a[1] * b[0],
4044 ]
4045}
4046
4047#[cfg(test)]
4048mod tests {
4049 use alloc::{vec, vec::Vec};
4050
4051 use crate::{
4052 Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
4053 Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
4054 };
4055
4056 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
4057
4058 #[test]
4059 fn unary_op() {
4060 let mut types = UniqueArena::new();
4061 let mut constants = Arena::new();
4062 let overrides = Arena::new();
4063 let mut global_expressions = Arena::new();
4064
4065 let scalar_ty = types.insert(
4066 Type {
4067 name: None,
4068 inner: TypeInner::Scalar(crate::Scalar::I32),
4069 },
4070 Default::default(),
4071 );
4072
4073 let vec_ty = types.insert(
4074 Type {
4075 name: None,
4076 inner: TypeInner::Vector {
4077 size: VectorSize::Bi,
4078 scalar: crate::Scalar::I32,
4079 },
4080 },
4081 Default::default(),
4082 );
4083
4084 let h = constants.append(
4085 Constant {
4086 name: None,
4087 ty: scalar_ty,
4088 init: global_expressions
4089 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4090 },
4091 Default::default(),
4092 );
4093
4094 let h1 = constants.append(
4095 Constant {
4096 name: None,
4097 ty: scalar_ty,
4098 init: global_expressions
4099 .append(Expression::Literal(Literal::I32(8)), Default::default()),
4100 },
4101 Default::default(),
4102 );
4103
4104 let vec_h = constants.append(
4105 Constant {
4106 name: None,
4107 ty: vec_ty,
4108 init: global_expressions.append(
4109 Expression::Compose {
4110 ty: vec_ty,
4111 components: vec![constants[h].init, constants[h1].init],
4112 },
4113 Default::default(),
4114 ),
4115 },
4116 Default::default(),
4117 );
4118
4119 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4120 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
4121
4122 let expr2 = Expression::Unary {
4123 op: UnaryOperator::Negate,
4124 expr,
4125 };
4126
4127 let expr3 = Expression::Unary {
4128 op: UnaryOperator::BitwiseNot,
4129 expr,
4130 };
4131
4132 let expr4 = Expression::Unary {
4133 op: UnaryOperator::BitwiseNot,
4134 expr: expr1,
4135 };
4136
4137 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4138 let mut solver = ConstantEvaluator {
4139 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4140 types: &mut types,
4141 constants: &constants,
4142 overrides: &overrides,
4143 expressions: &mut global_expressions,
4144 expression_kind_tracker,
4145 layouter: &mut crate::proc::Layouter::default(),
4146 };
4147
4148 let res1 = solver
4149 .try_eval_and_append(expr2, Default::default())
4150 .unwrap();
4151 let res2 = solver
4152 .try_eval_and_append(expr3, Default::default())
4153 .unwrap();
4154 let res3 = solver
4155 .try_eval_and_append(expr4, Default::default())
4156 .unwrap();
4157
4158 assert_eq!(
4159 global_expressions[res1],
4160 Expression::Literal(Literal::I32(-4))
4161 );
4162
4163 assert_eq!(
4164 global_expressions[res2],
4165 Expression::Literal(Literal::I32(!4))
4166 );
4167
4168 let res3_inner = &global_expressions[res3];
4169
4170 match *res3_inner {
4171 Expression::Compose {
4172 ref ty,
4173 ref components,
4174 } => {
4175 assert_eq!(*ty, vec_ty);
4176 let mut components_iter = components.iter().copied();
4177 assert_eq!(
4178 global_expressions[components_iter.next().unwrap()],
4179 Expression::Literal(Literal::I32(!4))
4180 );
4181 assert_eq!(
4182 global_expressions[components_iter.next().unwrap()],
4183 Expression::Literal(Literal::I32(!8))
4184 );
4185 assert!(components_iter.next().is_none());
4186 }
4187 _ => panic!("Expected vector"),
4188 }
4189 }
4190
4191 #[test]
4192 fn matrix_op() {
4193 let mut helper = MatrixTestHelper::new();
4194
4195 for nc in 2..=4 {
4196 for nr in 2..=4 {
4197 let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4200 let expected = (0..nc)
4201 .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4202 .collect::<Vec<f32>>();
4203 assert_eq!(evaluated, expected);
4204
4205 let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4208 let expected = (0..nr)
4209 .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4210 .collect::<Vec<f32>>();
4211 assert_eq!(evaluated, expected);
4212
4213 for k in 2..=4 {
4214 let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4217 let expected = (0..nc)
4218 .flat_map(|c| {
4219 (0..nr).map(move |r| {
4220 (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4221 })
4222 })
4223 .collect::<Vec<f32>>();
4224 assert_eq!(evaluated, expected);
4225 }
4226 }
4227 }
4228 }
4229
4230 struct MatrixTestHelper {
4234 types: UniqueArena<Type>,
4235 expressions: Arena<Expression>,
4236 vec_exprs: FastHashMap<usize, Handle<Expression>>,
4238 mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4240 }
4241
4242 impl MatrixTestHelper {
4243 fn new() -> Self {
4244 let mut types = UniqueArena::new();
4245 let mut expressions = Arena::new();
4246 let span = crate::Span::default();
4247
4248 let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4249 for c in 2..=4 {
4250 let vec_ty = types.insert(
4251 Type {
4252 name: None,
4253 inner: TypeInner::Vector {
4254 size: Self::int_to_vector_size(c),
4255 scalar: crate::Scalar::F32,
4256 },
4257 },
4258 span,
4259 );
4260 vec_tys.insert(c, vec_ty);
4261 for r in 2..=4 {
4262 let mat_ty = types.insert(
4263 Type {
4264 name: None,
4265 inner: TypeInner::Matrix {
4266 columns: Self::int_to_vector_size(c),
4267 rows: Self::int_to_vector_size(r),
4268 scalar: crate::Scalar::F32,
4269 },
4270 },
4271 span,
4272 );
4273 mat_tys.insert((c, r), mat_ty);
4274 }
4275 }
4276
4277 let mut lit_exprs = FastHashMap::default();
4278 for i in 0..16 {
4279 let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4280 lit_exprs.insert(i, expr);
4281 }
4282
4283 let mut vec_exprs = FastHashMap::default();
4284 for c in 2..=4 {
4285 let expr = expressions.append(
4286 Expression::Compose {
4287 ty: *vec_tys.get(&c).unwrap(),
4288 components: (0..c)
4289 .map(|i| *lit_exprs.get(&i).unwrap())
4290 .collect::<Vec<_>>(),
4291 },
4292 span,
4293 );
4294 vec_exprs.insert(c, expr);
4295 }
4296
4297 let mut mat_exprs = FastHashMap::default();
4298 for c in 2..=4 {
4299 for r in 2..=4 {
4300 let mut columns = Vec::with_capacity(c);
4301 for cc in 0..c {
4302 let start = cc * r;
4303 let expr = expressions.append(
4304 Expression::Compose {
4305 ty: *vec_tys.get(&r).unwrap(),
4306 components: (start..start + r)
4307 .map(|i| *lit_exprs.get(&i).unwrap())
4308 .collect::<Vec<_>>(),
4309 },
4310 span,
4311 );
4312 columns.push(expr);
4313 }
4314
4315 let expr = expressions.append(
4316 Expression::Compose {
4317 ty: *mat_tys.get(&(c, r)).unwrap(),
4318 components: columns,
4319 },
4320 span,
4321 );
4322 mat_exprs.insert((c, r), expr);
4323 }
4324 }
4325
4326 Self {
4327 types,
4328 expressions,
4329 vec_exprs,
4330 mat_exprs,
4331 }
4332 }
4333
4334 fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4336 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4337 let mut solver = ConstantEvaluator {
4338 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4339 types: &mut self.types,
4340 constants: &Arena::new(),
4341 overrides: &Arena::new(),
4342 expressions: &mut self.expressions,
4343 expression_kind_tracker,
4344 layouter: &mut crate::proc::Layouter::default(),
4345 };
4346
4347 let result = solver
4348 .try_eval_and_append(
4349 Expression::Binary {
4350 op: BinaryOperator::Multiply,
4351 left: *self.vec_exprs.get(&nr).unwrap(),
4352 right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4353 },
4354 Default::default(),
4355 )
4356 .unwrap();
4357 self.flatten(result)
4358 }
4359
4360 fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4362 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4363 let mut solver = ConstantEvaluator {
4364 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4365 types: &mut self.types,
4366 constants: &Arena::new(),
4367 overrides: &Arena::new(),
4368 expressions: &mut self.expressions,
4369 expression_kind_tracker,
4370 layouter: &mut crate::proc::Layouter::default(),
4371 };
4372
4373 let result = solver
4374 .try_eval_and_append(
4375 Expression::Binary {
4376 op: BinaryOperator::Multiply,
4377 left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4378 right: *self.vec_exprs.get(&nc).unwrap(),
4379 },
4380 Default::default(),
4381 )
4382 .unwrap();
4383 self.flatten(result)
4384 }
4385
4386 fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4389 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4390 let mut solver = ConstantEvaluator {
4391 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4392 types: &mut self.types,
4393 constants: &Arena::new(),
4394 overrides: &Arena::new(),
4395 expressions: &mut self.expressions,
4396 expression_kind_tracker,
4397 layouter: &mut crate::proc::Layouter::default(),
4398 };
4399
4400 let result = solver
4401 .try_eval_and_append(
4402 Expression::Binary {
4403 op: BinaryOperator::Multiply,
4404 left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4405 right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4406 },
4407 Default::default(),
4408 )
4409 .unwrap();
4410 self.flatten(result)
4411 }
4412
4413 fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4414 let Expression::Compose {
4415 ref components,
4416 ref ty,
4417 } = self.expressions[expr]
4418 else {
4419 unreachable!()
4420 };
4421
4422 match self.types[*ty].inner {
4423 TypeInner::Vector { .. } => components
4424 .iter()
4425 .map(|&comp| {
4426 let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4427 unreachable!()
4428 };
4429 v
4430 })
4431 .collect(),
4432 TypeInner::Matrix { .. } => components
4433 .iter()
4434 .flat_map(|&comp| self.flatten(comp))
4435 .collect(),
4436 _ => unreachable!(),
4437 }
4438 }
4439
4440 fn int_to_vector_size(int: usize) -> VectorSize {
4441 match int {
4442 2 => VectorSize::Bi,
4443 3 => VectorSize::Tri,
4444 4 => VectorSize::Quad,
4445 _ => unreachable!(),
4446 }
4447 }
4448 }
4449
4450 #[test]
4451 fn cast() {
4452 let mut types = UniqueArena::new();
4453 let mut constants = Arena::new();
4454 let overrides = Arena::new();
4455 let mut global_expressions = Arena::new();
4456
4457 let scalar_ty = types.insert(
4458 Type {
4459 name: None,
4460 inner: TypeInner::Scalar(crate::Scalar::I32),
4461 },
4462 Default::default(),
4463 );
4464
4465 let h = constants.append(
4466 Constant {
4467 name: None,
4468 ty: scalar_ty,
4469 init: global_expressions
4470 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4471 },
4472 Default::default(),
4473 );
4474
4475 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4476
4477 let root = Expression::As {
4478 expr,
4479 kind: ScalarKind::Bool,
4480 convert: Some(crate::BOOL_WIDTH),
4481 };
4482
4483 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4484 let mut solver = ConstantEvaluator {
4485 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4486 types: &mut types,
4487 constants: &constants,
4488 overrides: &overrides,
4489 expressions: &mut global_expressions,
4490 expression_kind_tracker,
4491 layouter: &mut crate::proc::Layouter::default(),
4492 };
4493
4494 let res = solver
4495 .try_eval_and_append(root, Default::default())
4496 .unwrap();
4497
4498 assert_eq!(
4499 global_expressions[res],
4500 Expression::Literal(Literal::Bool(true))
4501 );
4502 }
4503
4504 #[test]
4505 fn access() {
4506 let mut types = UniqueArena::new();
4507 let mut constants = Arena::new();
4508 let overrides = Arena::new();
4509 let mut global_expressions = Arena::new();
4510
4511 let matrix_ty = types.insert(
4512 Type {
4513 name: None,
4514 inner: TypeInner::Matrix {
4515 columns: VectorSize::Bi,
4516 rows: VectorSize::Tri,
4517 scalar: crate::Scalar::F32,
4518 },
4519 },
4520 Default::default(),
4521 );
4522
4523 let vec_ty = types.insert(
4524 Type {
4525 name: None,
4526 inner: TypeInner::Vector {
4527 size: VectorSize::Tri,
4528 scalar: crate::Scalar::F32,
4529 },
4530 },
4531 Default::default(),
4532 );
4533
4534 let mut vec1_components = Vec::with_capacity(3);
4535 let mut vec2_components = Vec::with_capacity(3);
4536
4537 for i in 0..3 {
4538 let h = global_expressions.append(
4539 Expression::Literal(Literal::F32(i as f32)),
4540 Default::default(),
4541 );
4542
4543 vec1_components.push(h)
4544 }
4545
4546 for i in 3..6 {
4547 let h = global_expressions.append(
4548 Expression::Literal(Literal::F32(i as f32)),
4549 Default::default(),
4550 );
4551
4552 vec2_components.push(h)
4553 }
4554
4555 let vec1 = constants.append(
4556 Constant {
4557 name: None,
4558 ty: vec_ty,
4559 init: global_expressions.append(
4560 Expression::Compose {
4561 ty: vec_ty,
4562 components: vec1_components,
4563 },
4564 Default::default(),
4565 ),
4566 },
4567 Default::default(),
4568 );
4569
4570 let vec2 = constants.append(
4571 Constant {
4572 name: None,
4573 ty: vec_ty,
4574 init: global_expressions.append(
4575 Expression::Compose {
4576 ty: vec_ty,
4577 components: vec2_components,
4578 },
4579 Default::default(),
4580 ),
4581 },
4582 Default::default(),
4583 );
4584
4585 let h = constants.append(
4586 Constant {
4587 name: None,
4588 ty: matrix_ty,
4589 init: global_expressions.append(
4590 Expression::Compose {
4591 ty: matrix_ty,
4592 components: vec![constants[vec1].init, constants[vec2].init],
4593 },
4594 Default::default(),
4595 ),
4596 },
4597 Default::default(),
4598 );
4599
4600 let base = global_expressions.append(Expression::Constant(h), Default::default());
4601
4602 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4603 let mut solver = ConstantEvaluator {
4604 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4605 types: &mut types,
4606 constants: &constants,
4607 overrides: &overrides,
4608 expressions: &mut global_expressions,
4609 expression_kind_tracker,
4610 layouter: &mut crate::proc::Layouter::default(),
4611 };
4612
4613 let root1 = Expression::AccessIndex { base, index: 1 };
4614
4615 let res1 = solver
4616 .try_eval_and_append(root1, Default::default())
4617 .unwrap();
4618
4619 let root2 = Expression::AccessIndex {
4620 base: res1,
4621 index: 2,
4622 };
4623
4624 let res2 = solver
4625 .try_eval_and_append(root2, Default::default())
4626 .unwrap();
4627
4628 match global_expressions[res1] {
4629 Expression::Compose {
4630 ref ty,
4631 ref components,
4632 } => {
4633 assert_eq!(*ty, vec_ty);
4634 let mut components_iter = components.iter().copied();
4635 assert_eq!(
4636 global_expressions[components_iter.next().unwrap()],
4637 Expression::Literal(Literal::F32(3.))
4638 );
4639 assert_eq!(
4640 global_expressions[components_iter.next().unwrap()],
4641 Expression::Literal(Literal::F32(4.))
4642 );
4643 assert_eq!(
4644 global_expressions[components_iter.next().unwrap()],
4645 Expression::Literal(Literal::F32(5.))
4646 );
4647 assert!(components_iter.next().is_none());
4648 }
4649 _ => panic!("Expected vector"),
4650 }
4651
4652 assert_eq!(
4653 global_expressions[res2],
4654 Expression::Literal(Literal::F32(5.))
4655 );
4656 }
4657
4658 #[test]
4659 fn compose_of_constants() {
4660 let mut types = UniqueArena::new();
4661 let mut constants = Arena::new();
4662 let overrides = Arena::new();
4663 let mut global_expressions = Arena::new();
4664
4665 let i32_ty = types.insert(
4666 Type {
4667 name: None,
4668 inner: TypeInner::Scalar(crate::Scalar::I32),
4669 },
4670 Default::default(),
4671 );
4672
4673 let vec2_i32_ty = types.insert(
4674 Type {
4675 name: None,
4676 inner: TypeInner::Vector {
4677 size: VectorSize::Bi,
4678 scalar: crate::Scalar::I32,
4679 },
4680 },
4681 Default::default(),
4682 );
4683
4684 let h = constants.append(
4685 Constant {
4686 name: None,
4687 ty: i32_ty,
4688 init: global_expressions
4689 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4690 },
4691 Default::default(),
4692 );
4693
4694 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4695
4696 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4697 let mut solver = ConstantEvaluator {
4698 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4699 types: &mut types,
4700 constants: &constants,
4701 overrides: &overrides,
4702 expressions: &mut global_expressions,
4703 expression_kind_tracker,
4704 layouter: &mut crate::proc::Layouter::default(),
4705 };
4706
4707 let solved_compose = solver
4708 .try_eval_and_append(
4709 Expression::Compose {
4710 ty: vec2_i32_ty,
4711 components: vec![h_expr, h_expr],
4712 },
4713 Default::default(),
4714 )
4715 .unwrap();
4716 let solved_negate = solver
4717 .try_eval_and_append(
4718 Expression::Unary {
4719 op: UnaryOperator::Negate,
4720 expr: solved_compose,
4721 },
4722 Default::default(),
4723 )
4724 .unwrap();
4725
4726 let pass = match global_expressions[solved_negate] {
4727 Expression::Compose { ty, ref components } => {
4728 ty == vec2_i32_ty
4729 && components.iter().all(|&component| {
4730 let component = &global_expressions[component];
4731 matches!(*component, Expression::Literal(Literal::I32(-4)))
4732 })
4733 }
4734 _ => false,
4735 };
4736 if !pass {
4737 panic!("unexpected evaluation result")
4738 }
4739 }
4740
4741 #[test]
4742 fn splat_of_constant() {
4743 let mut types = UniqueArena::new();
4744 let mut constants = Arena::new();
4745 let overrides = Arena::new();
4746 let mut global_expressions = Arena::new();
4747
4748 let i32_ty = types.insert(
4749 Type {
4750 name: None,
4751 inner: TypeInner::Scalar(crate::Scalar::I32),
4752 },
4753 Default::default(),
4754 );
4755
4756 let vec2_i32_ty = types.insert(
4757 Type {
4758 name: None,
4759 inner: TypeInner::Vector {
4760 size: VectorSize::Bi,
4761 scalar: crate::Scalar::I32,
4762 },
4763 },
4764 Default::default(),
4765 );
4766
4767 let h = constants.append(
4768 Constant {
4769 name: None,
4770 ty: i32_ty,
4771 init: global_expressions
4772 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4773 },
4774 Default::default(),
4775 );
4776
4777 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4778
4779 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4780 let mut solver = ConstantEvaluator {
4781 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4782 types: &mut types,
4783 constants: &constants,
4784 overrides: &overrides,
4785 expressions: &mut global_expressions,
4786 expression_kind_tracker,
4787 layouter: &mut crate::proc::Layouter::default(),
4788 };
4789
4790 let solved_compose = solver
4791 .try_eval_and_append(
4792 Expression::Splat {
4793 size: VectorSize::Bi,
4794 value: h_expr,
4795 },
4796 Default::default(),
4797 )
4798 .unwrap();
4799 let solved_negate = solver
4800 .try_eval_and_append(
4801 Expression::Unary {
4802 op: UnaryOperator::Negate,
4803 expr: solved_compose,
4804 },
4805 Default::default(),
4806 )
4807 .unwrap();
4808
4809 let pass = match global_expressions[solved_negate] {
4810 Expression::Compose { ty, ref components } => {
4811 ty == vec2_i32_ty
4812 && components.iter().all(|&component| {
4813 let component = &global_expressions[component];
4814 matches!(*component, Expression::Literal(Literal::I32(-4)))
4815 })
4816 }
4817 _ => false,
4818 };
4819 if !pass {
4820 panic!("unexpected evaluation result")
4821 }
4822 }
4823
4824 #[test]
4825 fn splat_of_zero_value() {
4826 let mut types = UniqueArena::new();
4827 let constants = Arena::new();
4828 let overrides = Arena::new();
4829 let mut global_expressions = Arena::new();
4830
4831 let f32_ty = types.insert(
4832 Type {
4833 name: None,
4834 inner: TypeInner::Scalar(crate::Scalar::F32),
4835 },
4836 Default::default(),
4837 );
4838
4839 let vec2_f32_ty = types.insert(
4840 Type {
4841 name: None,
4842 inner: TypeInner::Vector {
4843 size: VectorSize::Bi,
4844 scalar: crate::Scalar::F32,
4845 },
4846 },
4847 Default::default(),
4848 );
4849
4850 let five =
4851 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4852 let five_splat = global_expressions.append(
4853 Expression::Splat {
4854 size: VectorSize::Bi,
4855 value: five,
4856 },
4857 Default::default(),
4858 );
4859 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4860 let zero_splat = global_expressions.append(
4861 Expression::Splat {
4862 size: VectorSize::Bi,
4863 value: zero,
4864 },
4865 Default::default(),
4866 );
4867
4868 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4869 let mut solver = ConstantEvaluator {
4870 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4871 types: &mut types,
4872 constants: &constants,
4873 overrides: &overrides,
4874 expressions: &mut global_expressions,
4875 expression_kind_tracker,
4876 layouter: &mut crate::proc::Layouter::default(),
4877 };
4878
4879 let solved_add = solver
4880 .try_eval_and_append(
4881 Expression::Binary {
4882 op: BinaryOperator::Add,
4883 left: zero_splat,
4884 right: five_splat,
4885 },
4886 Default::default(),
4887 )
4888 .unwrap();
4889
4890 let pass = match global_expressions[solved_add] {
4891 Expression::Compose { ty, ref components } => {
4892 ty == vec2_f32_ty
4893 && components.iter().all(|&component| {
4894 let component = &global_expressions[component];
4895 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4896 })
4897 }
4898 _ => false,
4899 };
4900 if !pass {
4901 panic!("unexpected evaluation result")
4902 }
4903 }
4904}