1use alloc::{
7 format,
8 string::{String, ToString},
9 vec,
10 vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19 arena::{Arena, Handle, HandleVec, UniqueArena},
20 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21 ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27macro_rules! with_dollar_sign {
33 ($($body:tt)*) => {
34 macro_rules! __with_dollar_sign { $($body)* }
35 __with_dollar_sign!($);
36 }
37}
38
39macro_rules! gen_component_wise_extractor {
40 (
41 $ident:ident -> $target:ident,
42 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44 ) => {
45 #[derive(Debug)]
47 #[cfg_attr(test, derive(PartialEq))]
48 enum $target<const N: usize> {
49 $(
50 #[doc = concat!(
51 "Maps to [`Literal::",
52 stringify!($literal),
53 "`]",
54 )]
55 $mapping([$ty; N]),
56 )+
57 }
58
59 impl From<$target<1>> for Expression {
60 fn from(value: $target<1>) -> Self {
61 match value {
62 $(
63 $target::$mapping([value]) => {
64 Expression::Literal(Literal::$literal(value))
65 }
66 )+
67 }
68 }
69 }
70
71 #[doc = concat!(
72 "Attempts to evaluate multiple `exprs` as a combined [`",
73 stringify!($target),
74 "`] to pass to `handler`. ",
75 )]
76 fn $ident<const N: usize, const M: usize>(
83 eval: &mut ConstantEvaluator<'_>,
84 span: Span,
85 exprs: [Handle<Expression>; N],
86 handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88 where
89 $target<M>: Into<Expression>,
90 {
91 assert!(N > 0);
92 let err = ConstantEvaluatorError::InvalidMathArg;
93 let mut exprs = exprs.into_iter();
94
95 macro_rules! sanitize {
96 ($expr:expr) => {
97 eval.eval_zero_value_and_splat($expr, span)
98 .map(|expr| &eval.expressions[expr])
99 };
100 }
101
102 let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103 $(
104 &Expression::Literal(Literal::$literal(x)) => {
105 let mut arr = ArrayVec::<_, N>::new();
106 arr.push(x);
107 for expr in exprs {
108 match sanitize!(expr)? {
109 &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110 _ => return Err(err),
111 }
112 }
113 let comps = $target::$mapping(arr.into_inner().unwrap());
114 Ok(handler(comps)?.into())
115 },
116 )+
117 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118 &TypeInner::Vector { size, scalar } => match scalar.kind {
119 $(ScalarKind::$scalar_kind)|* => {
120 let first_ty = ty;
121 let mut component_groups =
122 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123 {
124 let mut inner = ArrayVec::new();
125 for item in crate::proc::flatten_compose(
126 first_ty,
127 components,
128 eval.expressions,
129 eval.types,
130 ) {
131 inner.push(item);
132 }
133 component_groups.push(inner);
134 }
135 for expr in exprs {
136 match sanitize!(expr)? {
137 &Expression::Compose { ty, ref components }
138 if &eval.types[ty].inner
139 == &eval.types[first_ty].inner =>
140 {
141 let mut inner = ArrayVec::new();
142 for item in crate::proc::flatten_compose(
143 ty,
144 components,
145 eval.expressions,
146 eval.types,
147 ) {
148 inner.push(item);
149 }
150 component_groups.push(inner);
151 }
152 _ => return Err(err),
153 }
154 }
155 let component_groups = component_groups.into_inner().unwrap();
156 let mut new_components =
157 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158 for idx in 0..(size as u8).into() {
159 let mut group_arr = ArrayVec::<_, N>::new();
160 for cs in component_groups.iter() {
161 group_arr.push(
162 cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163 );
164 }
165 let group = group_arr.into_inner().unwrap();
166 new_components.push($ident(
167 eval,
168 span,
169 group,
170 handler,
171 )?);
172 }
173 Ok(Expression::Compose {
174 ty: first_ty,
175 components: new_components.into_iter().collect(),
176 })
177 }
178 _ => return Err(err),
179 },
180 _ => return Err(err),
181 },
182 _ => return Err(err),
183 };
184 eval.register_evaluated_expr(new_expr?, span)
185 }
186
187 with_dollar_sign! {
188 ($d:tt) => {
189 #[allow(unused)]
190 #[doc = concat!(
191 "A convenience macro for using the same RHS for each [`",
192 stringify!($target),
193 "`] variant in a call to [`",
194 stringify!($ident),
195 "`].",
196 )]
197 macro_rules! $ident {
198 (
199 $eval:expr,
200 $span:expr,
201 [$d ($d expr:expr),+ $d (,)?],
202 |$d ($d arg:ident),+| $d tt:tt
203 ) => {
204 $ident($eval, $span, [$d ($d expr),+], |args| match args {
205 $(
206 $target::$mapping([$d ($d arg),+]) => {
207 let res = $d tt;
208 Result::map(res, $target::$mapping)
209 },
210 )+
211 })
212 };
213 }
214 };
215 }
216 };
217}
218
219gen_component_wise_extractor! {
220 component_wise_scalar -> Scalar,
221 literals: [
222 AbstractFloat => AbstractFloat: f64,
223 F32 => F32: f32,
224 F16 => F16: f16,
225 AbstractInt => AbstractInt: i64,
226 U32 => U32: u32,
227 I32 => I32: i32,
228 U64 => U64: u64,
229 I64 => I64: i64,
230 ],
231 scalar_kinds: [
232 Float,
233 AbstractFloat,
234 Sint,
235 Uint,
236 AbstractInt,
237 ],
238}
239
240gen_component_wise_extractor! {
241 component_wise_float -> Float,
242 literals: [
243 AbstractFloat => Abstract: f64,
244 F32 => F32: f32,
245 F16 => F16: f16,
246 ],
247 scalar_kinds: [
248 Float,
249 AbstractFloat,
250 ],
251}
252
253gen_component_wise_extractor! {
254 component_wise_concrete_int -> ConcreteInt,
255 literals: [
256 U32 => U32: u32,
257 I32 => I32: i32,
258 ],
259 scalar_kinds: [
260 Sint,
261 Uint,
262 ],
263}
264
265gen_component_wise_extractor! {
266 component_wise_signed -> Signed,
267 literals: [
268 AbstractFloat => AbstractFloat: f64,
269 AbstractInt => AbstractInt: i64,
270 F32 => F32: f32,
271 F16 => F16: f16,
272 I32 => I32: i32,
273 ],
274 scalar_kinds: [
275 Sint,
276 AbstractInt,
277 Float,
278 AbstractFloat,
279 ],
280}
281
282#[derive(Debug)]
284enum LiteralVector {
285 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288 U16(ArrayVec<u16, { crate::VectorSize::MAX }>),
289 I16(ArrayVec<i16, { crate::VectorSize::MAX }>),
290 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
291 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
292 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
293 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
294 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
295 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
296 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
297}
298
299impl LiteralVector {
300 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
301 fn len(&self) -> usize {
302 match *self {
303 LiteralVector::F64(ref v) => v.len(),
304 LiteralVector::F32(ref v) => v.len(),
305 LiteralVector::F16(ref v) => v.len(),
306 LiteralVector::U16(ref v) => v.len(),
307 LiteralVector::I16(ref v) => v.len(),
308 LiteralVector::U32(ref v) => v.len(),
309 LiteralVector::I32(ref v) => v.len(),
310 LiteralVector::U64(ref v) => v.len(),
311 LiteralVector::I64(ref v) => v.len(),
312 LiteralVector::Bool(ref v) => v.len(),
313 LiteralVector::AbstractInt(ref v) => v.len(),
314 LiteralVector::AbstractFloat(ref v) => v.len(),
315 }
316 }
317
318 fn from_literal(literal: Literal) -> Self {
320 fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
321 let mut v = ArrayVec::new();
322 v.push(val);
323 v
324 }
325 match literal {
326 Literal::F64(e) => Self::F64(arrayvec_of(e)),
327 Literal::F32(e) => Self::F32(arrayvec_of(e)),
328 Literal::U16(e) => Self::U16(arrayvec_of(e)),
329 Literal::I16(e) => Self::I16(arrayvec_of(e)),
330 Literal::U32(e) => Self::U32(arrayvec_of(e)),
331 Literal::I32(e) => Self::I32(arrayvec_of(e)),
332 Literal::U64(e) => Self::U64(arrayvec_of(e)),
333 Literal::I64(e) => Self::I64(arrayvec_of(e)),
334 Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
335 Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
336 Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
337 Literal::F16(e) => Self::F16(arrayvec_of(e)),
338 }
339 }
340
341 fn from_literal_vec(
346 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
347 ) -> Result<Self, ConstantEvaluatorError> {
348 assert!(!components.is_empty());
349 macro_rules! compose_literals {
351 ($components:expr, $variant:ident, $self_variant:ident) => {{
352 let mut out = ArrayVec::new();
353 for l in &$components {
354 match l {
355 &Literal::$variant(v) => out.push(v),
356 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
357 }
358 }
359 Self::$self_variant(out)
360 }};
361 }
362 Ok(match components[0] {
363 Literal::I16(_) => compose_literals!(components, I16, I16),
364 Literal::U16(_) => compose_literals!(components, U16, U16),
365 Literal::I32(_) => compose_literals!(components, I32, I32),
366 Literal::U32(_) => compose_literals!(components, U32, U32),
367 Literal::I64(_) => compose_literals!(components, I64, I64),
368 Literal::U64(_) => compose_literals!(components, U64, U64),
369 Literal::F32(_) => compose_literals!(components, F32, F32),
370 Literal::F64(_) => compose_literals!(components, F64, F64),
371 Literal::Bool(_) => compose_literals!(components, Bool, Bool),
372 Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
373 Literal::AbstractFloat(_) => {
374 compose_literals!(components, AbstractFloat, AbstractFloat)
375 }
376 Literal::F16(_) => compose_literals!(components, F16, F16),
377 })
378 }
379
380 #[allow(dead_code)]
381 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
383 macro_rules! decompose_literals {
384 ($v:expr, $variant:ident) => {{
385 let mut out = ArrayVec::new();
386 for e in $v {
387 out.push(Literal::$variant(*e));
388 }
389 out
390 }};
391 }
392 match *self {
393 LiteralVector::F64(ref v) => decompose_literals!(v, F64),
394 LiteralVector::F32(ref v) => decompose_literals!(v, F32),
395 LiteralVector::F16(ref v) => decompose_literals!(v, F16),
396 LiteralVector::U16(ref v) => decompose_literals!(v, U16),
397 LiteralVector::I16(ref v) => decompose_literals!(v, I16),
398 LiteralVector::U32(ref v) => decompose_literals!(v, U32),
399 LiteralVector::I32(ref v) => decompose_literals!(v, I32),
400 LiteralVector::U64(ref v) => decompose_literals!(v, U64),
401 LiteralVector::I64(ref v) => decompose_literals!(v, I64),
402 LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
403 LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
404 LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
405 }
406 }
407
408 #[allow(dead_code)]
409 fn register_as_evaluated_expr(
411 &self,
412 eval: &mut ConstantEvaluator<'_>,
413 span: Span,
414 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
415 let lit_vec = self.to_literal_vec();
416 assert!(!lit_vec.is_empty());
417 let expr = if lit_vec.len() == 1 {
418 Expression::Literal(lit_vec[0])
419 } else {
420 Expression::Compose {
421 ty: eval.types.insert(
422 Type {
423 name: None,
424 inner: TypeInner::Vector {
425 size: match lit_vec.len() {
426 2 => crate::VectorSize::Bi,
427 3 => crate::VectorSize::Tri,
428 4 => crate::VectorSize::Quad,
429 _ => unreachable!(),
430 },
431 scalar: lit_vec[0].scalar(),
432 },
433 },
434 Span::UNDEFINED,
435 ),
436 components: lit_vec
437 .iter()
438 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
439 .collect::<Result<_, _>>()?,
440 }
441 };
442 eval.register_evaluated_expr(expr, span)
443 }
444}
445
446macro_rules! match_literal_vector {
471 (match $lit_vec:expr => $out:ident {
472 $(
473 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
474 ),+
475 $(,)?
476 }) => {
477 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
478 };
479
480 (@inner_start
481 $lit_vec:expr;
482 $out:ident;
483 [$($ty:ident),+];
484 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
485 ) => {
486 match_literal_vector!(@inner
487 $lit_vec;
488 $out;
489 [$($ty),+];
490 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
491 )
492 };
493
494 (@inner
495 $lit_vec:expr;
496 $out:ident;
497 [$ty:ident $(, $ty1:ident)*];
498 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
499 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
500 ) => {
501 match_literal_vector!(@inner
502 $ty;
503 $lit_vec;
504 $out;
505 [$($ty1),*];
506 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
507 [$({ $($var),+ ; $($ret)? ; $body }),+]
508 )
509 };
510 (@inner
511 Integer;
512 $lit_vec:expr;
513 $out:ident;
514 [$($ty:ident),*];
515 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
516 [
517 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
518 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
519 ]
520 ) => {
521 match_literal_vector!(@inner
522 $lit_vec;
523 $out;
524 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
525 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
526 [
527 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
533 ]
534 )
535 };
536 (@inner
537 Float;
538 $lit_vec:expr;
539 $out:ident;
540 [$($ty:ident),*];
541 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
542 [
543 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
544 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
545 ]
546 ) => {
547 match_literal_vector!(@inner
548 $lit_vec;
549 $out;
550 [F16, F32, F64, AbstractFloat $(, $ty)*];
551 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
552 [
553 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
558 ]
559 )
560 };
561 (@inner
562 $ty:ident;
563 $lit_vec:expr;
564 $out:ident;
565 [$ty1:ident $(,$ty2:ident)*];
566 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
567 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
568 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
569 ]
570 ) => {
571 match_literal_vector!(@inner
572 $ty1;
573 $lit_vec;
574 $out;
575 [$($ty2),*];
576 [
577 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
578 { $ty; $($var),+ ; $($ret)? ; $body }
579 ] <>
580 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
581
582 )
583 };
584 (@inner
585 $ty:ident;
586 $lit_vec:expr;
587 $out:ident;
588 [];
589 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
590 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
591 ) => {
592 match_literal_vector!(@inner_finish
593 $lit_vec;
594 $out;
595 [
596 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
597 { $ty; $($var),+ ; $($ret)? ; $body }
598 ]
599 )
600 };
601 (@inner_finish
602 $lit_vec:expr;
603 $out:ident;
604 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
605 ) => {
606 match $lit_vec {
607 $(
608 #[allow(unused_parens)]
609 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
610 )+
611 _ => Err(ConstantEvaluatorError::InvalidMathArg),
612 }
613 };
614 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
615 $out::$ty($body)
616 };
617 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
618 $out::$ret($body)
619 };
620}
621
622fn float_length<F>(e: &[F]) -> Option<F>
623where
624 F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
625{
626 if e.len() == 1 {
627 Some(e[0].abs())
629 } else {
630 let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
631 result.is_finite().then_some(result)
632 }
633}
634
635#[derive(Debug)]
636enum Behavior<'a> {
637 Wgsl(WgslRestrictions<'a>),
638 Glsl(GlslRestrictions<'a>),
639}
640
641impl Behavior<'_> {
642 const fn has_runtime_restrictions(&self) -> bool {
644 matches!(
645 self,
646 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
647 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
648 )
649 }
650}
651
652#[derive(Debug)]
670pub struct ConstantEvaluator<'a> {
671 behavior: Behavior<'a>,
673
674 types: &'a mut UniqueArena<Type>,
681
682 constants: &'a Arena<Constant>,
684
685 overrides: &'a Arena<Override>,
687
688 expressions: &'a mut Arena<Expression>,
690
691 expression_kind_tracker: &'a mut ExpressionKindTracker,
693
694 layouter: &'a mut crate::proc::Layouter,
695}
696
697#[derive(Debug)]
698enum WgslRestrictions<'a> {
699 Const(Option<FunctionLocalData<'a>>),
701 Override,
704 Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711enum GlslRestrictions<'a> {
712 Const,
714 Runtime(FunctionLocalData<'a>),
718}
719
720#[derive(Debug)]
721struct FunctionLocalData<'a> {
722 global_expressions: &'a Arena<Expression>,
724 emitter: &'a mut super::Emitter,
725 block: &'a mut crate::Block,
726}
727
728#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
729pub enum ExpressionKind {
730 Const,
731 Override,
732 Runtime,
733}
734
735#[derive(Debug)]
736pub struct ExpressionKindTracker {
737 inner: HandleVec<Expression, ExpressionKind>,
738}
739
740impl ExpressionKindTracker {
741 pub const fn new() -> Self {
742 Self {
743 inner: HandleVec::new(),
744 }
745 }
746
747 pub fn force_non_const(&mut self, value: Handle<Expression>) {
749 self.inner[value] = ExpressionKind::Runtime;
750 }
751
752 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
753 self.inner.insert(value, expr_type);
754 }
755
756 pub fn is_const(&self, h: Handle<Expression>) -> bool {
757 matches!(self.type_of(h), ExpressionKind::Const)
758 }
759
760 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
761 matches!(
762 self.type_of(h),
763 ExpressionKind::Const | ExpressionKind::Override
764 )
765 }
766
767 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
768 self.inner[value]
769 }
770
771 pub fn from_arena(arena: &Arena<Expression>) -> Self {
772 let mut tracker = Self {
773 inner: HandleVec::with_capacity(arena.len()),
774 };
775 for (handle, expr) in arena.iter() {
776 tracker
777 .inner
778 .insert(handle, tracker.type_of_with_expr(expr));
779 }
780 tracker
781 }
782
783 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
784 match *expr {
785 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
786 ExpressionKind::Const
787 }
788 Expression::Override(_) => ExpressionKind::Override,
789 Expression::Compose { ref components, .. } => {
790 let mut expr_type = ExpressionKind::Const;
791 for component in components {
792 expr_type = expr_type.max(self.type_of(*component))
793 }
794 expr_type
795 }
796 Expression::Splat { value, .. } => self.type_of(value),
797 Expression::AccessIndex { base, .. } => self.type_of(base),
798 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
799 Expression::Swizzle { vector, .. } => self.type_of(vector),
800 Expression::Unary { expr, .. } => self.type_of(expr),
801 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
802 Expression::Math {
803 arg,
804 arg1,
805 arg2,
806 arg3,
807 ..
808 } => self
809 .type_of(arg)
810 .max(
811 arg1.map(|arg| self.type_of(arg))
812 .unwrap_or(ExpressionKind::Const),
813 )
814 .max(
815 arg2.map(|arg| self.type_of(arg))
816 .unwrap_or(ExpressionKind::Const),
817 )
818 .max(
819 arg3.map(|arg| self.type_of(arg))
820 .unwrap_or(ExpressionKind::Const),
821 ),
822 Expression::As { expr, .. } => self.type_of(expr),
823 Expression::Select {
824 condition,
825 accept,
826 reject,
827 } => self
828 .type_of(condition)
829 .max(self.type_of(accept))
830 .max(self.type_of(reject)),
831 Expression::Relational { argument, .. } => self.type_of(argument),
832 Expression::ArrayLength(expr) => self.type_of(expr),
833 _ => ExpressionKind::Runtime,
834 }
835 }
836}
837
838#[derive(Clone, Debug, thiserror::Error)]
839#[cfg_attr(test, derive(PartialEq))]
840pub enum ConstantEvaluatorError {
841 #[error("Constants cannot access function arguments")]
842 FunctionArg,
843 #[error("Constants cannot access global variables")]
844 GlobalVariable,
845 #[error("Constants cannot access local variables")]
846 LocalVariable,
847 #[error("Cannot get the array length of a non array type")]
848 InvalidArrayLengthArg,
849 #[error("Constants cannot get the array length of a dynamically sized array")]
850 ArrayLengthDynamic,
851 #[error("Cannot call arrayLength on array sized by override-expression")]
852 ArrayLengthOverridden,
853 #[error("Constants cannot call functions")]
854 Call,
855 #[error("Constants don't support workGroupUniformLoad")]
856 WorkGroupUniformLoadResult,
857 #[error("Constants don't support atomic functions")]
858 Atomic,
859 #[error("Constants don't support derivative functions")]
860 Derivative,
861 #[error("Constants don't support load expressions")]
862 Load,
863 #[error("Constants don't support image expressions")]
864 ImageExpression,
865 #[error("Constants don't support ray query expressions")]
866 RayQueryExpression,
867 #[error("Constants don't support subgroup expressions")]
868 SubgroupExpression,
869 #[error("Cannot access the type")]
870 InvalidAccessBase,
871 #[error("Cannot access at the index")]
872 InvalidAccessIndex,
873 #[error("Cannot access with index of type")]
874 InvalidAccessIndexTy,
875 #[error("Constants don't support array length expressions")]
876 ArrayLength,
877 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
878 InvalidCastArg { from: String, to: String },
879 #[error("Cannot apply the unary op to the argument")]
880 InvalidUnaryOpArg,
881 #[error("Cannot apply the binary op to the arguments")]
882 InvalidBinaryOpArgs,
883 #[error("Cannot apply math function to type")]
884 InvalidMathArg,
885 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
886 InvalidMathArgCount(crate::MathFunction, usize, usize),
887 #[error("{0} built-in function argument is out of valid range")]
888 InvalidMathArgValue(String),
889 #[error("Cannot apply relational function to type")]
890 InvalidRelationalArg(RelationalFunction),
891 #[error("value of `low` is greater than `high` for clamp built-in function")]
892 InvalidClamp,
893 #[error("Constructor expects {expected} components, found {actual}")]
894 InvalidVectorComposeLength { expected: usize, actual: usize },
895 #[error("Constructor must only contain vector or scalar arguments")]
896 InvalidVectorComposeComponent,
897 #[error("Splat is defined only on scalar values")]
898 SplatScalarOnly,
899 #[error("Can only swizzle vector constants")]
900 SwizzleVectorOnly,
901 #[error("swizzle component not present in source expression")]
902 SwizzleOutOfBounds,
903 #[error("Type is not constructible")]
904 TypeNotConstructible,
905 #[error("Subexpression(s) are not constant")]
906 SubexpressionsAreNotConstant,
907 #[error("Not implemented as constant expression: {0}")]
908 NotImplemented(String),
909 #[error("{0} operation overflowed")]
910 Overflow(String),
911 #[error(
912 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
913 )]
914 AutomaticConversionLossy {
915 value: String,
916 to_type: &'static str,
917 },
918 #[error("Division by zero")]
919 DivisionByZero,
920 #[error("Remainder by zero")]
921 RemainderByZero,
922 #[error("RHS of shift operation is greater than or equal to 32")]
923 ShiftedMoreThan32Bits,
924 #[error(transparent)]
925 Literal(#[from] crate::valid::LiteralError),
926 #[error("Can't use pipeline-overridable constants in const-expressions")]
927 Override,
928 #[error("Unexpected runtime-expression")]
929 RuntimeExpr,
930 #[error("Unexpected override-expression")]
931 OverrideExpr,
932 #[error("Expected boolean expression for condition argument of `select`, got something else")]
933 SelectScalarConditionNotABool,
934 #[error(
935 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
936 reject,
937 accept
938 )]
939 SelectVecRejectAcceptSizeMismatch {
940 reject: crate::VectorSize,
941 accept: crate::VectorSize,
942 },
943 #[error("Expected boolean vector for condition arg., got something else")]
944 SelectConditionNotAVecBool,
945 #[error(
946 "Expected same number of vector components between condition, accept, and reject args., got something else",
947 )]
948 SelectConditionVecSizeMismatch,
949 #[error(
950 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
951 )]
952 SelectAcceptRejectTypeMismatch,
953 #[error("Cooperative operations can't be constant")]
954 CooperativeOperation,
955}
956
957impl<'a> ConstantEvaluator<'a> {
958 pub const fn for_wgsl_module(
963 module: &'a mut crate::Module,
964 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
965 layouter: &'a mut crate::proc::Layouter,
966 in_override_ctx: bool,
967 ) -> Self {
968 Self::for_module(
969 Behavior::Wgsl(if in_override_ctx {
970 WgslRestrictions::Override
971 } else {
972 WgslRestrictions::Const(None)
973 }),
974 module,
975 global_expression_kind_tracker,
976 layouter,
977 )
978 }
979
980 pub const fn for_glsl_module(
985 module: &'a mut crate::Module,
986 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
987 layouter: &'a mut crate::proc::Layouter,
988 ) -> Self {
989 Self::for_module(
990 Behavior::Glsl(GlslRestrictions::Const),
991 module,
992 global_expression_kind_tracker,
993 layouter,
994 )
995 }
996
997 const fn for_module(
998 behavior: Behavior<'a>,
999 module: &'a mut crate::Module,
1000 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1001 layouter: &'a mut crate::proc::Layouter,
1002 ) -> Self {
1003 Self {
1004 behavior,
1005 types: &mut module.types,
1006 constants: &module.constants,
1007 overrides: &module.overrides,
1008 expressions: &mut module.global_expressions,
1009 expression_kind_tracker: global_expression_kind_tracker,
1010 layouter,
1011 }
1012 }
1013
1014 pub const fn for_wgsl_function(
1019 module: &'a mut crate::Module,
1020 expressions: &'a mut Arena<Expression>,
1021 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1022 layouter: &'a mut crate::proc::Layouter,
1023 emitter: &'a mut super::Emitter,
1024 block: &'a mut crate::Block,
1025 is_const: bool,
1026 ) -> Self {
1027 let local_data = FunctionLocalData {
1028 global_expressions: &module.global_expressions,
1029 emitter,
1030 block,
1031 };
1032 Self {
1033 behavior: Behavior::Wgsl(if is_const {
1034 WgslRestrictions::Const(Some(local_data))
1035 } else {
1036 WgslRestrictions::Runtime(local_data)
1037 }),
1038 types: &mut module.types,
1039 constants: &module.constants,
1040 overrides: &module.overrides,
1041 expressions,
1042 expression_kind_tracker: local_expression_kind_tracker,
1043 layouter,
1044 }
1045 }
1046
1047 pub const fn for_glsl_function(
1052 module: &'a mut crate::Module,
1053 expressions: &'a mut Arena<Expression>,
1054 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1055 layouter: &'a mut crate::proc::Layouter,
1056 emitter: &'a mut super::Emitter,
1057 block: &'a mut crate::Block,
1058 ) -> Self {
1059 Self {
1060 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1061 global_expressions: &module.global_expressions,
1062 emitter,
1063 block,
1064 })),
1065 types: &mut module.types,
1066 constants: &module.constants,
1067 overrides: &module.overrides,
1068 expressions,
1069 expression_kind_tracker: local_expression_kind_tracker,
1070 layouter,
1071 }
1072 }
1073
1074 pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1075 crate::proc::GlobalCtx {
1076 types: self.types,
1077 constants: self.constants,
1078 overrides: self.overrides,
1079 global_expressions: match self.function_local_data() {
1080 Some(data) => data.global_expressions,
1081 None => self.expressions,
1082 },
1083 }
1084 }
1085
1086 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1087 if !self.expression_kind_tracker.is_const(expr) {
1088 log::debug!("check: SubexpressionsAreNotConstant");
1089 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1090 }
1091 Ok(())
1092 }
1093
1094 fn check_and_get(
1095 &mut self,
1096 expr: Handle<Expression>,
1097 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1098 match self.expressions[expr] {
1099 Expression::Constant(c) => {
1100 if let Some(function_local_data) = self.function_local_data() {
1103 self.copy_from(
1105 self.constants[c].init,
1106 function_local_data.global_expressions,
1107 )
1108 } else {
1109 Ok(self.constants[c].init)
1111 }
1112 }
1113 _ => {
1114 self.check(expr)?;
1115 Ok(expr)
1116 }
1117 }
1118 }
1119
1120 pub fn try_eval_and_append(
1144 &mut self,
1145 expr: Expression,
1146 span: Span,
1147 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1148 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1149 ExpressionKind::Const => {
1150 let eval_result = self.try_eval_and_append_impl(&expr, span);
1151 if self.behavior.has_runtime_restrictions()
1156 && matches!(
1157 eval_result,
1158 Err(ConstantEvaluatorError::NotImplemented(_)
1159 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1160 )
1161 {
1162 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1163 } else {
1164 eval_result
1165 }
1166 }
1167 ExpressionKind::Override => match self.behavior {
1168 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1169 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1170 }
1171 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1172 Err(ConstantEvaluatorError::OverrideExpr)
1173 }
1174
1175 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1177 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1178 }
1179 Behavior::Glsl(GlslRestrictions::Const) => {
1180 Err(ConstantEvaluatorError::OverrideExpr)
1181 }
1182 },
1183 ExpressionKind::Runtime => {
1184 if self.behavior.has_runtime_restrictions() {
1185 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1186 } else {
1187 Err(ConstantEvaluatorError::RuntimeExpr)
1188 }
1189 }
1190 }
1191 }
1192
1193 const fn is_global_arena(&self) -> bool {
1195 matches!(
1196 self.behavior,
1197 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1198 | Behavior::Glsl(GlslRestrictions::Const)
1199 )
1200 }
1201
1202 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1203 match self.behavior {
1204 Behavior::Wgsl(
1205 WgslRestrictions::Runtime(ref function_local_data)
1206 | WgslRestrictions::Const(Some(ref function_local_data)),
1207 )
1208 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1209 Some(function_local_data)
1210 }
1211 _ => None,
1212 }
1213 }
1214
1215 fn try_eval_and_append_impl(
1216 &mut self,
1217 expr: &Expression,
1218 span: Span,
1219 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1220 log::trace!("try_eval_and_append: {expr:?}");
1221 match *expr {
1222 Expression::Constant(c) if self.is_global_arena() => {
1223 Ok(self.constants[c].init)
1226 }
1227 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1228 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1229 self.register_evaluated_expr(expr.clone(), span)
1230 }
1231 Expression::Compose { ty, ref components } => {
1232 let components = components
1233 .iter()
1234 .map(|component| self.check_and_get(*component))
1235 .collect::<Result<Vec<_>, _>>()?;
1236 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1237 }
1238 Expression::Splat { size, value } => {
1239 let value = self.check_and_get(value)?;
1240 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1241 }
1242 Expression::AccessIndex { base, index } => {
1243 let base = self.check_and_get(base)?;
1244
1245 self.access(base, index as usize, span)
1246 }
1247 Expression::Access { base, index } => {
1248 let base = self.check_and_get(base)?;
1249 let index = self.check_and_get(index)?;
1250
1251 let index_val: u32 = self
1252 .to_ctx()
1253 .get_const_val_from(index, self.expressions)
1254 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1255 self.access(base, index_val as usize, span)
1256 }
1257 Expression::Swizzle {
1258 size,
1259 vector,
1260 pattern,
1261 } => {
1262 let vector = self.check_and_get(vector)?;
1263
1264 self.swizzle(size, span, vector, pattern)
1265 }
1266 Expression::Unary { expr, op } => {
1267 let expr = self.check_and_get(expr)?;
1268
1269 self.unary_op(op, expr, span)
1270 }
1271 Expression::Binary { left, right, op } => {
1272 let left = self.check_and_get(left)?;
1273 let right = self.check_and_get(right)?;
1274
1275 self.binary_op(op, left, right, span)
1276 }
1277 Expression::Math {
1278 fun,
1279 arg,
1280 arg1,
1281 arg2,
1282 arg3,
1283 } => {
1284 let arg = self.check_and_get(arg)?;
1285 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1286 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1287 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1288
1289 self.math(arg, arg1, arg2, arg3, fun, span)
1290 }
1291 Expression::As {
1292 convert,
1293 expr,
1294 kind,
1295 } => {
1296 let expr = self.check_and_get(expr)?;
1297
1298 match convert {
1299 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1300 None => Err(ConstantEvaluatorError::NotImplemented(
1301 "bitcast built-in function".into(),
1302 )),
1303 }
1304 }
1305 Expression::Select {
1306 reject,
1307 accept,
1308 condition,
1309 } => {
1310 let mut arg = |expr| self.check_and_get(expr);
1311
1312 let reject = arg(reject)?;
1313 let accept = arg(accept)?;
1314 let condition = arg(condition)?;
1315
1316 self.select(reject, accept, condition, span)
1317 }
1318 Expression::Relational { fun, argument } => {
1319 let argument = self.check_and_get(argument)?;
1320 self.relational(fun, argument, span)
1321 }
1322 Expression::ArrayLength(expr) => match self.behavior {
1323 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1324 Behavior::Glsl(_) => {
1325 let expr = self.check_and_get(expr)?;
1326 self.array_length(expr, span)
1327 }
1328 },
1329 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1330 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1331 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1332 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1333 Expression::WorkGroupUniformLoadResult { .. } => {
1334 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1335 }
1336 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1337 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1338 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1339 Expression::ImageSample { .. }
1340 | Expression::ImageLoad { .. }
1341 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1342 Expression::RayQueryProceedResult
1343 | Expression::RayQueryGetIntersection { .. }
1344 | Expression::RayQueryVertexPositions { .. } => {
1345 Err(ConstantEvaluatorError::RayQueryExpression)
1346 }
1347 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1348 Expression::SubgroupOperationResult { .. } => {
1349 Err(ConstantEvaluatorError::SubgroupExpression)
1350 }
1351 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1352 Err(ConstantEvaluatorError::CooperativeOperation)
1353 }
1354 }
1355 }
1356
1357 fn splat(
1370 &mut self,
1371 value: Handle<Expression>,
1372 size: crate::VectorSize,
1373 span: Span,
1374 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1375 match self.expressions[value] {
1376 Expression::Literal(literal) => {
1377 let scalar = literal.scalar();
1378 let ty = self.types.insert(
1379 Type {
1380 name: None,
1381 inner: TypeInner::Vector { size, scalar },
1382 },
1383 span,
1384 );
1385 let expr = Expression::Compose {
1386 ty,
1387 components: vec![value; size as usize],
1388 };
1389 self.register_evaluated_expr(expr, span)
1390 }
1391 Expression::ZeroValue(ty) => {
1392 let inner = match self.types[ty].inner {
1393 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1394 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1395 };
1396 let res_ty = self.types.insert(Type { name: None, inner }, span);
1397 let expr = Expression::ZeroValue(res_ty);
1398 self.register_evaluated_expr(expr, span)
1399 }
1400 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1401 }
1402 }
1403
1404 fn swizzle(
1405 &mut self,
1406 size: crate::VectorSize,
1407 span: Span,
1408 src_constant: Handle<Expression>,
1409 pattern: [crate::SwizzleComponent; 4],
1410 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1411 let mut get_dst_ty = |ty| match self.types[ty].inner {
1412 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1413 Type {
1414 name: None,
1415 inner: TypeInner::Vector { size, scalar },
1416 },
1417 span,
1418 )),
1419 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1420 };
1421
1422 match self.expressions[src_constant] {
1423 Expression::ZeroValue(ty) => {
1424 let dst_ty = get_dst_ty(ty)?;
1425 let expr = Expression::ZeroValue(dst_ty);
1426 self.register_evaluated_expr(expr, span)
1427 }
1428 Expression::Splat { value, .. } => {
1429 let expr = Expression::Splat { size, value };
1430 self.register_evaluated_expr(expr, span)
1431 }
1432 Expression::Compose { ty, ref components } => {
1433 let dst_ty = get_dst_ty(ty)?;
1434
1435 let mut flattened = [src_constant; 4]; let len =
1437 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1438 .zip(flattened.iter_mut())
1439 .map(|(component, elt)| *elt = component)
1440 .count();
1441 let flattened = &flattened[..len];
1442
1443 let swizzled_components = pattern[..size as usize]
1444 .iter()
1445 .map(|&sc| {
1446 let sc = sc as usize;
1447 if let Some(elt) = flattened.get(sc) {
1448 Ok(*elt)
1449 } else {
1450 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1451 }
1452 })
1453 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1454 let expr = Expression::Compose {
1455 ty: dst_ty,
1456 components: swizzled_components,
1457 };
1458 self.register_evaluated_expr(expr, span)
1459 }
1460 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1461 }
1462 }
1463
1464 fn math(
1465 &mut self,
1466 arg: Handle<Expression>,
1467 arg1: Option<Handle<Expression>>,
1468 arg2: Option<Handle<Expression>>,
1469 arg3: Option<Handle<Expression>>,
1470 fun: crate::MathFunction,
1471 span: Span,
1472 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1473 let expected = fun.argument_count();
1474 let given = Some(arg)
1475 .into_iter()
1476 .chain(arg1)
1477 .chain(arg2)
1478 .chain(arg3)
1479 .count();
1480 if expected != given {
1481 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1482 fun, expected, given,
1483 ));
1484 }
1485
1486 match fun {
1488 crate::MathFunction::Abs => {
1490 component_wise_scalar(self, span, [arg], |args| match args {
1491 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1492 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1493 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1494 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1495 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1496 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1498 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1499 })
1500 }
1501 crate::MathFunction::Min => {
1502 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1503 Ok([e1.min(e2)])
1504 })
1505 }
1506 crate::MathFunction::Max => {
1507 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1508 Ok([e1.max(e2)])
1509 })
1510 }
1511 crate::MathFunction::Clamp => {
1512 component_wise_scalar!(
1513 self,
1514 span,
1515 [arg, arg1.unwrap(), arg2.unwrap()],
1516 |e, low, high| {
1517 if low > high {
1518 Err(ConstantEvaluatorError::InvalidClamp)
1519 } else {
1520 Ok([e.clamp(low, high)])
1521 }
1522 }
1523 )
1524 }
1525 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1526 Float::F16([e]) => Ok(Float::F16(
1527 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1528 )),
1529 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1530 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1531 }),
1532
1533 crate::MathFunction::Cos => {
1535 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1536 }
1537 crate::MathFunction::Cosh => {
1538 component_wise_float!(self, span, [arg], |e| {
1539 let result = e.cosh();
1540 if result.is_finite() {
1541 Ok([result])
1542 } else {
1543 Err(ConstantEvaluatorError::Overflow("cosh".into()))
1544 }
1545 })
1546 }
1547 crate::MathFunction::Sin => {
1548 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1549 }
1550 crate::MathFunction::Sinh => {
1551 component_wise_float!(self, span, [arg], |e| {
1552 let result = e.sinh();
1553 if result.is_finite() {
1554 Ok([result])
1555 } else {
1556 Err(ConstantEvaluatorError::Overflow("sinh".into()))
1557 }
1558 })
1559 }
1560 crate::MathFunction::Tan => {
1561 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1562 }
1563 crate::MathFunction::Tanh => {
1564 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1565 }
1566 crate::MathFunction::Acos => {
1567 component_wise_float!(self, span, [arg], |e| {
1568 if e.abs() <= One::one() {
1569 Ok([e.acos()])
1570 } else {
1571 Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1572 }
1573 })
1574 }
1575 crate::MathFunction::Asin => {
1576 component_wise_float!(self, span, [arg], |e| {
1577 if e.abs() <= One::one() {
1578 Ok([e.asin()])
1579 } else {
1580 Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1581 }
1582 })
1583 }
1584 crate::MathFunction::Atan => {
1585 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1586 }
1587 crate::MathFunction::Atan2 => {
1588 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1589 Ok([y.atan2(x)])
1590 })
1591 }
1592 crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1593 Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1594 Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1595 Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1596 }),
1597 crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1598 Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1599 Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1600 Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1601 _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1602 }),
1603 crate::MathFunction::Atanh => {
1604 component_wise_float!(self, span, [arg], |e| {
1605 if e.abs() < One::one() {
1606 Ok([e.atanh()])
1607 } else {
1608 Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1609 }
1610 })
1611 }
1612 crate::MathFunction::Radians => {
1613 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1614 }
1615 crate::MathFunction::Degrees => {
1616 component_wise_float!(self, span, [arg], |e| {
1617 let result = e.to_degrees();
1618 if result.is_finite() {
1619 Ok([result])
1620 } else {
1621 Err(ConstantEvaluatorError::Overflow("degrees".into()))
1622 }
1623 })
1624 }
1625
1626 crate::MathFunction::Ceil => {
1628 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1629 }
1630 crate::MathFunction::Floor => {
1631 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1632 }
1633 crate::MathFunction::Round => {
1634 component_wise_float(self, span, [arg], |e| match e {
1635 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1636 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1637 Float::F16([e]) => {
1638 fn round_ties_even(x: f64) -> f64 {
1646 let i = x as i64;
1647 let f = (x - i as f64).abs();
1648 if f == 0.5 {
1649 if i & 1 == 1 {
1650 (x.abs() + 0.5).copysign(x)
1652 } else {
1653 (x.abs() - 0.5).copysign(x)
1654 }
1655 } else {
1656 x.round()
1657 }
1658 }
1659
1660 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1661 }
1662 })
1663 }
1664 crate::MathFunction::Fract => {
1665 component_wise_float!(self, span, [arg], |e| {
1666 Ok([e - e.floor()])
1669 })
1670 }
1671 crate::MathFunction::Trunc => {
1672 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1673 }
1674
1675 crate::MathFunction::Exp => {
1677 component_wise_float!(self, span, [arg], |e| {
1678 let result = e.exp();
1679 if result.is_finite() {
1680 Ok([result])
1681 } else {
1682 Err(ConstantEvaluatorError::Overflow("exp".into()))
1683 }
1684 })
1685 }
1686 crate::MathFunction::Exp2 => {
1687 component_wise_float!(self, span, [arg], |e| {
1688 let result = e.exp2();
1689 if result.is_finite() {
1690 Ok([result])
1691 } else {
1692 Err(ConstantEvaluatorError::Overflow("exp2".into()))
1693 }
1694 })
1695 }
1696 crate::MathFunction::Log => {
1697 component_wise_float!(self, span, [arg], |e| {
1698 if e > Zero::zero() {
1699 Ok([e.ln()])
1700 } else {
1701 Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1702 }
1703 })
1704 }
1705 crate::MathFunction::Log2 => {
1706 component_wise_float!(self, span, [arg], |e| {
1707 if e > Zero::zero() {
1708 Ok([e.log2()])
1709 } else {
1710 Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1711 }
1712 })
1713 }
1714 crate::MathFunction::Pow => {
1715 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1716 if e1 < Zero::zero()
1719 || e1.is_one() && e2.is_infinite()
1720 || e1.is_infinite() && e2.is_zero()
1721 || e1.is_zero() && e2.is_zero()
1722 {
1723 Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1724 } else {
1725 let result = e1.powf(e2);
1726 if result.is_finite() {
1727 Ok([result])
1728 } else {
1729 Err(ConstantEvaluatorError::Overflow("pow".into()))
1730 }
1731 }
1732 })
1733 }
1734
1735 crate::MathFunction::Sign => {
1737 component_wise_signed!(self, span, [arg], |e| {
1738 Ok([if e.is_zero() {
1739 Zero::zero()
1740 } else {
1741 e.signum()
1742 }])
1743 })
1744 }
1745 crate::MathFunction::Fma => {
1746 component_wise_float!(
1747 self,
1748 span,
1749 [arg, arg1.unwrap(), arg2.unwrap()],
1750 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1751 )
1752 }
1753 crate::MathFunction::Step => {
1754 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1755 Float::Abstract([edge, x]) => {
1756 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1757 }
1758 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1759 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1760 f16::one()
1761 } else {
1762 f16::zero()
1763 }])),
1764 })
1765 }
1766 crate::MathFunction::Sqrt => {
1767 component_wise_float!(self, span, [arg], |e| {
1768 if e >= Zero::zero() {
1769 Ok([e.sqrt()])
1770 } else {
1771 Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1772 }
1773 })
1774 }
1775 crate::MathFunction::InverseSqrt => {
1776 component_wise_float(self, span, [arg], |e| match e {
1777 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1778 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1779 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1780 })
1781 }
1782
1783 crate::MathFunction::CountTrailingZeros => {
1785 component_wise_concrete_int!(self, span, [arg], |e| {
1786 #[allow(clippy::useless_conversion)]
1787 Ok([e
1788 .trailing_zeros()
1789 .try_into()
1790 .expect("bit count overflowed 32 bits, somehow!?")])
1791 })
1792 }
1793 crate::MathFunction::CountLeadingZeros => {
1794 component_wise_concrete_int!(self, span, [arg], |e| {
1795 #[allow(clippy::useless_conversion)]
1796 Ok([e
1797 .leading_zeros()
1798 .try_into()
1799 .expect("bit count overflowed 32 bits, somehow!?")])
1800 })
1801 }
1802 crate::MathFunction::CountOneBits => {
1803 component_wise_concrete_int!(self, span, [arg], |e| {
1804 #[allow(clippy::useless_conversion)]
1805 Ok([e
1806 .count_ones()
1807 .try_into()
1808 .expect("bit count overflowed 32 bits, somehow!?")])
1809 })
1810 }
1811 crate::MathFunction::ReverseBits => {
1812 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1813 }
1814 crate::MathFunction::FirstTrailingBit => {
1815 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1816 }
1817 crate::MathFunction::FirstLeadingBit => {
1818 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1819 }
1820
1821 crate::MathFunction::Dot4I8Packed => {
1823 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1824 }
1825 crate::MathFunction::Dot4U8Packed => {
1826 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1827 }
1828 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1829 crate::MathFunction::Dot => {
1830 let e1 = self.extract_vec(arg, false)?;
1832 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1833 if e1.len() != e2.len() {
1834 return Err(ConstantEvaluatorError::InvalidMathArg);
1835 }
1836
1837 fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1838 where
1839 P: num_traits::Float,
1840 {
1841 let result = a
1842 .iter()
1843 .zip(b.iter())
1844 .map(|(&aa, &bb)| aa * bb)
1845 .fold(P::zero(), |acc, x| acc + x);
1846 if result.is_finite() {
1847 Ok(result)
1848 } else {
1849 Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1850 }
1851 }
1852
1853 fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1854 where
1855 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1856 {
1857 a.iter()
1858 .zip(b.iter())
1859 .map(|(&aa, bb)| aa.checked_mul(bb))
1860 .try_fold(P::zero(), |acc, x| {
1861 if let Some(x) = x {
1862 acc.checked_add(&x)
1863 } else {
1864 None
1865 }
1866 })
1867 .ok_or(ConstantEvaluatorError::Overflow(
1868 "in dot built-in".to_string(),
1869 ))
1870 }
1871
1872 fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1873 where
1874 P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1875 {
1876 a.iter()
1877 .zip(b.iter())
1878 .map(|(&aa, bb)| aa.wrapping_mul(bb))
1879 .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1880 }
1881
1882 let result = match_literal_vector!(match (e1, e2) => Literal {
1883 Float => |e1, e2| { float_dot_checked(e1, e2)? },
1884 AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1885 I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1886 U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1887 })?;
1888 self.register_evaluated_expr(Expression::Literal(result), span)
1889 }
1890 crate::MathFunction::Length => {
1891 let e1 = self.extract_vec(arg, true)?;
1893
1894 let result = match_literal_vector!(match e1 => Literal {
1895 Float => |e1| {
1896 float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1897 },
1898 })?;
1899 self.register_evaluated_expr(Expression::Literal(result), span)
1900 }
1901 crate::MathFunction::Distance => {
1902 let e1 = self.extract_vec(arg, true)?;
1904 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1905 if e1.len() != e2.len() {
1906 return Err(ConstantEvaluatorError::InvalidMathArg);
1907 }
1908
1909 fn float_distance<F>(a: &[F], b: &[F]) -> F
1910 where
1911 F: core::ops::Mul<F>,
1912 F: num_traits::Float + iter::Sum + core::ops::Sub,
1913 {
1914 if a.len() == 1 {
1915 (a[0] - b[0]).abs()
1917 } else {
1918 a.iter()
1919 .zip(b.iter())
1920 .map(|(&aa, &bb)| aa - bb)
1921 .map(|ei| ei * ei)
1922 .sum::<F>()
1923 .sqrt()
1924 }
1925 }
1926 let result = match_literal_vector!(match (e1, e2) => Literal {
1927 Float => |e1, e2| { float_distance(e1, e2) },
1928 })?;
1929 self.register_evaluated_expr(Expression::Literal(result), span)
1930 }
1931 crate::MathFunction::Normalize => {
1932 let e1 = self.extract_vec(arg, true)?;
1934
1935 fn float_normalize<F>(
1936 e: &[F],
1937 ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1938 where
1939 F: core::ops::Mul<F>,
1940 F: num_traits::Float + iter::Sum,
1941 {
1942 let len = match float_length(e) {
1943 Some(len) if !len.is_zero() => Ok(len),
1944 Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1945 "normalize".into(),
1946 )),
1947 None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1948 }?;
1949
1950 let mut out = ArrayVec::new();
1951 for &ei in e {
1952 out.push(ei / len);
1953 }
1954 Ok(out)
1955 }
1956
1957 let result = match_literal_vector!(match e1 => LiteralVector {
1958 Float => |e1| { float_normalize(e1)? },
1959 })?;
1960 result.register_as_evaluated_expr(self, span)
1961 }
1962
1963 crate::MathFunction::Modf
1965 | crate::MathFunction::Frexp
1966 | crate::MathFunction::Ldexp
1967 | crate::MathFunction::Outer
1968 | crate::MathFunction::FaceForward
1969 | crate::MathFunction::Reflect
1970 | crate::MathFunction::Refract
1971 | crate::MathFunction::Mix
1972 | crate::MathFunction::SmoothStep
1973 | crate::MathFunction::Inverse
1974 | crate::MathFunction::Transpose
1975 | crate::MathFunction::Determinant
1976 | crate::MathFunction::QuantizeToF16
1977 | crate::MathFunction::ExtractBits
1978 | crate::MathFunction::InsertBits
1979 | crate::MathFunction::Pack4x8snorm
1980 | crate::MathFunction::Pack4x8unorm
1981 | crate::MathFunction::Pack2x16snorm
1982 | crate::MathFunction::Pack2x16unorm
1983 | crate::MathFunction::Pack2x16float
1984 | crate::MathFunction::Pack4xI8
1985 | crate::MathFunction::Pack4xU8
1986 | crate::MathFunction::Pack4xI8Clamp
1987 | crate::MathFunction::Pack4xU8Clamp
1988 | crate::MathFunction::Unpack4x8snorm
1989 | crate::MathFunction::Unpack4x8unorm
1990 | crate::MathFunction::Unpack2x16snorm
1991 | crate::MathFunction::Unpack2x16unorm
1992 | crate::MathFunction::Unpack2x16float
1993 | crate::MathFunction::Unpack4xI8
1994 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1995 format!("{fun:?} built-in function"),
1996 )),
1997 }
1998 }
1999
2000 fn packed_dot_product(
2002 &mut self,
2003 a: Handle<Expression>,
2004 b: Handle<Expression>,
2005 span: Span,
2006 signed: bool,
2007 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2008 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
2009 return Err(ConstantEvaluatorError::InvalidMathArg);
2010 };
2011 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2012 return Err(ConstantEvaluatorError::InvalidMathArg);
2013 };
2014
2015 let result = if signed {
2016 Literal::I32(
2017 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2018 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2019 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2020 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2021 )
2022 } else {
2023 Literal::U32(
2024 (a & 0xFF) * (b & 0xFF)
2025 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2026 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2027 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2028 )
2029 };
2030
2031 self.register_evaluated_expr(Expression::Literal(result), span)
2032 }
2033
2034 fn cross_product(
2036 &mut self,
2037 a: Handle<Expression>,
2038 b: Handle<Expression>,
2039 span: Span,
2040 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2041 use Literal as Li;
2042
2043 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2044 let (b, _) = self.extract_vec_with_size::<3>(b)?;
2045
2046 let product = match (a, b) {
2047 (
2048 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2049 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2050 ) => {
2051 let p = cross_product(
2056 [a0 as f64, a1 as f64, a2 as f64],
2057 [b0 as f64, b1 as f64, b2 as f64],
2058 );
2059 [
2060 Li::AbstractFloat(p[0]),
2061 Li::AbstractFloat(p[1]),
2062 Li::AbstractFloat(p[2]),
2063 ]
2064 }
2065 (
2066 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2067 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2068 ) => {
2069 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2070 [
2071 Li::AbstractFloat(p[0]),
2072 Li::AbstractFloat(p[1]),
2073 Li::AbstractFloat(p[2]),
2074 ]
2075 }
2076 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2077 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2078 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2079 }
2080 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2081 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2082 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2083 }
2084 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2085 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2086 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2087 }
2088 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2089 };
2090
2091 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2092 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2093 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2094
2095 self.register_evaluated_expr(
2096 Expression::Compose {
2097 ty,
2098 components: vec![p0, p1, p2],
2099 },
2100 span,
2101 )
2102 }
2103
2104 fn extract_vec_with_size<const N: usize>(
2112 &mut self,
2113 expr: Handle<Expression>,
2114 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2115 let span = self.expressions.get_span(expr);
2116 let expr = self.eval_zero_value_and_splat(expr, span)?;
2117 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2118 return Err(ConstantEvaluatorError::InvalidMathArg);
2119 };
2120
2121 let mut value = [Literal::Bool(false); N];
2122 for (component, elt) in
2123 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2124 .zip(value.iter_mut())
2125 {
2126 let Expression::Literal(literal) = self.expressions[component] else {
2127 return Err(ConstantEvaluatorError::InvalidMathArg);
2128 };
2129 *elt = literal;
2130 }
2131
2132 Ok((value, ty))
2133 }
2134
2135 fn extract_vec(
2143 &mut self,
2144 expr: Handle<Expression>,
2145 allow_single: bool,
2146 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2147 let span = self.expressions.get_span(expr);
2148 let expr = self.eval_zero_value_and_splat(expr, span)?;
2149
2150 match self.expressions[expr] {
2151 Expression::Literal(literal) if allow_single => {
2152 Ok(LiteralVector::from_literal(literal))
2153 }
2154 Expression::Compose { ty, ref components } => {
2155 let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2156 for expr in
2157 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2158 {
2159 match self.expressions[expr] {
2160 Expression::Literal(l) => components_out.push(l),
2161 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2162 }
2163 }
2164 LiteralVector::from_literal_vec(components_out)
2165 }
2166 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2167 }
2168 }
2169
2170 fn array_length(
2171 &mut self,
2172 array: Handle<Expression>,
2173 span: Span,
2174 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2175 match self.expressions[array] {
2176 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2177 match self.types[ty].inner {
2178 TypeInner::Array { size, .. } => match size {
2179 ArraySize::Constant(len) => {
2180 let expr = Expression::Literal(Literal::U32(len.get()));
2181 self.register_evaluated_expr(expr, span)
2182 }
2183 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2184 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2185 },
2186 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2187 }
2188 }
2189 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2190 }
2191 }
2192
2193 fn access(
2194 &mut self,
2195 base: Handle<Expression>,
2196 index: usize,
2197 span: Span,
2198 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2199 match self.expressions[base] {
2200 Expression::ZeroValue(ty) => {
2201 let ty_inner = &self.types[ty].inner;
2202 let components = ty_inner
2203 .components()
2204 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2205
2206 if index >= components as usize {
2207 Err(ConstantEvaluatorError::InvalidAccessBase)
2208 } else {
2209 let ty_res = ty_inner
2210 .component_type(index)
2211 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2212 let ty = match ty_res {
2213 crate::proc::TypeResolution::Handle(ty) => ty,
2214 crate::proc::TypeResolution::Value(inner) => {
2215 self.types.insert(Type { name: None, inner }, span)
2216 }
2217 };
2218 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2219 }
2220 }
2221 Expression::Splat { size, value } => {
2222 if index >= size as usize {
2223 Err(ConstantEvaluatorError::InvalidAccessBase)
2224 } else {
2225 Ok(value)
2226 }
2227 }
2228 Expression::Compose { ty, ref components } => {
2229 let _ = self.types[ty]
2230 .inner
2231 .components()
2232 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2233
2234 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2235 .nth(index)
2236 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2237 }
2238 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2239 }
2240 }
2241
2242 fn eval_zero_value_and_splat(
2249 &mut self,
2250 mut expr: Handle<Expression>,
2251 span: Span,
2252 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2253 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2256 let components = components
2257 .clone()
2258 .iter()
2259 .map(|component| self.eval_zero_value_and_splat(*component, span))
2260 .collect::<Result<_, _>>()?;
2261 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2262 }
2263
2264 if let Expression::Splat { size, value } = self.expressions[expr] {
2268 expr = self.splat(value, size, span)?;
2269 }
2270 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2271 expr = self.eval_zero_value_impl(ty, span)?;
2272 }
2273 Ok(expr)
2274 }
2275
2276 fn eval_zero_value(
2282 &mut self,
2283 expr: Handle<Expression>,
2284 span: Span,
2285 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2286 match self.expressions[expr] {
2287 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2288 _ => Ok(expr),
2289 }
2290 }
2291
2292 fn eval_zero_value_impl(
2298 &mut self,
2299 ty: Handle<Type>,
2300 span: Span,
2301 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2302 match self.types[ty].inner {
2303 TypeInner::Scalar(scalar) => {
2304 let expr = Expression::Literal(
2305 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2306 );
2307 self.register_evaluated_expr(expr, span)
2308 }
2309 TypeInner::Vector { size, scalar } => {
2310 let scalar_ty = self.types.insert(
2311 Type {
2312 name: None,
2313 inner: TypeInner::Scalar(scalar),
2314 },
2315 span,
2316 );
2317 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2318 let expr = Expression::Compose {
2319 ty,
2320 components: vec![el; size as usize],
2321 };
2322 self.register_evaluated_expr(expr, span)
2323 }
2324 TypeInner::Matrix {
2325 columns,
2326 rows,
2327 scalar,
2328 } => {
2329 let vec_ty = self.types.insert(
2330 Type {
2331 name: None,
2332 inner: TypeInner::Vector { size: rows, scalar },
2333 },
2334 span,
2335 );
2336 let el = self.eval_zero_value_impl(vec_ty, span)?;
2337 let expr = Expression::Compose {
2338 ty,
2339 components: vec![el; columns as usize],
2340 };
2341 self.register_evaluated_expr(expr, span)
2342 }
2343 TypeInner::Array {
2344 base,
2345 size: ArraySize::Constant(size),
2346 ..
2347 } => {
2348 let el = self.eval_zero_value_impl(base, span)?;
2349 let expr = Expression::Compose {
2350 ty,
2351 components: vec![el; size.get() as usize],
2352 };
2353 self.register_evaluated_expr(expr, span)
2354 }
2355 TypeInner::Struct { ref members, .. } => {
2356 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2357 let mut components = Vec::with_capacity(members.len());
2358 for ty in types {
2359 components.push(self.eval_zero_value_impl(ty, span)?);
2360 }
2361 let expr = Expression::Compose { ty, components };
2362 self.register_evaluated_expr(expr, span)
2363 }
2364 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2365 }
2366 }
2367
2368 pub fn cast(
2372 &mut self,
2373 expr: Handle<Expression>,
2374 target: crate::Scalar,
2375 span: Span,
2376 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2377 use crate::Scalar as Sc;
2378
2379 let expr = self.eval_zero_value(expr, span)?;
2380
2381 let make_error = || -> Result<_, ConstantEvaluatorError> {
2382 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2383
2384 #[cfg(feature = "wgsl-in")]
2385 let to = target.to_wgsl_for_diagnostics();
2386
2387 #[cfg(not(feature = "wgsl-in"))]
2388 let to = format!("{target:?}");
2389
2390 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2391 };
2392
2393 use crate::proc::type_methods::IntFloatLimits;
2394
2395 let expr = match self.expressions[expr] {
2396 Expression::Literal(literal) => {
2397 let literal = match target {
2398 Sc::I16 => Literal::I16(match literal {
2399 Literal::I16(v) => v,
2400 Literal::U16(v) => v as i16,
2401 Literal::I32(v) => v as i16,
2402 Literal::U32(v) => v as i16,
2403 Literal::F32(v) => v.clamp(i16::MIN as f32, i16::MAX as f32) as i16,
2404 Literal::F16(v) => {
2405 f16::to_f32(v).clamp(i16::MIN as f32, i16::MAX as f32) as i16
2406 }
2407 Literal::Bool(v) => v as i16,
2408 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2409 return make_error();
2410 }
2411 Literal::AbstractInt(v) => v as i16,
2412 Literal::AbstractFloat(v) => v as i16,
2413 }),
2414 Sc::U16 => Literal::U16(match literal {
2415 Literal::I16(v) => v as u16,
2416 Literal::U16(v) => v,
2417 Literal::I32(v) => v as u16,
2418 Literal::U32(v) => v as u16,
2419 Literal::F32(v) => v.clamp(u16::MIN as f32, u16::MAX as f32) as u16,
2420 Literal::F16(v) => f16::to_u16(&v.max(f16::ZERO)).unwrap(),
2421 Literal::Bool(v) => v as u16,
2422 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2423 return make_error();
2424 }
2425 Literal::AbstractInt(v) => v as u16,
2426 Literal::AbstractFloat(v) => v as u16,
2427 }),
2428 Sc::I32 => Literal::I32(match literal {
2429 Literal::I16(v) => v as i32,
2430 Literal::U16(v) => v as i32,
2431 Literal::I32(v) => v,
2432 Literal::U32(v) => v as i32,
2433 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2434 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2436 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2437 return make_error();
2438 }
2439 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2440 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2441 }),
2442 Sc::U32 => Literal::U32(match literal {
2443 Literal::I16(v) => v as u32,
2444 Literal::U16(v) => v as u32,
2445 Literal::I32(v) => v as u32,
2446 Literal::U32(v) => v,
2447 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2448 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2450 Literal::Bool(v) => v as u32,
2451 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2452 return make_error();
2453 }
2454 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2455 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2456 }),
2457 Sc::I64 => Literal::I64(match literal {
2458 Literal::I16(v) => v as i64,
2459 Literal::U16(v) => v as i64,
2460 Literal::I32(v) => v as i64,
2461 Literal::U32(v) => v as i64,
2462 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2463 Literal::Bool(v) => v as i64,
2464 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2465 Literal::I64(v) => v,
2466 Literal::U64(v) => v as i64,
2467 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2469 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2470 }),
2471 Sc::U64 => Literal::U64(match literal {
2472 Literal::I16(v) => v as u64,
2473 Literal::U16(v) => v as u64,
2474 Literal::I32(v) => v as u64,
2475 Literal::U32(v) => v as u64,
2476 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2477 Literal::Bool(v) => v as u64,
2478 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2479 Literal::I64(v) => v as u64,
2480 Literal::U64(v) => v,
2481 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2483 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2484 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2485 }),
2486 Sc::F16 => Literal::F16(match literal {
2487 Literal::F16(v) => v,
2488 Literal::F32(v) => f16::from_f32(v),
2489 Literal::F64(v) => f16::from_f64(v),
2490 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2491 Literal::I16(v) => f16::from_f32(v as f32),
2492 Literal::U16(v) => f16::from_f32(v as f32),
2493 Literal::I64(v) => f16::from_i64(v).unwrap(),
2494 Literal::U64(v) => f16::from_u64(v).unwrap(),
2495 Literal::I32(v) => f16::from_i32(v).unwrap(),
2496 Literal::U32(v) => f16::from_u32(v).unwrap(),
2497 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2498 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2499 }),
2500 Sc::F32 => Literal::F32(match literal {
2501 Literal::I16(v) => v as f32,
2502 Literal::U16(v) => v as f32,
2503 Literal::I32(v) => v as f32,
2504 Literal::U32(v) => v as f32,
2505 Literal::F32(v) => v,
2506 Literal::Bool(v) => v as u32 as f32,
2507 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2508 return make_error();
2509 }
2510 Literal::F16(v) => f16::to_f32(v),
2511 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2512 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2513 }),
2514 Sc::F64 => Literal::F64(match literal {
2515 Literal::I16(v) => v as f64,
2516 Literal::U16(v) => v as f64,
2517 Literal::I32(v) => v as f64,
2518 Literal::U32(v) => v as f64,
2519 Literal::F16(v) => f16::to_f64(v),
2520 Literal::F32(v) => v as f64,
2521 Literal::F64(v) => v,
2522 Literal::Bool(v) => v as u32 as f64,
2523 Literal::I64(_) | Literal::U64(_) => return make_error(),
2524 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2525 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2526 }),
2527 Sc::BOOL => Literal::Bool(match literal {
2528 Literal::I16(v) => v != 0,
2529 Literal::U16(v) => v != 0,
2530 Literal::I32(v) => v != 0,
2531 Literal::U32(v) => v != 0,
2532 Literal::F32(v) => v != 0.0,
2533 Literal::F16(v) => v != f16::zero(),
2534 Literal::Bool(v) => v,
2535 Literal::AbstractInt(v) => v != 0,
2536 Literal::AbstractFloat(v) => v != 0.0,
2537 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2538 return make_error();
2539 }
2540 }),
2541 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2542 Literal::AbstractInt(v) => {
2543 v as f64
2548 }
2549 Literal::AbstractFloat(v) => v,
2550 _ => return make_error(),
2551 }),
2552 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2553 Literal::AbstractInt(v) => v,
2554 _ => return make_error(),
2555 }),
2556 _ => {
2557 log::debug!("Constant evaluator refused to convert value to {target:?}");
2558 return make_error();
2559 }
2560 };
2561 Expression::Literal(literal)
2562 }
2563 Expression::Compose {
2564 ty,
2565 components: ref src_components,
2566 } => {
2567 let ty_inner = match self.types[ty].inner {
2568 TypeInner::Vector { size, .. } => TypeInner::Vector {
2569 size,
2570 scalar: target,
2571 },
2572 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2573 columns,
2574 rows,
2575 scalar: target,
2576 },
2577 _ => return make_error(),
2578 };
2579
2580 let mut components = src_components.clone();
2581 for component in &mut components {
2582 *component = self.cast(*component, target, span)?;
2583 }
2584
2585 let ty = self.types.insert(
2586 Type {
2587 name: None,
2588 inner: ty_inner,
2589 },
2590 span,
2591 );
2592
2593 Expression::Compose { ty, components }
2594 }
2595 Expression::Splat { size, value } => {
2596 let value_span = self.expressions.get_span(value);
2597 let cast_value = self.cast(value, target, value_span)?;
2598 Expression::Splat {
2599 size,
2600 value: cast_value,
2601 }
2602 }
2603 _ => return make_error(),
2604 };
2605
2606 self.register_evaluated_expr(expr, span)
2607 }
2608
2609 pub fn cast_array(
2622 &mut self,
2623 expr: Handle<Expression>,
2624 target: crate::Scalar,
2625 span: Span,
2626 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2627 let expr = self.check_and_get(expr)?;
2628
2629 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2630 return self.cast(expr, target, span);
2631 };
2632
2633 let TypeInner::Array {
2634 base: _,
2635 size,
2636 stride: _,
2637 } = self.types[ty].inner
2638 else {
2639 return self.cast(expr, target, span);
2640 };
2641
2642 let mut components = components.clone();
2643 for component in &mut components {
2644 *component = self.cast_array(*component, target, span)?;
2645 }
2646
2647 let first = components.first().unwrap();
2648 let new_base = match self.resolve_type(*first)? {
2649 crate::proc::TypeResolution::Handle(ty) => ty,
2650 crate::proc::TypeResolution::Value(inner) => {
2651 self.types.insert(Type { name: None, inner }, span)
2652 }
2653 };
2654 let mut layouter = core::mem::take(self.layouter);
2655 layouter.update(self.to_ctx()).unwrap();
2656 *self.layouter = layouter;
2657
2658 let new_base_stride = self.layouter[new_base].to_stride();
2659 let new_array_ty = self.types.insert(
2660 Type {
2661 name: None,
2662 inner: TypeInner::Array {
2663 base: new_base,
2664 size,
2665 stride: new_base_stride,
2666 },
2667 },
2668 span,
2669 );
2670
2671 let compose = Expression::Compose {
2672 ty: new_array_ty,
2673 components,
2674 };
2675 self.register_evaluated_expr(compose, span)
2676 }
2677
2678 fn unary_op(
2679 &mut self,
2680 op: UnaryOperator,
2681 expr: Handle<Expression>,
2682 span: Span,
2683 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2684 let expr = self.eval_zero_value_and_splat(expr, span)?;
2685
2686 let expr = match self.expressions[expr] {
2687 Expression::Literal(value) => Expression::Literal(match op {
2688 UnaryOperator::Negate => match value {
2689 Literal::I16(v) => Literal::I16(v.wrapping_neg()),
2690 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2691 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2692 Literal::F32(v) => Literal::F32(-v),
2693 Literal::F16(v) => Literal::F16(-v),
2694 Literal::F64(v) => Literal::F64(-v),
2695 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2696 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2697 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2698 },
2699 UnaryOperator::LogicalNot => match value {
2700 Literal::Bool(v) => Literal::Bool(!v),
2701 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2702 },
2703 UnaryOperator::BitwiseNot => match value {
2704 Literal::I16(v) => Literal::I16(!v),
2705 Literal::I32(v) => Literal::I32(!v),
2706 Literal::I64(v) => Literal::I64(!v),
2707 Literal::U16(v) => Literal::U16(!v),
2708 Literal::U32(v) => Literal::U32(!v),
2709 Literal::U64(v) => Literal::U64(!v),
2710 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2711 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2712 },
2713 }),
2714 Expression::Compose {
2715 ty,
2716 components: ref src_components,
2717 } => {
2718 match self.types[ty].inner {
2719 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2720 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2721 }
2722
2723 let mut components = src_components.clone();
2724 for component in &mut components {
2725 *component = self.unary_op(op, *component, span)?;
2726 }
2727
2728 Expression::Compose { ty, components }
2729 }
2730 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2731 };
2732
2733 self.register_evaluated_expr(expr, span)
2734 }
2735
2736 fn binary_op(
2737 &mut self,
2738 op: BinaryOperator,
2739 left: Handle<Expression>,
2740 right: Handle<Expression>,
2741 span: Span,
2742 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2743 let left = self.eval_zero_value_and_splat(left, span)?;
2744 let right = self.eval_zero_value_and_splat(right, span)?;
2745
2746 let expr = match (&self.expressions[left], &self.expressions[right]) {
2751 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2752 if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2753 && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2754 {
2755 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2756 }
2757
2758 if matches!(
2759 (left_value, op),
2760 (
2761 Literal::Bool(_),
2762 BinaryOperator::Less
2763 | BinaryOperator::LessEqual
2764 | BinaryOperator::Greater
2765 | BinaryOperator::GreaterEqual
2766 )
2767 ) {
2768 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2769 }
2770
2771 let literal = match op {
2772 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2773 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2774 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2775 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2776 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2777 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2778
2779 _ => match (left_value, right_value) {
2780 (Literal::I16(a), Literal::I16(b)) => Literal::I16(match op {
2781 BinaryOperator::Add => a.wrapping_add(b),
2782 BinaryOperator::Subtract => a.wrapping_sub(b),
2783 BinaryOperator::Multiply => a.wrapping_mul(b),
2784 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2785 if b == 0 {
2786 ConstantEvaluatorError::DivisionByZero
2787 } else {
2788 ConstantEvaluatorError::Overflow("division".into())
2789 }
2790 })?,
2791 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2792 if b == 0 {
2793 ConstantEvaluatorError::RemainderByZero
2794 } else {
2795 ConstantEvaluatorError::Overflow("remainder".into())
2796 }
2797 })?,
2798 BinaryOperator::And => a & b,
2799 BinaryOperator::ExclusiveOr => a ^ b,
2800 BinaryOperator::InclusiveOr => a | b,
2801 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2802 }),
2803 (Literal::I16(a), Literal::U32(b)) => Literal::I16(match op {
2804 BinaryOperator::ShiftLeft => {
2805 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2806 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2807 }
2808 a.checked_shl(b)
2809 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2810 }
2811 BinaryOperator::ShiftRight => a
2812 .checked_shr(b)
2813 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2814 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2815 }),
2816 (Literal::U16(a), Literal::U16(b)) => Literal::U16(match op {
2817 BinaryOperator::Add => a.wrapping_add(b),
2818 BinaryOperator::Subtract => a.wrapping_sub(b),
2819 BinaryOperator::Multiply => a.wrapping_mul(b),
2820 BinaryOperator::Divide => a
2821 .checked_div(b)
2822 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2823 BinaryOperator::Modulo => a
2824 .checked_rem(b)
2825 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2826 BinaryOperator::And => a & b,
2827 BinaryOperator::ExclusiveOr => a ^ b,
2828 BinaryOperator::InclusiveOr => a | b,
2829 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2830 }),
2831 (Literal::U16(a), Literal::U32(b)) => Literal::U16(match op {
2832 BinaryOperator::ShiftLeft => a
2833 .checked_mul(
2834 1u16.checked_shl(b)
2835 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2836 )
2837 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2838 BinaryOperator::ShiftRight => a
2839 .checked_shr(b)
2840 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2841 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2842 }),
2843 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2844 BinaryOperator::Add => a.wrapping_add(b),
2845 BinaryOperator::Subtract => a.wrapping_sub(b),
2846 BinaryOperator::Multiply => a.wrapping_mul(b),
2847 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2848 if b == 0 {
2849 ConstantEvaluatorError::DivisionByZero
2850 } else {
2851 ConstantEvaluatorError::Overflow("division".into())
2852 }
2853 })?,
2854 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2855 if b == 0 {
2856 ConstantEvaluatorError::RemainderByZero
2857 } else {
2858 ConstantEvaluatorError::Overflow("remainder".into())
2859 }
2860 })?,
2861 BinaryOperator::And => a & b,
2862 BinaryOperator::ExclusiveOr => a ^ b,
2863 BinaryOperator::InclusiveOr => a | b,
2864 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2865 }),
2866 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2867 BinaryOperator::ShiftLeft => {
2868 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2869 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2870 }
2871 a.checked_shl(b)
2872 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2873 }
2874 BinaryOperator::ShiftRight => a
2875 .checked_shr(b)
2876 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2877 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2878 }),
2879 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2880 BinaryOperator::Add => a.wrapping_add(b),
2881 BinaryOperator::Subtract => a.wrapping_sub(b),
2882 BinaryOperator::Multiply => a.wrapping_mul(b),
2883 BinaryOperator::Divide => a
2884 .checked_div(b)
2885 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2886 BinaryOperator::Modulo => a
2887 .checked_rem(b)
2888 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2889 BinaryOperator::And => a & b,
2890 BinaryOperator::ExclusiveOr => a ^ b,
2891 BinaryOperator::InclusiveOr => a | b,
2892 BinaryOperator::ShiftLeft => a
2893 .checked_mul(
2894 1u32.checked_shl(b)
2895 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2896 )
2897 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2898 BinaryOperator::ShiftRight => a
2899 .checked_shr(b)
2900 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2901 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2902 }),
2903 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2904 BinaryOperator::Add => a + b,
2905 BinaryOperator::Subtract => a - b,
2906 BinaryOperator::Multiply => a * b,
2907 BinaryOperator::Divide => a / b,
2908 BinaryOperator::Modulo => a % b,
2909 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2910 }),
2911 (Literal::AbstractInt(a), Literal::U32(b)) => {
2912 Literal::AbstractInt(match op {
2913 BinaryOperator::ShiftLeft => {
2914 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2915 return Err(ConstantEvaluatorError::Overflow(
2916 "<<".to_string(),
2917 ));
2918 }
2919 a.checked_shl(b).unwrap_or(0)
2920 }
2921 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2922 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2923 })
2924 }
2925 (Literal::F16(a), Literal::F16(b)) => {
2926 let result = match op {
2927 BinaryOperator::Add => a + b,
2928 BinaryOperator::Subtract => a - b,
2929 BinaryOperator::Multiply => a * b,
2930 BinaryOperator::Divide => a / b,
2931 BinaryOperator::Modulo => a % b,
2932 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2933 };
2934 if !result.is_finite() {
2935 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2936 }
2937 Literal::F16(result)
2938 }
2939 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2940 Literal::AbstractInt(match op {
2941 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2942 ConstantEvaluatorError::Overflow("addition".into())
2943 })?,
2944 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2945 ConstantEvaluatorError::Overflow("subtraction".into())
2946 })?,
2947 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2948 ConstantEvaluatorError::Overflow("multiplication".into())
2949 })?,
2950 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2951 if b == 0 {
2952 ConstantEvaluatorError::DivisionByZero
2953 } else {
2954 ConstantEvaluatorError::Overflow("division".into())
2955 }
2956 })?,
2957 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2958 if b == 0 {
2959 ConstantEvaluatorError::RemainderByZero
2960 } else {
2961 ConstantEvaluatorError::Overflow("remainder".into())
2962 }
2963 })?,
2964 BinaryOperator::And => a & b,
2965 BinaryOperator::ExclusiveOr => a ^ b,
2966 BinaryOperator::InclusiveOr => a | b,
2967 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2968 })
2969 }
2970 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2971 let result = match op {
2972 BinaryOperator::Add => a + b,
2973 BinaryOperator::Subtract => a - b,
2974 BinaryOperator::Multiply => a * b,
2975 BinaryOperator::Divide => a / b,
2976 BinaryOperator::Modulo => a % b,
2977 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2978 };
2979 if !result.is_finite() {
2980 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2981 }
2982 Literal::AbstractFloat(result)
2983 }
2984 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2985 BinaryOperator::LogicalAnd => a && b,
2986 BinaryOperator::LogicalOr => a || b,
2987 BinaryOperator::And => a & b,
2988 BinaryOperator::InclusiveOr => a | b,
2989 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2990 }),
2991 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2992 },
2993 };
2994 Expression::Literal(literal)
2995 }
2996 (
2997 &Expression::Compose {
2998 components: ref src_components,
2999 ty,
3000 },
3001 &Expression::Literal(_),
3002 ) => {
3003 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3004 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3005 }
3006 let mut components = src_components.clone();
3007 for component in &mut components {
3008 *component = self.binary_op(op, *component, right, span)?;
3009 }
3010 Expression::Compose { ty, components }
3011 }
3012 (
3013 &Expression::Literal(_),
3014 &Expression::Compose {
3015 components: ref src_components,
3016 ty,
3017 },
3018 ) => {
3019 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
3020 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3021 }
3022 let mut components = src_components.clone();
3023 for component in &mut components {
3024 *component = self.binary_op(op, left, *component, span)?;
3025 }
3026 Expression::Compose { ty, components }
3027 }
3028 (
3029 &Expression::Compose {
3030 components: ref left_components,
3031 ty: left_ty,
3032 },
3033 &Expression::Compose {
3034 components: ref right_components,
3035 ty: right_ty,
3036 },
3037 ) => {
3038 let left_flattened = crate::proc::flatten_compose(
3042 left_ty,
3043 left_components,
3044 self.expressions,
3045 self.types,
3046 )
3047 .collect::<Vec<_>>();
3048 let right_flattened = crate::proc::flatten_compose(
3049 right_ty,
3050 right_components,
3051 self.expressions,
3052 self.types,
3053 )
3054 .collect::<Vec<_>>();
3055
3056 self.binary_op_compose(
3057 op,
3058 &left_flattened,
3059 &right_flattened,
3060 left_ty,
3061 right_ty,
3062 span,
3063 )?
3064 }
3065 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3066 };
3067
3068 return self.register_evaluated_expr(expr, span);
3069
3070 fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
3071 let is_numeric_vec = matches!(
3072 compose_ty, TypeInner::Vector { scalar, .. }
3073 if scalar.kind != ScalarKind::Bool
3074 );
3075 let is_allowed_vec_scalar_op = matches!(
3076 op,
3077 BinaryOperator::Add
3078 | BinaryOperator::Subtract
3079 | BinaryOperator::Multiply
3080 | BinaryOperator::Divide
3081 | BinaryOperator::Modulo
3082 );
3083 let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
3084 let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
3085 is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
3086 }
3087 }
3088
3089 fn binary_op_compose(
3090 &mut self,
3091 op: BinaryOperator,
3092 left_components: &[Handle<Expression>],
3093 right_components: &[Handle<Expression>],
3094 left_ty: Handle<Type>,
3095 right_ty: Handle<Type>,
3096 span: Span,
3097 ) -> Result<Expression, ConstantEvaluatorError> {
3098 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
3099 (
3101 &TypeInner::Vector {
3102 size: left_size, ..
3103 },
3104 &TypeInner::Vector {
3105 size: right_size, ..
3106 },
3107 ) if left_size == right_size => self.binary_op_vector(
3108 op,
3109 left_size,
3110 left_components,
3111 right_components,
3112 left_ty,
3113 span,
3114 ),
3115 (
3117 &TypeInner::Vector { size, .. },
3118 &TypeInner::Matrix {
3119 columns,
3120 rows,
3121 scalar,
3122 },
3123 ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3124 left_components,
3125 right_components,
3126 columns,
3127 scalar,
3128 span,
3129 ),
3130 (
3132 &TypeInner::Matrix {
3133 columns,
3134 rows,
3135 scalar,
3136 },
3137 &TypeInner::Vector { size, .. },
3138 ) if op == BinaryOperator::Multiply && size == columns => {
3139 self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3140 }
3141 (
3143 &TypeInner::Matrix {
3144 columns: left_columns,
3145 rows: left_rows,
3146 scalar,
3147 },
3148 &TypeInner::Matrix {
3149 columns: right_columns,
3150 rows: right_rows,
3151 ..
3152 },
3153 ) => match op {
3154 BinaryOperator::Add | BinaryOperator::Subtract
3155 if left_columns == right_columns && left_rows == right_rows =>
3156 {
3157 let components = left_components
3158 .iter()
3159 .zip(right_components)
3160 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3161 .collect::<Result<Vec<_>, _>>()?;
3162 Ok(Expression::Compose {
3163 ty: left_ty,
3164 components,
3165 })
3166 }
3167 BinaryOperator::Multiply if left_columns == right_rows => self
3168 .multiply_matrix_matrix(
3169 left_components,
3170 right_components,
3171 left_rows,
3172 right_columns,
3173 scalar,
3174 span,
3175 ),
3176 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3177 },
3178 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3179 }
3180 }
3181
3182 fn binary_op_vector(
3183 &mut self,
3184 op: BinaryOperator,
3185 size: crate::VectorSize,
3186 left_components: &[Handle<Expression>],
3187 right_components: &[Handle<Expression>],
3188 left_ty: Handle<Type>,
3189 span: Span,
3190 ) -> Result<Expression, ConstantEvaluatorError> {
3191 let ty = match op {
3192 BinaryOperator::Equal
3194 | BinaryOperator::NotEqual
3195 | BinaryOperator::Less
3196 | BinaryOperator::LessEqual
3197 | BinaryOperator::Greater
3198 | BinaryOperator::GreaterEqual => self.types.insert(
3199 Type {
3200 name: None,
3201 inner: TypeInner::Vector {
3202 size,
3203 scalar: crate::Scalar::BOOL,
3204 },
3205 },
3206 span,
3207 ),
3208
3209 BinaryOperator::Add
3212 | BinaryOperator::Subtract
3213 | BinaryOperator::Multiply
3214 | BinaryOperator::Divide
3215 | BinaryOperator::Modulo
3216 | BinaryOperator::And
3217 | BinaryOperator::ExclusiveOr
3218 | BinaryOperator::InclusiveOr
3219 | BinaryOperator::ShiftLeft
3220 | BinaryOperator::ShiftRight => left_ty,
3221
3222 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3223 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3225 }
3226 };
3227
3228 let components = left_components
3229 .iter()
3230 .zip(right_components)
3231 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3232 .collect::<Result<Vec<_>, _>>()?;
3233
3234 Ok(Expression::Compose { ty, components })
3235 }
3236
3237 fn multiply_vector_matrix(
3238 &mut self,
3239 vec_components: &[Handle<Expression>],
3240 mat_components: &[Handle<Expression>],
3241 mat_columns: crate::VectorSize,
3242 scalar: crate::Scalar,
3243 span: Span,
3244 ) -> Result<Expression, ConstantEvaluatorError> {
3245 let ty = self.types.insert(
3246 Type {
3247 name: None,
3248 inner: TypeInner::Vector {
3249 size: mat_columns,
3250 scalar,
3251 },
3252 },
3253 span,
3254 );
3255 let components = mat_components
3256 .iter()
3257 .map(|&column| {
3258 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3259 unreachable!()
3260 };
3261 self.dot_exprs(
3262 vec_components.iter().cloned(),
3263 components.clone().into_iter(),
3264 span,
3265 )
3266 })
3267 .collect::<Result<Vec<_>, _>>()?;
3268 Ok(Expression::Compose { ty, components })
3269 }
3270
3271 fn multiply_matrix_vector(
3272 &mut self,
3273 mat_components: &[Handle<Expression>],
3274 vec_components: &[Handle<Expression>],
3275 mat_rows: crate::VectorSize,
3276 scalar: crate::Scalar,
3277 span: Span,
3278 ) -> Result<Expression, ConstantEvaluatorError> {
3279 let ty = self.types.insert(
3280 Type {
3281 name: None,
3282 inner: TypeInner::Vector {
3283 size: mat_rows,
3284 scalar,
3285 },
3286 },
3287 span,
3288 );
3289
3290 let flatten = self.flatten_matrix(mat_components);
3291 let nr = mat_rows as usize;
3292 let components = (0..nr)
3293 .map(|r| {
3294 let row = flatten.iter().skip(r).step_by(nr).cloned();
3295 self.dot_exprs(row, vec_components.iter().cloned(), span)
3296 })
3297 .collect::<Result<Vec<_>, _>>()?;
3298 Ok(Expression::Compose { ty, components })
3299 }
3300
3301 fn multiply_matrix_matrix(
3302 &mut self,
3303 left_components: &[Handle<Expression>],
3304 right_components: &[Handle<Expression>],
3305 left_rows: crate::VectorSize,
3306 right_columns: crate::VectorSize,
3307 scalar: crate::Scalar,
3308 span: Span,
3309 ) -> Result<Expression, ConstantEvaluatorError> {
3310 let left_nc = left_components.len();
3311 let left_nr = left_rows as usize;
3312 let right_nc = right_columns as usize;
3313 let right_nr = left_nc;
3314
3315 let mut result = Vec::with_capacity(right_nc);
3316 let result_ty = self.types.insert(
3317 Type {
3318 name: None,
3319 inner: TypeInner::Matrix {
3320 columns: right_columns,
3321 rows: left_rows,
3322 scalar,
3323 },
3324 },
3325 span,
3326 );
3327 let result_column_ty = self.types.insert(
3328 Type {
3329 name: None,
3330 inner: TypeInner::Vector {
3331 size: left_rows,
3332 scalar,
3333 },
3334 },
3335 span,
3336 );
3337
3338 let left_flattened = self.flatten_matrix(left_components);
3339 let right_flattened = self.flatten_matrix(right_components);
3340 for c in 0..right_nc {
3341 let result_column = (0..left_nr)
3342 .map(|r| {
3343 let row = left_flattened.iter().skip(r).step_by(left_nr);
3344 let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3345 self.dot_exprs(row.cloned(), column.cloned(), span)
3346 })
3347 .collect::<Result<Vec<_>, _>>()?;
3348 let expr = Expression::Compose {
3349 ty: result_column_ty,
3350 components: result_column,
3351 };
3352 let handle = self.register_evaluated_expr(expr, span)?;
3353 result.push(handle);
3354 }
3355 Ok(Expression::Compose {
3356 ty: result_ty,
3357 components: result,
3358 })
3359 }
3360
3361 fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3362 let mut flattened = ArrayVec::<_, 16>::new();
3363 for &column in columns {
3364 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3365 unreachable!()
3366 };
3367 flattened.extend(components.iter().cloned());
3368 }
3369 flattened
3370 }
3371
3372 fn dot_exprs(
3373 &mut self,
3374 left: impl Iterator<Item = Handle<Expression>>,
3375 right: impl Iterator<Item = Handle<Expression>>,
3376 span: Span,
3377 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3378 let mut acc = None;
3379 for (l, r) in left.zip(right) {
3380 let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3381 match acc.as_mut() {
3382 Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3383 None => acc = Some(result),
3384 }
3385 }
3386 Ok(acc.unwrap())
3387 }
3388
3389 fn relational(
3390 &mut self,
3391 fun: RelationalFunction,
3392 arg: Handle<Expression>,
3393 span: Span,
3394 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3395 let arg = self.eval_zero_value_and_splat(arg, span)?;
3396 match fun {
3397 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3398 Expression::Literal(Literal::Bool(_)) => Ok(arg),
3399 Expression::Compose { ty, ref components }
3400 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3401 {
3402 let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3403 for component in
3404 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3405 {
3406 match self.expressions[component] {
3407 Expression::Literal(Literal::Bool(val)) => {
3408 bool_components.push(val);
3409 }
3410 _ => {
3411 return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3412 }
3413 }
3414 }
3415 let components = bool_components;
3416 let result = match fun {
3417 RelationalFunction::All => components.iter().all(|c| *c),
3418 RelationalFunction::Any => components.iter().any(|c| *c),
3419 _ => unreachable!(),
3420 };
3421 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3422 }
3423 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3424 },
3425 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3426 "{fun:?} built-in function"
3427 ))),
3428 }
3429 }
3430
3431 fn copy_from(
3439 &mut self,
3440 expr: Handle<Expression>,
3441 expressions: &Arena<Expression>,
3442 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3443 let span = expressions.get_span(expr);
3444 match expressions[expr] {
3445 ref expr @ (Expression::Literal(_)
3446 | Expression::Constant(_)
3447 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3448 Expression::Compose { ty, ref components } => {
3449 let mut components = components.clone();
3450 for component in &mut components {
3451 *component = self.copy_from(*component, expressions)?;
3452 }
3453 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3454 }
3455 Expression::Splat { size, value } => {
3456 let value = self.copy_from(value, expressions)?;
3457 self.register_evaluated_expr(Expression::Splat { size, value }, span)
3458 }
3459 _ => {
3460 log::debug!("copy_from: SubexpressionsAreNotConstant");
3461 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3462 }
3463 }
3464 }
3465
3466 fn vector_compose_flattened_size(
3468 &self,
3469 components: &[Handle<Expression>],
3470 ) -> Result<usize, ConstantEvaluatorError> {
3471 components
3472 .iter()
3473 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3474 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3475 TypeInner::Scalar(_) => 1,
3476 TypeInner::Vector { size, .. } => size as usize,
3480 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3481 };
3482 Ok(acc + size)
3483 })
3484 }
3485
3486 fn register_evaluated_expr(
3487 &mut self,
3488 expr: Expression,
3489 span: Span,
3490 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3491 if let Expression::Literal(literal) = expr {
3496 crate::valid::check_literal_value(literal)?;
3497 }
3498
3499 if let Expression::Compose { ty, ref components } = expr {
3503 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3504 let expected = size as usize;
3505 let actual = self.vector_compose_flattened_size(components)?;
3506 if expected != actual {
3507 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3508 expected,
3509 actual,
3510 });
3511 }
3512 }
3513 }
3514
3515 Ok(self.append_expr(expr, span, ExpressionKind::Const))
3516 }
3517
3518 fn append_expr(
3519 &mut self,
3520 expr: Expression,
3521 span: Span,
3522 expr_type: ExpressionKind,
3523 ) -> Handle<Expression> {
3524 let h = match self.behavior {
3525 Behavior::Wgsl(
3526 WgslRestrictions::Runtime(ref mut function_local_data)
3527 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3528 )
3529 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3530 let is_running = function_local_data.emitter.is_running();
3531 let needs_pre_emit = expr.needs_pre_emit();
3532 if is_running && needs_pre_emit {
3533 function_local_data
3534 .block
3535 .extend(function_local_data.emitter.finish(self.expressions));
3536 let h = self.expressions.append(expr, span);
3537 function_local_data.emitter.start(self.expressions);
3538 h
3539 } else {
3540 self.expressions.append(expr, span)
3541 }
3542 }
3543 _ => self.expressions.append(expr, span),
3544 };
3545 self.expression_kind_tracker.insert(h, expr_type);
3546 h
3547 }
3548
3549 fn resolve_type(
3554 &self,
3555 expr: Handle<Expression>,
3556 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3557 use crate::proc::TypeResolution as Tr;
3558 use crate::Expression as Ex;
3559 let resolution = match self.expressions[expr] {
3560 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3561 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3562 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3563 Ex::Splat { size, value } => {
3564 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3565 return Err(ConstantEvaluatorError::SplatScalarOnly);
3566 };
3567 Tr::Value(TypeInner::Vector { scalar, size })
3568 }
3569 _ => {
3570 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3571 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3572 }
3573 };
3574
3575 Ok(resolution)
3576 }
3577
3578 fn select(
3579 &mut self,
3580 reject: Handle<Expression>,
3581 accept: Handle<Expression>,
3582 condition: Handle<Expression>,
3583 span: Span,
3584 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3585 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3586
3587 let reject = arg(reject)?;
3588 let accept = arg(accept)?;
3589 let condition = arg(condition)?;
3590
3591 let select_single_component =
3592 |this: &mut Self, reject_scalar, reject, accept, condition| {
3593 let accept = this.cast(accept, reject_scalar, span)?;
3594 if condition {
3595 Ok(accept)
3596 } else {
3597 Ok(reject)
3598 }
3599 };
3600
3601 match (&self.expressions[reject], &self.expressions[accept]) {
3602 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3603 let reject_scalar = reject_lit.scalar();
3604 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3605 else {
3606 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3607 };
3608 select_single_component(self, reject_scalar, reject, accept, condition)
3609 }
3610 (
3611 &Expression::Compose {
3612 ty: reject_ty,
3613 components: ref reject_components,
3614 },
3615 &Expression::Compose {
3616 ty: accept_ty,
3617 components: ref accept_components,
3618 },
3619 ) => {
3620 let ty_deets = |ty| {
3621 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3622 (size.unwrap(), scalar)
3623 };
3624
3625 let expected_vec_size = {
3626 let [(reject_vec_size, _), (accept_vec_size, _)] =
3627 [reject_ty, accept_ty].map(ty_deets);
3628
3629 if reject_vec_size != accept_vec_size {
3630 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3631 reject: reject_vec_size,
3632 accept: accept_vec_size,
3633 });
3634 }
3635 reject_vec_size
3636 };
3637
3638 let condition_components = match self.expressions[condition] {
3639 Expression::Literal(Literal::Bool(condition)) => {
3640 vec![condition; (expected_vec_size as u8).into()]
3641 }
3642 Expression::Compose {
3643 ty: condition_ty,
3644 components: ref condition_components,
3645 } => {
3646 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3647 if condition_scalar.kind != ScalarKind::Bool {
3648 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3649 }
3650 if condition_vec_size != expected_vec_size {
3651 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3652 }
3653 condition_components
3654 .iter()
3655 .copied()
3656 .map(|component| match &self.expressions[component] {
3657 &Expression::Literal(Literal::Bool(condition)) => condition,
3658 _ => unreachable!(),
3659 })
3660 .collect()
3661 }
3662
3663 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3664 };
3665
3666 let evaluated = Expression::Compose {
3667 ty: reject_ty,
3668 components: reject_components
3669 .clone()
3670 .into_iter()
3671 .zip(accept_components.clone().into_iter())
3672 .zip(condition_components.into_iter())
3673 .map(|((reject, accept), condition)| {
3674 let reject_scalar = match &self.expressions[reject] {
3675 &Expression::Literal(lit) => lit.scalar(),
3676 _ => unreachable!(),
3677 };
3678 select_single_component(self, reject_scalar, reject, accept, condition)
3679 })
3680 .collect::<Result<_, _>>()?,
3681 };
3682 self.register_evaluated_expr(evaluated, span)
3683 }
3684 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3685 }
3686 }
3687}
3688
3689fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3690 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3694 match e {
3695 idx @ 0..=31 => idx,
3696 32 => u32::MAX,
3697 _ => unreachable!(),
3698 }
3699 };
3700 match concrete_int {
3701 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3702 ConcreteInt::I32([e]) => {
3703 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3704 }
3705 }
3706}
3707
3708#[test]
3709fn first_trailing_bit_smoke() {
3710 assert_eq!(
3711 first_trailing_bit(ConcreteInt::I32([0])),
3712 ConcreteInt::I32([-1])
3713 );
3714 assert_eq!(
3715 first_trailing_bit(ConcreteInt::I32([1])),
3716 ConcreteInt::I32([0])
3717 );
3718 assert_eq!(
3719 first_trailing_bit(ConcreteInt::I32([2])),
3720 ConcreteInt::I32([1])
3721 );
3722 assert_eq!(
3723 first_trailing_bit(ConcreteInt::I32([-1])),
3724 ConcreteInt::I32([0]),
3725 );
3726 assert_eq!(
3727 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3728 ConcreteInt::I32([31]),
3729 );
3730 assert_eq!(
3731 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3732 ConcreteInt::I32([0]),
3733 );
3734 for idx in 0..32 {
3735 assert_eq!(
3736 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3737 ConcreteInt::I32([idx])
3738 )
3739 }
3740
3741 assert_eq!(
3742 first_trailing_bit(ConcreteInt::U32([0])),
3743 ConcreteInt::U32([u32::MAX])
3744 );
3745 assert_eq!(
3746 first_trailing_bit(ConcreteInt::U32([1])),
3747 ConcreteInt::U32([0])
3748 );
3749 assert_eq!(
3750 first_trailing_bit(ConcreteInt::U32([2])),
3751 ConcreteInt::U32([1])
3752 );
3753 assert_eq!(
3754 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3755 ConcreteInt::U32([31]),
3756 );
3757 assert_eq!(
3758 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3759 ConcreteInt::U32([0]),
3760 );
3761 for idx in 0..32 {
3762 assert_eq!(
3763 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3764 ConcreteInt::U32([idx])
3765 )
3766 }
3767}
3768
3769fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3770 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3774 match e {
3775 idx @ 0..=31 => 31 - idx,
3776 32 => u32::MAX,
3777 _ => unreachable!(),
3778 }
3779 };
3780 match concrete_int {
3781 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3782 let rtl_bit_index = if e.is_negative() {
3783 e.leading_ones()
3784 } else {
3785 e.leading_zeros()
3786 };
3787 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3788 }]),
3789 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3790 }
3791}
3792
3793#[test]
3794fn first_leading_bit_smoke() {
3795 assert_eq!(
3796 first_leading_bit(ConcreteInt::I32([-1])),
3797 ConcreteInt::I32([-1])
3798 );
3799 assert_eq!(
3800 first_leading_bit(ConcreteInt::I32([0])),
3801 ConcreteInt::I32([-1])
3802 );
3803 assert_eq!(
3804 first_leading_bit(ConcreteInt::I32([1])),
3805 ConcreteInt::I32([0])
3806 );
3807 assert_eq!(
3808 first_leading_bit(ConcreteInt::I32([-2])),
3809 ConcreteInt::I32([0])
3810 );
3811 assert_eq!(
3812 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3813 ConcreteInt::I32([12])
3814 );
3815 assert_eq!(
3816 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3817 ConcreteInt::I32([30])
3818 );
3819 assert_eq!(
3820 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3821 ConcreteInt::I32([30])
3822 );
3823 for idx in 0..(32 - 1) {
3825 assert_eq!(
3826 first_leading_bit(ConcreteInt::I32([1 << idx])),
3827 ConcreteInt::I32([idx])
3828 );
3829 }
3830 for idx in 1..(32 - 1) {
3831 assert_eq!(
3832 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3833 ConcreteInt::I32([idx - 1])
3834 );
3835 }
3836
3837 assert_eq!(
3838 first_leading_bit(ConcreteInt::U32([0])),
3839 ConcreteInt::U32([u32::MAX])
3840 );
3841 assert_eq!(
3842 first_leading_bit(ConcreteInt::U32([1])),
3843 ConcreteInt::U32([0])
3844 );
3845 assert_eq!(
3846 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3847 ConcreteInt::U32([31])
3848 );
3849 for idx in 0..32 {
3850 assert_eq!(
3851 first_leading_bit(ConcreteInt::U32([1 << idx])),
3852 ConcreteInt::U32([idx])
3853 )
3854 }
3855}
3856
3857trait TryFromAbstract<T>: Sized {
3859 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3881}
3882
3883impl TryFromAbstract<i64> for i32 {
3884 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3885 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3886 value: format!("{value:?}"),
3887 to_type: "i32",
3888 })
3889 }
3890}
3891
3892impl TryFromAbstract<i64> for u32 {
3893 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3894 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3895 value: format!("{value:?}"),
3896 to_type: "u32",
3897 })
3898 }
3899}
3900
3901impl TryFromAbstract<i64> for u64 {
3902 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3903 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3904 value: format!("{value:?}"),
3905 to_type: "u64",
3906 })
3907 }
3908}
3909
3910impl TryFromAbstract<i64> for i64 {
3911 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3912 Ok(value)
3913 }
3914}
3915
3916impl TryFromAbstract<i64> for f32 {
3917 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3918 let f = value as f32;
3919 Ok(f)
3923 }
3924}
3925
3926impl TryFromAbstract<f64> for f32 {
3927 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3928 let f = value as f32;
3929 if f.is_infinite() {
3930 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3931 value: format!("{value:?}"),
3932 to_type: "f32",
3933 });
3934 }
3935 Ok(f)
3936 }
3937}
3938
3939impl TryFromAbstract<i64> for f64 {
3940 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3941 let f = value as f64;
3942 Ok(f)
3946 }
3947}
3948
3949impl TryFromAbstract<f64> for f64 {
3950 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3951 Ok(value)
3952 }
3953}
3954
3955impl TryFromAbstract<f64> for i32 {
3956 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3957 Ok(value as i32)
3970 }
3971}
3972
3973impl TryFromAbstract<f64> for u32 {
3974 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3975 Ok(value as u32)
3978 }
3979}
3980
3981impl TryFromAbstract<f64> for i64 {
3982 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3983 use crate::proc::type_methods::IntFloatLimits;
3986 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3987 }
3988}
3989
3990impl TryFromAbstract<f64> for u64 {
3991 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3992 use crate::proc::type_methods::IntFloatLimits;
3995 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3996 }
3997}
3998
3999impl TryFromAbstract<f64> for f16 {
4000 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
4001 let f = f16::from_f64(value);
4002 if f.is_infinite() {
4003 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4004 value: format!("{value:?}"),
4005 to_type: "f16",
4006 });
4007 }
4008 Ok(f)
4009 }
4010}
4011
4012impl TryFromAbstract<i64> for f16 {
4013 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
4014 let f = f16::from_i64(value);
4015 if f.is_none() {
4016 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
4017 value: format!("{value:?}"),
4018 to_type: "f16",
4019 });
4020 }
4021 Ok(f.unwrap())
4022 }
4023}
4024
4025fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
4026where
4027 T: Copy,
4028 T: core::ops::Mul<T, Output = T>,
4029 T: core::ops::Sub<T, Output = T>,
4030{
4031 [
4032 a[1] * b[2] - a[2] * b[1],
4033 a[2] * b[0] - a[0] * b[2],
4034 a[0] * b[1] - a[1] * b[0],
4035 ]
4036}
4037
4038#[cfg(test)]
4039mod tests {
4040 use alloc::{vec, vec::Vec};
4041
4042 use crate::{
4043 Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
4044 Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
4045 };
4046
4047 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
4048
4049 #[test]
4050 fn unary_op() {
4051 let mut types = UniqueArena::new();
4052 let mut constants = Arena::new();
4053 let overrides = Arena::new();
4054 let mut global_expressions = Arena::new();
4055
4056 let scalar_ty = types.insert(
4057 Type {
4058 name: None,
4059 inner: TypeInner::Scalar(crate::Scalar::I32),
4060 },
4061 Default::default(),
4062 );
4063
4064 let vec_ty = types.insert(
4065 Type {
4066 name: None,
4067 inner: TypeInner::Vector {
4068 size: VectorSize::Bi,
4069 scalar: crate::Scalar::I32,
4070 },
4071 },
4072 Default::default(),
4073 );
4074
4075 let h = constants.append(
4076 Constant {
4077 name: None,
4078 ty: scalar_ty,
4079 init: global_expressions
4080 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4081 },
4082 Default::default(),
4083 );
4084
4085 let h1 = constants.append(
4086 Constant {
4087 name: None,
4088 ty: scalar_ty,
4089 init: global_expressions
4090 .append(Expression::Literal(Literal::I32(8)), Default::default()),
4091 },
4092 Default::default(),
4093 );
4094
4095 let vec_h = constants.append(
4096 Constant {
4097 name: None,
4098 ty: vec_ty,
4099 init: global_expressions.append(
4100 Expression::Compose {
4101 ty: vec_ty,
4102 components: vec![constants[h].init, constants[h1].init],
4103 },
4104 Default::default(),
4105 ),
4106 },
4107 Default::default(),
4108 );
4109
4110 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4111 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
4112
4113 let expr2 = Expression::Unary {
4114 op: UnaryOperator::Negate,
4115 expr,
4116 };
4117
4118 let expr3 = Expression::Unary {
4119 op: UnaryOperator::BitwiseNot,
4120 expr,
4121 };
4122
4123 let expr4 = Expression::Unary {
4124 op: UnaryOperator::BitwiseNot,
4125 expr: expr1,
4126 };
4127
4128 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4129 let mut solver = ConstantEvaluator {
4130 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4131 types: &mut types,
4132 constants: &constants,
4133 overrides: &overrides,
4134 expressions: &mut global_expressions,
4135 expression_kind_tracker,
4136 layouter: &mut crate::proc::Layouter::default(),
4137 };
4138
4139 let res1 = solver
4140 .try_eval_and_append(expr2, Default::default())
4141 .unwrap();
4142 let res2 = solver
4143 .try_eval_and_append(expr3, Default::default())
4144 .unwrap();
4145 let res3 = solver
4146 .try_eval_and_append(expr4, Default::default())
4147 .unwrap();
4148
4149 assert_eq!(
4150 global_expressions[res1],
4151 Expression::Literal(Literal::I32(-4))
4152 );
4153
4154 assert_eq!(
4155 global_expressions[res2],
4156 Expression::Literal(Literal::I32(!4))
4157 );
4158
4159 let res3_inner = &global_expressions[res3];
4160
4161 match *res3_inner {
4162 Expression::Compose {
4163 ref ty,
4164 ref components,
4165 } => {
4166 assert_eq!(*ty, vec_ty);
4167 let mut components_iter = components.iter().copied();
4168 assert_eq!(
4169 global_expressions[components_iter.next().unwrap()],
4170 Expression::Literal(Literal::I32(!4))
4171 );
4172 assert_eq!(
4173 global_expressions[components_iter.next().unwrap()],
4174 Expression::Literal(Literal::I32(!8))
4175 );
4176 assert!(components_iter.next().is_none());
4177 }
4178 _ => panic!("Expected vector"),
4179 }
4180 }
4181
4182 #[test]
4183 fn matrix_op() {
4184 let mut helper = MatrixTestHelper::new();
4185
4186 for nc in 2..=4 {
4187 for nr in 2..=4 {
4188 let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4191 let expected = (0..nc)
4192 .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4193 .collect::<Vec<f32>>();
4194 assert_eq!(evaluated, expected);
4195
4196 let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4199 let expected = (0..nr)
4200 .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4201 .collect::<Vec<f32>>();
4202 assert_eq!(evaluated, expected);
4203
4204 for k in 2..=4 {
4205 let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4208 let expected = (0..nc)
4209 .flat_map(|c| {
4210 (0..nr).map(move |r| {
4211 (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4212 })
4213 })
4214 .collect::<Vec<f32>>();
4215 assert_eq!(evaluated, expected);
4216 }
4217 }
4218 }
4219 }
4220
4221 struct MatrixTestHelper {
4225 types: UniqueArena<Type>,
4226 expressions: Arena<Expression>,
4227 vec_exprs: FastHashMap<usize, Handle<Expression>>,
4229 mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4231 }
4232
4233 impl MatrixTestHelper {
4234 fn new() -> Self {
4235 let mut types = UniqueArena::new();
4236 let mut expressions = Arena::new();
4237 let span = crate::Span::default();
4238
4239 let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4240 for c in 2..=4 {
4241 let vec_ty = types.insert(
4242 Type {
4243 name: None,
4244 inner: TypeInner::Vector {
4245 size: Self::int_to_vector_size(c),
4246 scalar: crate::Scalar::F32,
4247 },
4248 },
4249 span,
4250 );
4251 vec_tys.insert(c, vec_ty);
4252 for r in 2..=4 {
4253 let mat_ty = types.insert(
4254 Type {
4255 name: None,
4256 inner: TypeInner::Matrix {
4257 columns: Self::int_to_vector_size(c),
4258 rows: Self::int_to_vector_size(r),
4259 scalar: crate::Scalar::F32,
4260 },
4261 },
4262 span,
4263 );
4264 mat_tys.insert((c, r), mat_ty);
4265 }
4266 }
4267
4268 let mut lit_exprs = FastHashMap::default();
4269 for i in 0..16 {
4270 let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4271 lit_exprs.insert(i, expr);
4272 }
4273
4274 let mut vec_exprs = FastHashMap::default();
4275 for c in 2..=4 {
4276 let expr = expressions.append(
4277 Expression::Compose {
4278 ty: *vec_tys.get(&c).unwrap(),
4279 components: (0..c)
4280 .map(|i| *lit_exprs.get(&i).unwrap())
4281 .collect::<Vec<_>>(),
4282 },
4283 span,
4284 );
4285 vec_exprs.insert(c, expr);
4286 }
4287
4288 let mut mat_exprs = FastHashMap::default();
4289 for c in 2..=4 {
4290 for r in 2..=4 {
4291 let mut columns = Vec::with_capacity(c);
4292 for cc in 0..c {
4293 let start = cc * r;
4294 let expr = expressions.append(
4295 Expression::Compose {
4296 ty: *vec_tys.get(&r).unwrap(),
4297 components: (start..start + r)
4298 .map(|i| *lit_exprs.get(&i).unwrap())
4299 .collect::<Vec<_>>(),
4300 },
4301 span,
4302 );
4303 columns.push(expr);
4304 }
4305
4306 let expr = expressions.append(
4307 Expression::Compose {
4308 ty: *mat_tys.get(&(c, r)).unwrap(),
4309 components: columns,
4310 },
4311 span,
4312 );
4313 mat_exprs.insert((c, r), expr);
4314 }
4315 }
4316
4317 Self {
4318 types,
4319 expressions,
4320 vec_exprs,
4321 mat_exprs,
4322 }
4323 }
4324
4325 fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4327 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4328 let mut solver = ConstantEvaluator {
4329 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4330 types: &mut self.types,
4331 constants: &Arena::new(),
4332 overrides: &Arena::new(),
4333 expressions: &mut self.expressions,
4334 expression_kind_tracker,
4335 layouter: &mut crate::proc::Layouter::default(),
4336 };
4337
4338 let result = solver
4339 .try_eval_and_append(
4340 Expression::Binary {
4341 op: BinaryOperator::Multiply,
4342 left: *self.vec_exprs.get(&nr).unwrap(),
4343 right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4344 },
4345 Default::default(),
4346 )
4347 .unwrap();
4348 self.flatten(result)
4349 }
4350
4351 fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4353 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4354 let mut solver = ConstantEvaluator {
4355 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4356 types: &mut self.types,
4357 constants: &Arena::new(),
4358 overrides: &Arena::new(),
4359 expressions: &mut self.expressions,
4360 expression_kind_tracker,
4361 layouter: &mut crate::proc::Layouter::default(),
4362 };
4363
4364 let result = solver
4365 .try_eval_and_append(
4366 Expression::Binary {
4367 op: BinaryOperator::Multiply,
4368 left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4369 right: *self.vec_exprs.get(&nc).unwrap(),
4370 },
4371 Default::default(),
4372 )
4373 .unwrap();
4374 self.flatten(result)
4375 }
4376
4377 fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4380 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4381 let mut solver = ConstantEvaluator {
4382 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4383 types: &mut self.types,
4384 constants: &Arena::new(),
4385 overrides: &Arena::new(),
4386 expressions: &mut self.expressions,
4387 expression_kind_tracker,
4388 layouter: &mut crate::proc::Layouter::default(),
4389 };
4390
4391 let result = solver
4392 .try_eval_and_append(
4393 Expression::Binary {
4394 op: BinaryOperator::Multiply,
4395 left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4396 right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4397 },
4398 Default::default(),
4399 )
4400 .unwrap();
4401 self.flatten(result)
4402 }
4403
4404 fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4405 let Expression::Compose {
4406 ref components,
4407 ref ty,
4408 } = self.expressions[expr]
4409 else {
4410 unreachable!()
4411 };
4412
4413 match self.types[*ty].inner {
4414 TypeInner::Vector { .. } => components
4415 .iter()
4416 .map(|&comp| {
4417 let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4418 unreachable!()
4419 };
4420 v
4421 })
4422 .collect(),
4423 TypeInner::Matrix { .. } => components
4424 .iter()
4425 .flat_map(|&comp| self.flatten(comp))
4426 .collect(),
4427 _ => unreachable!(),
4428 }
4429 }
4430
4431 fn int_to_vector_size(int: usize) -> VectorSize {
4432 match int {
4433 2 => VectorSize::Bi,
4434 3 => VectorSize::Tri,
4435 4 => VectorSize::Quad,
4436 _ => unreachable!(),
4437 }
4438 }
4439 }
4440
4441 #[test]
4442 fn cast() {
4443 let mut types = UniqueArena::new();
4444 let mut constants = Arena::new();
4445 let overrides = Arena::new();
4446 let mut global_expressions = Arena::new();
4447
4448 let scalar_ty = types.insert(
4449 Type {
4450 name: None,
4451 inner: TypeInner::Scalar(crate::Scalar::I32),
4452 },
4453 Default::default(),
4454 );
4455
4456 let h = constants.append(
4457 Constant {
4458 name: None,
4459 ty: scalar_ty,
4460 init: global_expressions
4461 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4462 },
4463 Default::default(),
4464 );
4465
4466 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4467
4468 let root = Expression::As {
4469 expr,
4470 kind: ScalarKind::Bool,
4471 convert: Some(crate::BOOL_WIDTH),
4472 };
4473
4474 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4475 let mut solver = ConstantEvaluator {
4476 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4477 types: &mut types,
4478 constants: &constants,
4479 overrides: &overrides,
4480 expressions: &mut global_expressions,
4481 expression_kind_tracker,
4482 layouter: &mut crate::proc::Layouter::default(),
4483 };
4484
4485 let res = solver
4486 .try_eval_and_append(root, Default::default())
4487 .unwrap();
4488
4489 assert_eq!(
4490 global_expressions[res],
4491 Expression::Literal(Literal::Bool(true))
4492 );
4493 }
4494
4495 #[test]
4496 fn access() {
4497 let mut types = UniqueArena::new();
4498 let mut constants = Arena::new();
4499 let overrides = Arena::new();
4500 let mut global_expressions = Arena::new();
4501
4502 let matrix_ty = types.insert(
4503 Type {
4504 name: None,
4505 inner: TypeInner::Matrix {
4506 columns: VectorSize::Bi,
4507 rows: VectorSize::Tri,
4508 scalar: crate::Scalar::F32,
4509 },
4510 },
4511 Default::default(),
4512 );
4513
4514 let vec_ty = types.insert(
4515 Type {
4516 name: None,
4517 inner: TypeInner::Vector {
4518 size: VectorSize::Tri,
4519 scalar: crate::Scalar::F32,
4520 },
4521 },
4522 Default::default(),
4523 );
4524
4525 let mut vec1_components = Vec::with_capacity(3);
4526 let mut vec2_components = Vec::with_capacity(3);
4527
4528 for i in 0..3 {
4529 let h = global_expressions.append(
4530 Expression::Literal(Literal::F32(i as f32)),
4531 Default::default(),
4532 );
4533
4534 vec1_components.push(h)
4535 }
4536
4537 for i in 3..6 {
4538 let h = global_expressions.append(
4539 Expression::Literal(Literal::F32(i as f32)),
4540 Default::default(),
4541 );
4542
4543 vec2_components.push(h)
4544 }
4545
4546 let vec1 = constants.append(
4547 Constant {
4548 name: None,
4549 ty: vec_ty,
4550 init: global_expressions.append(
4551 Expression::Compose {
4552 ty: vec_ty,
4553 components: vec1_components,
4554 },
4555 Default::default(),
4556 ),
4557 },
4558 Default::default(),
4559 );
4560
4561 let vec2 = constants.append(
4562 Constant {
4563 name: None,
4564 ty: vec_ty,
4565 init: global_expressions.append(
4566 Expression::Compose {
4567 ty: vec_ty,
4568 components: vec2_components,
4569 },
4570 Default::default(),
4571 ),
4572 },
4573 Default::default(),
4574 );
4575
4576 let h = constants.append(
4577 Constant {
4578 name: None,
4579 ty: matrix_ty,
4580 init: global_expressions.append(
4581 Expression::Compose {
4582 ty: matrix_ty,
4583 components: vec![constants[vec1].init, constants[vec2].init],
4584 },
4585 Default::default(),
4586 ),
4587 },
4588 Default::default(),
4589 );
4590
4591 let base = global_expressions.append(Expression::Constant(h), Default::default());
4592
4593 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4594 let mut solver = ConstantEvaluator {
4595 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4596 types: &mut types,
4597 constants: &constants,
4598 overrides: &overrides,
4599 expressions: &mut global_expressions,
4600 expression_kind_tracker,
4601 layouter: &mut crate::proc::Layouter::default(),
4602 };
4603
4604 let root1 = Expression::AccessIndex { base, index: 1 };
4605
4606 let res1 = solver
4607 .try_eval_and_append(root1, Default::default())
4608 .unwrap();
4609
4610 let root2 = Expression::AccessIndex {
4611 base: res1,
4612 index: 2,
4613 };
4614
4615 let res2 = solver
4616 .try_eval_and_append(root2, Default::default())
4617 .unwrap();
4618
4619 match global_expressions[res1] {
4620 Expression::Compose {
4621 ref ty,
4622 ref components,
4623 } => {
4624 assert_eq!(*ty, vec_ty);
4625 let mut components_iter = components.iter().copied();
4626 assert_eq!(
4627 global_expressions[components_iter.next().unwrap()],
4628 Expression::Literal(Literal::F32(3.))
4629 );
4630 assert_eq!(
4631 global_expressions[components_iter.next().unwrap()],
4632 Expression::Literal(Literal::F32(4.))
4633 );
4634 assert_eq!(
4635 global_expressions[components_iter.next().unwrap()],
4636 Expression::Literal(Literal::F32(5.))
4637 );
4638 assert!(components_iter.next().is_none());
4639 }
4640 _ => panic!("Expected vector"),
4641 }
4642
4643 assert_eq!(
4644 global_expressions[res2],
4645 Expression::Literal(Literal::F32(5.))
4646 );
4647 }
4648
4649 #[test]
4650 fn compose_of_constants() {
4651 let mut types = UniqueArena::new();
4652 let mut constants = Arena::new();
4653 let overrides = Arena::new();
4654 let mut global_expressions = Arena::new();
4655
4656 let i32_ty = types.insert(
4657 Type {
4658 name: None,
4659 inner: TypeInner::Scalar(crate::Scalar::I32),
4660 },
4661 Default::default(),
4662 );
4663
4664 let vec2_i32_ty = types.insert(
4665 Type {
4666 name: None,
4667 inner: TypeInner::Vector {
4668 size: VectorSize::Bi,
4669 scalar: crate::Scalar::I32,
4670 },
4671 },
4672 Default::default(),
4673 );
4674
4675 let h = constants.append(
4676 Constant {
4677 name: None,
4678 ty: i32_ty,
4679 init: global_expressions
4680 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4681 },
4682 Default::default(),
4683 );
4684
4685 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4686
4687 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4688 let mut solver = ConstantEvaluator {
4689 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4690 types: &mut types,
4691 constants: &constants,
4692 overrides: &overrides,
4693 expressions: &mut global_expressions,
4694 expression_kind_tracker,
4695 layouter: &mut crate::proc::Layouter::default(),
4696 };
4697
4698 let solved_compose = solver
4699 .try_eval_and_append(
4700 Expression::Compose {
4701 ty: vec2_i32_ty,
4702 components: vec![h_expr, h_expr],
4703 },
4704 Default::default(),
4705 )
4706 .unwrap();
4707 let solved_negate = solver
4708 .try_eval_and_append(
4709 Expression::Unary {
4710 op: UnaryOperator::Negate,
4711 expr: solved_compose,
4712 },
4713 Default::default(),
4714 )
4715 .unwrap();
4716
4717 let pass = match global_expressions[solved_negate] {
4718 Expression::Compose { ty, ref components } => {
4719 ty == vec2_i32_ty
4720 && components.iter().all(|&component| {
4721 let component = &global_expressions[component];
4722 matches!(*component, Expression::Literal(Literal::I32(-4)))
4723 })
4724 }
4725 _ => false,
4726 };
4727 if !pass {
4728 panic!("unexpected evaluation result")
4729 }
4730 }
4731
4732 #[test]
4733 fn splat_of_constant() {
4734 let mut types = UniqueArena::new();
4735 let mut constants = Arena::new();
4736 let overrides = Arena::new();
4737 let mut global_expressions = Arena::new();
4738
4739 let i32_ty = types.insert(
4740 Type {
4741 name: None,
4742 inner: TypeInner::Scalar(crate::Scalar::I32),
4743 },
4744 Default::default(),
4745 );
4746
4747 let vec2_i32_ty = types.insert(
4748 Type {
4749 name: None,
4750 inner: TypeInner::Vector {
4751 size: VectorSize::Bi,
4752 scalar: crate::Scalar::I32,
4753 },
4754 },
4755 Default::default(),
4756 );
4757
4758 let h = constants.append(
4759 Constant {
4760 name: None,
4761 ty: i32_ty,
4762 init: global_expressions
4763 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4764 },
4765 Default::default(),
4766 );
4767
4768 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4769
4770 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4771 let mut solver = ConstantEvaluator {
4772 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4773 types: &mut types,
4774 constants: &constants,
4775 overrides: &overrides,
4776 expressions: &mut global_expressions,
4777 expression_kind_tracker,
4778 layouter: &mut crate::proc::Layouter::default(),
4779 };
4780
4781 let solved_compose = solver
4782 .try_eval_and_append(
4783 Expression::Splat {
4784 size: VectorSize::Bi,
4785 value: h_expr,
4786 },
4787 Default::default(),
4788 )
4789 .unwrap();
4790 let solved_negate = solver
4791 .try_eval_and_append(
4792 Expression::Unary {
4793 op: UnaryOperator::Negate,
4794 expr: solved_compose,
4795 },
4796 Default::default(),
4797 )
4798 .unwrap();
4799
4800 let pass = match global_expressions[solved_negate] {
4801 Expression::Compose { ty, ref components } => {
4802 ty == vec2_i32_ty
4803 && components.iter().all(|&component| {
4804 let component = &global_expressions[component];
4805 matches!(*component, Expression::Literal(Literal::I32(-4)))
4806 })
4807 }
4808 _ => false,
4809 };
4810 if !pass {
4811 panic!("unexpected evaluation result")
4812 }
4813 }
4814
4815 #[test]
4816 fn splat_of_zero_value() {
4817 let mut types = UniqueArena::new();
4818 let constants = Arena::new();
4819 let overrides = Arena::new();
4820 let mut global_expressions = Arena::new();
4821
4822 let f32_ty = types.insert(
4823 Type {
4824 name: None,
4825 inner: TypeInner::Scalar(crate::Scalar::F32),
4826 },
4827 Default::default(),
4828 );
4829
4830 let vec2_f32_ty = types.insert(
4831 Type {
4832 name: None,
4833 inner: TypeInner::Vector {
4834 size: VectorSize::Bi,
4835 scalar: crate::Scalar::F32,
4836 },
4837 },
4838 Default::default(),
4839 );
4840
4841 let five =
4842 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4843 let five_splat = global_expressions.append(
4844 Expression::Splat {
4845 size: VectorSize::Bi,
4846 value: five,
4847 },
4848 Default::default(),
4849 );
4850 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4851 let zero_splat = global_expressions.append(
4852 Expression::Splat {
4853 size: VectorSize::Bi,
4854 value: zero,
4855 },
4856 Default::default(),
4857 );
4858
4859 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4860 let mut solver = ConstantEvaluator {
4861 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4862 types: &mut types,
4863 constants: &constants,
4864 overrides: &overrides,
4865 expressions: &mut global_expressions,
4866 expression_kind_tracker,
4867 layouter: &mut crate::proc::Layouter::default(),
4868 };
4869
4870 let solved_add = solver
4871 .try_eval_and_append(
4872 Expression::Binary {
4873 op: BinaryOperator::Add,
4874 left: zero_splat,
4875 right: five_splat,
4876 },
4877 Default::default(),
4878 )
4879 .unwrap();
4880
4881 let pass = match global_expressions[solved_add] {
4882 Expression::Compose { ty, ref components } => {
4883 ty == vec2_f32_ty
4884 && components.iter().all(|&component| {
4885 let component = &global_expressions[component];
4886 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4887 })
4888 }
4889 _ => false,
4890 };
4891 if !pass {
4892 panic!("unexpected evaluation result")
4893 }
4894 }
4895}