1use alloc::{
7 format,
8 string::{String, ToString},
9 vec,
10 vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19 arena::{Arena, Handle, HandleVec, UniqueArena},
20 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21 ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27macro_rules! with_dollar_sign {
33 ($($body:tt)*) => {
34 macro_rules! __with_dollar_sign { $($body)* }
35 __with_dollar_sign!($);
36 }
37}
38
39macro_rules! gen_component_wise_extractor {
40 (
41 $ident:ident -> $target:ident,
42 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44 ) => {
45 #[derive(Debug)]
47 #[cfg_attr(test, derive(PartialEq))]
48 enum $target<const N: usize> {
49 $(
50 #[doc = concat!(
51 "Maps to [`Literal::",
52 stringify!($literal),
53 "`]",
54 )]
55 $mapping([$ty; N]),
56 )+
57 }
58
59 impl From<$target<1>> for Expression {
60 fn from(value: $target<1>) -> Self {
61 match value {
62 $(
63 $target::$mapping([value]) => {
64 Expression::Literal(Literal::$literal(value))
65 }
66 )+
67 }
68 }
69 }
70
71 #[doc = concat!(
72 "Attempts to evaluate multiple `exprs` as a combined [`",
73 stringify!($target),
74 "`] to pass to `handler`. ",
75 )]
76 fn $ident<const N: usize, const M: usize>(
83 eval: &mut ConstantEvaluator<'_>,
84 span: Span,
85 exprs: [Handle<Expression>; N],
86 handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88 where
89 $target<M>: Into<Expression>,
90 {
91 assert!(N > 0);
92 let err = ConstantEvaluatorError::InvalidMathArg;
93 let mut exprs = exprs.into_iter();
94
95 macro_rules! sanitize {
96 ($expr:expr) => {
97 eval.eval_zero_value_and_splat($expr, span)
98 .map(|expr| &eval.expressions[expr])
99 };
100 }
101
102 let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103 $(
104 &Expression::Literal(Literal::$literal(x)) => {
105 let mut arr = ArrayVec::<_, N>::new();
106 arr.push(x);
107 for expr in exprs {
108 match sanitize!(expr)? {
109 &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110 _ => return Err(err),
111 }
112 }
113 let comps = $target::$mapping(arr.into_inner().unwrap());
114 Ok(handler(comps)?.into())
115 },
116 )+
117 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118 &TypeInner::Vector { size, scalar } => match scalar.kind {
119 $(ScalarKind::$scalar_kind)|* => {
120 let first_ty = ty;
121 let mut component_groups =
122 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123 {
124 let mut inner = ArrayVec::new();
125 for item in crate::proc::flatten_compose(
126 first_ty,
127 components,
128 eval.expressions,
129 eval.types,
130 ) {
131 inner.push(item);
132 }
133 component_groups.push(inner);
134 }
135 for expr in exprs {
136 match sanitize!(expr)? {
137 &Expression::Compose { ty, ref components }
138 if &eval.types[ty].inner
139 == &eval.types[first_ty].inner =>
140 {
141 let mut inner = ArrayVec::new();
142 for item in crate::proc::flatten_compose(
143 ty,
144 components,
145 eval.expressions,
146 eval.types,
147 ) {
148 inner.push(item);
149 }
150 component_groups.push(inner);
151 }
152 _ => return Err(err),
153 }
154 }
155 let component_groups = component_groups.into_inner().unwrap();
156 let mut new_components =
157 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158 for idx in 0..(size as u8).into() {
159 let mut group_arr = ArrayVec::<_, N>::new();
160 for cs in component_groups.iter() {
161 group_arr.push(
162 cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163 );
164 }
165 let group = group_arr.into_inner().unwrap();
166 new_components.push($ident(
167 eval,
168 span,
169 group,
170 handler,
171 )?);
172 }
173 Ok(Expression::Compose {
174 ty: first_ty,
175 components: new_components.into_iter().collect(),
176 })
177 }
178 _ => return Err(err),
179 },
180 _ => return Err(err),
181 },
182 _ => return Err(err),
183 };
184 eval.register_evaluated_expr(new_expr?, span)
185 }
186
187 with_dollar_sign! {
188 ($d:tt) => {
189 #[allow(unused)]
190 #[doc = concat!(
191 "A convenience macro for using the same RHS for each [`",
192 stringify!($target),
193 "`] variant in a call to [`",
194 stringify!($ident),
195 "`].",
196 )]
197 macro_rules! $ident {
198 (
199 $eval:expr,
200 $span:expr,
201 [$d ($d expr:expr),+ $d (,)?],
202 |$d ($d arg:ident),+| $d tt:tt
203 ) => {
204 $ident($eval, $span, [$d ($d expr),+], |args| match args {
205 $(
206 $target::$mapping([$d ($d arg),+]) => {
207 let res = $d tt;
208 Result::map(res, $target::$mapping)
209 },
210 )+
211 })
212 };
213 }
214 };
215 }
216 };
217}
218
219gen_component_wise_extractor! {
220 component_wise_scalar -> Scalar,
221 literals: [
222 AbstractFloat => AbstractFloat: f64,
223 F32 => F32: f32,
224 F16 => F16: f16,
225 AbstractInt => AbstractInt: i64,
226 U32 => U32: u32,
227 I32 => I32: i32,
228 U64 => U64: u64,
229 I64 => I64: i64,
230 ],
231 scalar_kinds: [
232 Float,
233 AbstractFloat,
234 Sint,
235 Uint,
236 AbstractInt,
237 ],
238}
239
240gen_component_wise_extractor! {
241 component_wise_float -> Float,
242 literals: [
243 AbstractFloat => Abstract: f64,
244 F32 => F32: f32,
245 F16 => F16: f16,
246 ],
247 scalar_kinds: [
248 Float,
249 AbstractFloat,
250 ],
251}
252
253gen_component_wise_extractor! {
254 component_wise_concrete_int -> ConcreteInt,
255 literals: [
256 U32 => U32: u32,
257 I32 => I32: i32,
258 ],
259 scalar_kinds: [
260 Sint,
261 Uint,
262 ],
263}
264
265gen_component_wise_extractor! {
266 component_wise_signed -> Signed,
267 literals: [
268 AbstractFloat => AbstractFloat: f64,
269 AbstractInt => AbstractInt: i64,
270 F32 => F32: f32,
271 F16 => F16: f16,
272 I32 => I32: i32,
273 ],
274 scalar_kinds: [
275 Sint,
276 AbstractInt,
277 Float,
278 AbstractFloat,
279 ],
280}
281
282#[derive(Debug)]
284enum LiteralVector {
285 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
289 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
290 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
291 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
292 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
293 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
294 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
295}
296
297impl LiteralVector {
298 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
299 fn len(&self) -> usize {
300 match *self {
301 LiteralVector::F64(ref v) => v.len(),
302 LiteralVector::F32(ref v) => v.len(),
303 LiteralVector::F16(ref v) => v.len(),
304 LiteralVector::U32(ref v) => v.len(),
305 LiteralVector::I32(ref v) => v.len(),
306 LiteralVector::U64(ref v) => v.len(),
307 LiteralVector::I64(ref v) => v.len(),
308 LiteralVector::Bool(ref v) => v.len(),
309 LiteralVector::AbstractInt(ref v) => v.len(),
310 LiteralVector::AbstractFloat(ref v) => v.len(),
311 }
312 }
313
314 fn from_literal(literal: Literal) -> Self {
316 fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
317 let mut v = ArrayVec::new();
318 v.push(val);
319 v
320 }
321 match literal {
322 Literal::F64(e) => Self::F64(arrayvec_of(e)),
323 Literal::F32(e) => Self::F32(arrayvec_of(e)),
324 Literal::U32(e) => Self::U32(arrayvec_of(e)),
325 Literal::I32(e) => Self::I32(arrayvec_of(e)),
326 Literal::U64(e) => Self::U64(arrayvec_of(e)),
327 Literal::I64(e) => Self::I64(arrayvec_of(e)),
328 Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
329 Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
330 Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
331 Literal::F16(e) => Self::F16(arrayvec_of(e)),
332 }
333 }
334
335 fn from_literal_vec(
340 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
341 ) -> Result<Self, ConstantEvaluatorError> {
342 assert!(!components.is_empty());
343 macro_rules! compose_literals {
345 ($components:expr, $variant:ident, $self_variant:ident) => {{
346 let mut out = ArrayVec::new();
347 for l in &$components {
348 match l {
349 &Literal::$variant(v) => out.push(v),
350 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
351 }
352 }
353 Self::$self_variant(out)
354 }};
355 }
356 Ok(match components[0] {
357 Literal::I32(_) => compose_literals!(components, I32, I32),
358 Literal::U32(_) => compose_literals!(components, U32, U32),
359 Literal::I64(_) => compose_literals!(components, I64, I64),
360 Literal::U64(_) => compose_literals!(components, U64, U64),
361 Literal::F32(_) => compose_literals!(components, F32, F32),
362 Literal::F64(_) => compose_literals!(components, F64, F64),
363 Literal::Bool(_) => compose_literals!(components, Bool, Bool),
364 Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
365 Literal::AbstractFloat(_) => {
366 compose_literals!(components, AbstractFloat, AbstractFloat)
367 }
368 Literal::F16(_) => compose_literals!(components, F16, F16),
369 })
370 }
371
372 #[allow(dead_code)]
373 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
375 macro_rules! decompose_literals {
376 ($v:expr, $variant:ident) => {{
377 let mut out = ArrayVec::new();
378 for e in $v {
379 out.push(Literal::$variant(*e));
380 }
381 out
382 }};
383 }
384 match *self {
385 LiteralVector::F64(ref v) => decompose_literals!(v, F64),
386 LiteralVector::F32(ref v) => decompose_literals!(v, F32),
387 LiteralVector::F16(ref v) => decompose_literals!(v, F16),
388 LiteralVector::U32(ref v) => decompose_literals!(v, U32),
389 LiteralVector::I32(ref v) => decompose_literals!(v, I32),
390 LiteralVector::U64(ref v) => decompose_literals!(v, U64),
391 LiteralVector::I64(ref v) => decompose_literals!(v, I64),
392 LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
393 LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
394 LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
395 }
396 }
397
398 #[allow(dead_code)]
399 fn register_as_evaluated_expr(
401 &self,
402 eval: &mut ConstantEvaluator<'_>,
403 span: Span,
404 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
405 let lit_vec = self.to_literal_vec();
406 assert!(!lit_vec.is_empty());
407 let expr = if lit_vec.len() == 1 {
408 Expression::Literal(lit_vec[0])
409 } else {
410 Expression::Compose {
411 ty: eval.types.insert(
412 Type {
413 name: None,
414 inner: TypeInner::Vector {
415 size: match lit_vec.len() {
416 2 => crate::VectorSize::Bi,
417 3 => crate::VectorSize::Tri,
418 4 => crate::VectorSize::Quad,
419 _ => unreachable!(),
420 },
421 scalar: lit_vec[0].scalar(),
422 },
423 },
424 Span::UNDEFINED,
425 ),
426 components: lit_vec
427 .iter()
428 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
429 .collect::<Result<_, _>>()?,
430 }
431 };
432 eval.register_evaluated_expr(expr, span)
433 }
434}
435
436macro_rules! match_literal_vector {
461 (match $lit_vec:expr => $out:ident {
462 $(
463 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
464 ),+
465 $(,)?
466 }) => {
467 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
468 };
469
470 (@inner_start
471 $lit_vec:expr;
472 $out:ident;
473 [$($ty:ident),+];
474 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
475 ) => {
476 match_literal_vector!(@inner
477 $lit_vec;
478 $out;
479 [$($ty),+];
480 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
481 )
482 };
483
484 (@inner
485 $lit_vec:expr;
486 $out:ident;
487 [$ty:ident $(, $ty1:ident)*];
488 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
489 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
490 ) => {
491 match_literal_vector!(@inner
492 $ty;
493 $lit_vec;
494 $out;
495 [$($ty1),*];
496 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
497 [$({ $($var),+ ; $($ret)? ; $body }),+]
498 )
499 };
500 (@inner
501 Integer;
502 $lit_vec:expr;
503 $out:ident;
504 [$($ty:ident),*];
505 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
506 [
507 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
508 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
509 ]
510 ) => {
511 match_literal_vector!(@inner
512 $lit_vec;
513 $out;
514 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
515 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
516 [
517 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
523 ]
524 )
525 };
526 (@inner
527 Float;
528 $lit_vec:expr;
529 $out:ident;
530 [$($ty:ident),*];
531 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
532 [
533 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
534 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
535 ]
536 ) => {
537 match_literal_vector!(@inner
538 $lit_vec;
539 $out;
540 [F16, F32, F64, AbstractFloat $(, $ty)*];
541 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542 [
543 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
548 ]
549 )
550 };
551 (@inner
552 $ty:ident;
553 $lit_vec:expr;
554 $out:ident;
555 [$ty1:ident $(,$ty2:ident)*];
556 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
557 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
558 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
559 ]
560 ) => {
561 match_literal_vector!(@inner
562 $ty1;
563 $lit_vec;
564 $out;
565 [$($ty2),*];
566 [
567 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
568 { $ty; $($var),+ ; $($ret)? ; $body }
569 ] <>
570 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
571
572 )
573 };
574 (@inner
575 $ty:ident;
576 $lit_vec:expr;
577 $out:ident;
578 [];
579 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
580 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
581 ) => {
582 match_literal_vector!(@inner_finish
583 $lit_vec;
584 $out;
585 [
586 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
587 { $ty; $($var),+ ; $($ret)? ; $body }
588 ]
589 )
590 };
591 (@inner_finish
592 $lit_vec:expr;
593 $out:ident;
594 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
595 ) => {
596 match $lit_vec {
597 $(
598 #[allow(unused_parens)]
599 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
600 )+
601 _ => Err(ConstantEvaluatorError::InvalidMathArg),
602 }
603 };
604 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
605 $out::$ty($body)
606 };
607 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
608 $out::$ret($body)
609 };
610}
611
612fn float_length<F>(e: &[F]) -> Option<F>
613where
614 F: core::ops::Mul<F> + num_traits::Float + iter::Sum,
615{
616 if e.len() == 1 {
617 Some(e[0].abs())
619 } else {
620 let result = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
621 result.is_finite().then_some(result)
622 }
623}
624
625#[derive(Debug)]
626enum Behavior<'a> {
627 Wgsl(WgslRestrictions<'a>),
628 Glsl(GlslRestrictions<'a>),
629}
630
631impl Behavior<'_> {
632 const fn has_runtime_restrictions(&self) -> bool {
634 matches!(
635 self,
636 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
637 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
638 )
639 }
640}
641
642#[derive(Debug)]
660pub struct ConstantEvaluator<'a> {
661 behavior: Behavior<'a>,
663
664 types: &'a mut UniqueArena<Type>,
671
672 constants: &'a Arena<Constant>,
674
675 overrides: &'a Arena<Override>,
677
678 expressions: &'a mut Arena<Expression>,
680
681 expression_kind_tracker: &'a mut ExpressionKindTracker,
683
684 layouter: &'a mut crate::proc::Layouter,
685}
686
687#[derive(Debug)]
688enum WgslRestrictions<'a> {
689 Const(Option<FunctionLocalData<'a>>),
691 Override,
694 Runtime(FunctionLocalData<'a>),
698}
699
700#[derive(Debug)]
701enum GlslRestrictions<'a> {
702 Const,
704 Runtime(FunctionLocalData<'a>),
708}
709
710#[derive(Debug)]
711struct FunctionLocalData<'a> {
712 global_expressions: &'a Arena<Expression>,
714 emitter: &'a mut super::Emitter,
715 block: &'a mut crate::Block,
716}
717
718#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
719pub enum ExpressionKind {
720 Const,
721 Override,
722 Runtime,
723}
724
725#[derive(Debug)]
726pub struct ExpressionKindTracker {
727 inner: HandleVec<Expression, ExpressionKind>,
728}
729
730impl ExpressionKindTracker {
731 pub const fn new() -> Self {
732 Self {
733 inner: HandleVec::new(),
734 }
735 }
736
737 pub fn force_non_const(&mut self, value: Handle<Expression>) {
739 self.inner[value] = ExpressionKind::Runtime;
740 }
741
742 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
743 self.inner.insert(value, expr_type);
744 }
745
746 pub fn is_const(&self, h: Handle<Expression>) -> bool {
747 matches!(self.type_of(h), ExpressionKind::Const)
748 }
749
750 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
751 matches!(
752 self.type_of(h),
753 ExpressionKind::Const | ExpressionKind::Override
754 )
755 }
756
757 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
758 self.inner[value]
759 }
760
761 pub fn from_arena(arena: &Arena<Expression>) -> Self {
762 let mut tracker = Self {
763 inner: HandleVec::with_capacity(arena.len()),
764 };
765 for (handle, expr) in arena.iter() {
766 tracker
767 .inner
768 .insert(handle, tracker.type_of_with_expr(expr));
769 }
770 tracker
771 }
772
773 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
774 match *expr {
775 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
776 ExpressionKind::Const
777 }
778 Expression::Override(_) => ExpressionKind::Override,
779 Expression::Compose { ref components, .. } => {
780 let mut expr_type = ExpressionKind::Const;
781 for component in components {
782 expr_type = expr_type.max(self.type_of(*component))
783 }
784 expr_type
785 }
786 Expression::Splat { value, .. } => self.type_of(value),
787 Expression::AccessIndex { base, .. } => self.type_of(base),
788 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
789 Expression::Swizzle { vector, .. } => self.type_of(vector),
790 Expression::Unary { expr, .. } => self.type_of(expr),
791 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
792 Expression::Math {
793 arg,
794 arg1,
795 arg2,
796 arg3,
797 ..
798 } => self
799 .type_of(arg)
800 .max(
801 arg1.map(|arg| self.type_of(arg))
802 .unwrap_or(ExpressionKind::Const),
803 )
804 .max(
805 arg2.map(|arg| self.type_of(arg))
806 .unwrap_or(ExpressionKind::Const),
807 )
808 .max(
809 arg3.map(|arg| self.type_of(arg))
810 .unwrap_or(ExpressionKind::Const),
811 ),
812 Expression::As { expr, .. } => self.type_of(expr),
813 Expression::Select {
814 condition,
815 accept,
816 reject,
817 } => self
818 .type_of(condition)
819 .max(self.type_of(accept))
820 .max(self.type_of(reject)),
821 Expression::Relational { argument, .. } => self.type_of(argument),
822 Expression::ArrayLength(expr) => self.type_of(expr),
823 _ => ExpressionKind::Runtime,
824 }
825 }
826}
827
828#[derive(Clone, Debug, thiserror::Error)]
829#[cfg_attr(test, derive(PartialEq))]
830pub enum ConstantEvaluatorError {
831 #[error("Constants cannot access function arguments")]
832 FunctionArg,
833 #[error("Constants cannot access global variables")]
834 GlobalVariable,
835 #[error("Constants cannot access local variables")]
836 LocalVariable,
837 #[error("Cannot get the array length of a non array type")]
838 InvalidArrayLengthArg,
839 #[error("Constants cannot get the array length of a dynamically sized array")]
840 ArrayLengthDynamic,
841 #[error("Cannot call arrayLength on array sized by override-expression")]
842 ArrayLengthOverridden,
843 #[error("Constants cannot call functions")]
844 Call,
845 #[error("Constants don't support workGroupUniformLoad")]
846 WorkGroupUniformLoadResult,
847 #[error("Constants don't support atomic functions")]
848 Atomic,
849 #[error("Constants don't support derivative functions")]
850 Derivative,
851 #[error("Constants don't support load expressions")]
852 Load,
853 #[error("Constants don't support image expressions")]
854 ImageExpression,
855 #[error("Constants don't support ray query expressions")]
856 RayQueryExpression,
857 #[error("Constants don't support subgroup expressions")]
858 SubgroupExpression,
859 #[error("Cannot access the type")]
860 InvalidAccessBase,
861 #[error("Cannot access at the index")]
862 InvalidAccessIndex,
863 #[error("Cannot access with index of type")]
864 InvalidAccessIndexTy,
865 #[error("Constants don't support array length expressions")]
866 ArrayLength,
867 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
868 InvalidCastArg { from: String, to: String },
869 #[error("Cannot apply the unary op to the argument")]
870 InvalidUnaryOpArg,
871 #[error("Cannot apply the binary op to the arguments")]
872 InvalidBinaryOpArgs,
873 #[error("Cannot apply math function to type")]
874 InvalidMathArg,
875 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
876 InvalidMathArgCount(crate::MathFunction, usize, usize),
877 #[error("{0} built-in function argument is out of valid range")]
878 InvalidMathArgValue(String),
879 #[error("Cannot apply relational function to type")]
880 InvalidRelationalArg(RelationalFunction),
881 #[error("value of `low` is greater than `high` for clamp built-in function")]
882 InvalidClamp,
883 #[error("Constructor expects {expected} components, found {actual}")]
884 InvalidVectorComposeLength { expected: usize, actual: usize },
885 #[error("Constructor must only contain vector or scalar arguments")]
886 InvalidVectorComposeComponent,
887 #[error("Splat is defined only on scalar values")]
888 SplatScalarOnly,
889 #[error("Can only swizzle vector constants")]
890 SwizzleVectorOnly,
891 #[error("swizzle component not present in source expression")]
892 SwizzleOutOfBounds,
893 #[error("Type is not constructible")]
894 TypeNotConstructible,
895 #[error("Subexpression(s) are not constant")]
896 SubexpressionsAreNotConstant,
897 #[error("Not implemented as constant expression: {0}")]
898 NotImplemented(String),
899 #[error("{0} operation overflowed")]
900 Overflow(String),
901 #[error(
902 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
903 )]
904 AutomaticConversionLossy {
905 value: String,
906 to_type: &'static str,
907 },
908 #[error("Division by zero")]
909 DivisionByZero,
910 #[error("Remainder by zero")]
911 RemainderByZero,
912 #[error("RHS of shift operation is greater than or equal to 32")]
913 ShiftedMoreThan32Bits,
914 #[error(transparent)]
915 Literal(#[from] crate::valid::LiteralError),
916 #[error("Can't use pipeline-overridable constants in const-expressions")]
917 Override,
918 #[error("Unexpected runtime-expression")]
919 RuntimeExpr,
920 #[error("Unexpected override-expression")]
921 OverrideExpr,
922 #[error("Expected boolean expression for condition argument of `select`, got something else")]
923 SelectScalarConditionNotABool,
924 #[error(
925 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
926 reject,
927 accept
928 )]
929 SelectVecRejectAcceptSizeMismatch {
930 reject: crate::VectorSize,
931 accept: crate::VectorSize,
932 },
933 #[error("Expected boolean vector for condition arg., got something else")]
934 SelectConditionNotAVecBool,
935 #[error(
936 "Expected same number of vector components between condition, accept, and reject args., got something else",
937 )]
938 SelectConditionVecSizeMismatch,
939 #[error(
940 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
941 )]
942 SelectAcceptRejectTypeMismatch,
943 #[error("Cooperative operations can't be constant")]
944 CooperativeOperation,
945}
946
947impl<'a> ConstantEvaluator<'a> {
948 pub const fn for_wgsl_module(
953 module: &'a mut crate::Module,
954 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
955 layouter: &'a mut crate::proc::Layouter,
956 in_override_ctx: bool,
957 ) -> Self {
958 Self::for_module(
959 Behavior::Wgsl(if in_override_ctx {
960 WgslRestrictions::Override
961 } else {
962 WgslRestrictions::Const(None)
963 }),
964 module,
965 global_expression_kind_tracker,
966 layouter,
967 )
968 }
969
970 pub const fn for_glsl_module(
975 module: &'a mut crate::Module,
976 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
977 layouter: &'a mut crate::proc::Layouter,
978 ) -> Self {
979 Self::for_module(
980 Behavior::Glsl(GlslRestrictions::Const),
981 module,
982 global_expression_kind_tracker,
983 layouter,
984 )
985 }
986
987 const fn for_module(
988 behavior: Behavior<'a>,
989 module: &'a mut crate::Module,
990 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
991 layouter: &'a mut crate::proc::Layouter,
992 ) -> Self {
993 Self {
994 behavior,
995 types: &mut module.types,
996 constants: &module.constants,
997 overrides: &module.overrides,
998 expressions: &mut module.global_expressions,
999 expression_kind_tracker: global_expression_kind_tracker,
1000 layouter,
1001 }
1002 }
1003
1004 pub const fn for_wgsl_function(
1009 module: &'a mut crate::Module,
1010 expressions: &'a mut Arena<Expression>,
1011 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1012 layouter: &'a mut crate::proc::Layouter,
1013 emitter: &'a mut super::Emitter,
1014 block: &'a mut crate::Block,
1015 is_const: bool,
1016 ) -> Self {
1017 let local_data = FunctionLocalData {
1018 global_expressions: &module.global_expressions,
1019 emitter,
1020 block,
1021 };
1022 Self {
1023 behavior: Behavior::Wgsl(if is_const {
1024 WgslRestrictions::Const(Some(local_data))
1025 } else {
1026 WgslRestrictions::Runtime(local_data)
1027 }),
1028 types: &mut module.types,
1029 constants: &module.constants,
1030 overrides: &module.overrides,
1031 expressions,
1032 expression_kind_tracker: local_expression_kind_tracker,
1033 layouter,
1034 }
1035 }
1036
1037 pub const fn for_glsl_function(
1042 module: &'a mut crate::Module,
1043 expressions: &'a mut Arena<Expression>,
1044 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1045 layouter: &'a mut crate::proc::Layouter,
1046 emitter: &'a mut super::Emitter,
1047 block: &'a mut crate::Block,
1048 ) -> Self {
1049 Self {
1050 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1051 global_expressions: &module.global_expressions,
1052 emitter,
1053 block,
1054 })),
1055 types: &mut module.types,
1056 constants: &module.constants,
1057 overrides: &module.overrides,
1058 expressions,
1059 expression_kind_tracker: local_expression_kind_tracker,
1060 layouter,
1061 }
1062 }
1063
1064 pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1065 crate::proc::GlobalCtx {
1066 types: self.types,
1067 constants: self.constants,
1068 overrides: self.overrides,
1069 global_expressions: match self.function_local_data() {
1070 Some(data) => data.global_expressions,
1071 None => self.expressions,
1072 },
1073 }
1074 }
1075
1076 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1077 if !self.expression_kind_tracker.is_const(expr) {
1078 log::debug!("check: SubexpressionsAreNotConstant");
1079 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1080 }
1081 Ok(())
1082 }
1083
1084 fn check_and_get(
1085 &mut self,
1086 expr: Handle<Expression>,
1087 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1088 match self.expressions[expr] {
1089 Expression::Constant(c) => {
1090 if let Some(function_local_data) = self.function_local_data() {
1093 self.copy_from(
1095 self.constants[c].init,
1096 function_local_data.global_expressions,
1097 )
1098 } else {
1099 Ok(self.constants[c].init)
1101 }
1102 }
1103 _ => {
1104 self.check(expr)?;
1105 Ok(expr)
1106 }
1107 }
1108 }
1109
1110 pub fn try_eval_and_append(
1134 &mut self,
1135 expr: Expression,
1136 span: Span,
1137 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1138 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1139 ExpressionKind::Const => {
1140 let eval_result = self.try_eval_and_append_impl(&expr, span);
1141 if self.behavior.has_runtime_restrictions()
1146 && matches!(
1147 eval_result,
1148 Err(ConstantEvaluatorError::NotImplemented(_)
1149 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1150 )
1151 {
1152 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1153 } else {
1154 eval_result
1155 }
1156 }
1157 ExpressionKind::Override => match self.behavior {
1158 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1159 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1160 }
1161 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1162 Err(ConstantEvaluatorError::OverrideExpr)
1163 }
1164
1165 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1167 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1168 }
1169 Behavior::Glsl(GlslRestrictions::Const) => {
1170 Err(ConstantEvaluatorError::OverrideExpr)
1171 }
1172 },
1173 ExpressionKind::Runtime => {
1174 if self.behavior.has_runtime_restrictions() {
1175 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1176 } else {
1177 Err(ConstantEvaluatorError::RuntimeExpr)
1178 }
1179 }
1180 }
1181 }
1182
1183 const fn is_global_arena(&self) -> bool {
1185 matches!(
1186 self.behavior,
1187 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1188 | Behavior::Glsl(GlslRestrictions::Const)
1189 )
1190 }
1191
1192 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1193 match self.behavior {
1194 Behavior::Wgsl(
1195 WgslRestrictions::Runtime(ref function_local_data)
1196 | WgslRestrictions::Const(Some(ref function_local_data)),
1197 )
1198 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1199 Some(function_local_data)
1200 }
1201 _ => None,
1202 }
1203 }
1204
1205 fn try_eval_and_append_impl(
1206 &mut self,
1207 expr: &Expression,
1208 span: Span,
1209 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1210 log::trace!("try_eval_and_append: {expr:?}");
1211 match *expr {
1212 Expression::Constant(c) if self.is_global_arena() => {
1213 Ok(self.constants[c].init)
1216 }
1217 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1218 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1219 self.register_evaluated_expr(expr.clone(), span)
1220 }
1221 Expression::Compose { ty, ref components } => {
1222 let components = components
1223 .iter()
1224 .map(|component| self.check_and_get(*component))
1225 .collect::<Result<Vec<_>, _>>()?;
1226 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1227 }
1228 Expression::Splat { size, value } => {
1229 let value = self.check_and_get(value)?;
1230 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1231 }
1232 Expression::AccessIndex { base, index } => {
1233 let base = self.check_and_get(base)?;
1234
1235 self.access(base, index as usize, span)
1236 }
1237 Expression::Access { base, index } => {
1238 let base = self.check_and_get(base)?;
1239 let index = self.check_and_get(index)?;
1240
1241 let index_val: u32 = self
1242 .to_ctx()
1243 .get_const_val_from(index, self.expressions)
1244 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1245 self.access(base, index_val as usize, span)
1246 }
1247 Expression::Swizzle {
1248 size,
1249 vector,
1250 pattern,
1251 } => {
1252 let vector = self.check_and_get(vector)?;
1253
1254 self.swizzle(size, span, vector, pattern)
1255 }
1256 Expression::Unary { expr, op } => {
1257 let expr = self.check_and_get(expr)?;
1258
1259 self.unary_op(op, expr, span)
1260 }
1261 Expression::Binary { left, right, op } => {
1262 let left = self.check_and_get(left)?;
1263 let right = self.check_and_get(right)?;
1264
1265 self.binary_op(op, left, right, span)
1266 }
1267 Expression::Math {
1268 fun,
1269 arg,
1270 arg1,
1271 arg2,
1272 arg3,
1273 } => {
1274 let arg = self.check_and_get(arg)?;
1275 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1276 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1277 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1278
1279 self.math(arg, arg1, arg2, arg3, fun, span)
1280 }
1281 Expression::As {
1282 convert,
1283 expr,
1284 kind,
1285 } => {
1286 let expr = self.check_and_get(expr)?;
1287
1288 match convert {
1289 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1290 None => Err(ConstantEvaluatorError::NotImplemented(
1291 "bitcast built-in function".into(),
1292 )),
1293 }
1294 }
1295 Expression::Select {
1296 reject,
1297 accept,
1298 condition,
1299 } => {
1300 let mut arg = |expr| self.check_and_get(expr);
1301
1302 let reject = arg(reject)?;
1303 let accept = arg(accept)?;
1304 let condition = arg(condition)?;
1305
1306 self.select(reject, accept, condition, span)
1307 }
1308 Expression::Relational { fun, argument } => {
1309 let argument = self.check_and_get(argument)?;
1310 self.relational(fun, argument, span)
1311 }
1312 Expression::ArrayLength(expr) => match self.behavior {
1313 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1314 Behavior::Glsl(_) => {
1315 let expr = self.check_and_get(expr)?;
1316 self.array_length(expr, span)
1317 }
1318 },
1319 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1320 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1321 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1322 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1323 Expression::WorkGroupUniformLoadResult { .. } => {
1324 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1325 }
1326 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1327 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1328 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1329 Expression::ImageSample { .. }
1330 | Expression::ImageLoad { .. }
1331 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1332 Expression::RayQueryProceedResult
1333 | Expression::RayQueryGetIntersection { .. }
1334 | Expression::RayQueryVertexPositions { .. } => {
1335 Err(ConstantEvaluatorError::RayQueryExpression)
1336 }
1337 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1338 Expression::SubgroupOperationResult { .. } => {
1339 Err(ConstantEvaluatorError::SubgroupExpression)
1340 }
1341 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1342 Err(ConstantEvaluatorError::CooperativeOperation)
1343 }
1344 }
1345 }
1346
1347 fn splat(
1360 &mut self,
1361 value: Handle<Expression>,
1362 size: crate::VectorSize,
1363 span: Span,
1364 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1365 match self.expressions[value] {
1366 Expression::Literal(literal) => {
1367 let scalar = literal.scalar();
1368 let ty = self.types.insert(
1369 Type {
1370 name: None,
1371 inner: TypeInner::Vector { size, scalar },
1372 },
1373 span,
1374 );
1375 let expr = Expression::Compose {
1376 ty,
1377 components: vec![value; size as usize],
1378 };
1379 self.register_evaluated_expr(expr, span)
1380 }
1381 Expression::ZeroValue(ty) => {
1382 let inner = match self.types[ty].inner {
1383 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1384 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1385 };
1386 let res_ty = self.types.insert(Type { name: None, inner }, span);
1387 let expr = Expression::ZeroValue(res_ty);
1388 self.register_evaluated_expr(expr, span)
1389 }
1390 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1391 }
1392 }
1393
1394 fn swizzle(
1395 &mut self,
1396 size: crate::VectorSize,
1397 span: Span,
1398 src_constant: Handle<Expression>,
1399 pattern: [crate::SwizzleComponent; 4],
1400 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1401 let mut get_dst_ty = |ty| match self.types[ty].inner {
1402 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1403 Type {
1404 name: None,
1405 inner: TypeInner::Vector { size, scalar },
1406 },
1407 span,
1408 )),
1409 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1410 };
1411
1412 match self.expressions[src_constant] {
1413 Expression::ZeroValue(ty) => {
1414 let dst_ty = get_dst_ty(ty)?;
1415 let expr = Expression::ZeroValue(dst_ty);
1416 self.register_evaluated_expr(expr, span)
1417 }
1418 Expression::Splat { value, .. } => {
1419 let expr = Expression::Splat { size, value };
1420 self.register_evaluated_expr(expr, span)
1421 }
1422 Expression::Compose { ty, ref components } => {
1423 let dst_ty = get_dst_ty(ty)?;
1424
1425 let mut flattened = [src_constant; 4]; let len =
1427 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1428 .zip(flattened.iter_mut())
1429 .map(|(component, elt)| *elt = component)
1430 .count();
1431 let flattened = &flattened[..len];
1432
1433 let swizzled_components = pattern[..size as usize]
1434 .iter()
1435 .map(|&sc| {
1436 let sc = sc as usize;
1437 if let Some(elt) = flattened.get(sc) {
1438 Ok(*elt)
1439 } else {
1440 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1441 }
1442 })
1443 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1444 let expr = Expression::Compose {
1445 ty: dst_ty,
1446 components: swizzled_components,
1447 };
1448 self.register_evaluated_expr(expr, span)
1449 }
1450 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1451 }
1452 }
1453
1454 fn math(
1455 &mut self,
1456 arg: Handle<Expression>,
1457 arg1: Option<Handle<Expression>>,
1458 arg2: Option<Handle<Expression>>,
1459 arg3: Option<Handle<Expression>>,
1460 fun: crate::MathFunction,
1461 span: Span,
1462 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1463 let expected = fun.argument_count();
1464 let given = Some(arg)
1465 .into_iter()
1466 .chain(arg1)
1467 .chain(arg2)
1468 .chain(arg3)
1469 .count();
1470 if expected != given {
1471 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1472 fun, expected, given,
1473 ));
1474 }
1475
1476 match fun {
1478 crate::MathFunction::Abs => {
1480 component_wise_scalar(self, span, [arg], |args| match args {
1481 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1482 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1483 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1484 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1485 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1486 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1488 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1489 })
1490 }
1491 crate::MathFunction::Min => {
1492 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1493 Ok([e1.min(e2)])
1494 })
1495 }
1496 crate::MathFunction::Max => {
1497 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1498 Ok([e1.max(e2)])
1499 })
1500 }
1501 crate::MathFunction::Clamp => {
1502 component_wise_scalar!(
1503 self,
1504 span,
1505 [arg, arg1.unwrap(), arg2.unwrap()],
1506 |e, low, high| {
1507 if low > high {
1508 Err(ConstantEvaluatorError::InvalidClamp)
1509 } else {
1510 Ok([e.clamp(low, high)])
1511 }
1512 }
1513 )
1514 }
1515 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1516 Float::F16([e]) => Ok(Float::F16(
1517 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1518 )),
1519 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1520 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1521 }),
1522
1523 crate::MathFunction::Cos => {
1525 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1526 }
1527 crate::MathFunction::Cosh => {
1528 component_wise_float!(self, span, [arg], |e| {
1529 let result = e.cosh();
1530 if result.is_finite() {
1531 Ok([result])
1532 } else {
1533 Err(ConstantEvaluatorError::Overflow("cosh".into()))
1534 }
1535 })
1536 }
1537 crate::MathFunction::Sin => {
1538 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1539 }
1540 crate::MathFunction::Sinh => {
1541 component_wise_float!(self, span, [arg], |e| {
1542 let result = e.sinh();
1543 if result.is_finite() {
1544 Ok([result])
1545 } else {
1546 Err(ConstantEvaluatorError::Overflow("sinh".into()))
1547 }
1548 })
1549 }
1550 crate::MathFunction::Tan => {
1551 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1552 }
1553 crate::MathFunction::Tanh => {
1554 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1555 }
1556 crate::MathFunction::Acos => {
1557 component_wise_float!(self, span, [arg], |e| {
1558 if e.abs() <= One::one() {
1559 Ok([e.acos()])
1560 } else {
1561 Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1562 }
1563 })
1564 }
1565 crate::MathFunction::Asin => {
1566 component_wise_float!(self, span, [arg], |e| {
1567 if e.abs() <= One::one() {
1568 Ok([e.asin()])
1569 } else {
1570 Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1571 }
1572 })
1573 }
1574 crate::MathFunction::Atan => {
1575 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1576 }
1577 crate::MathFunction::Atan2 => {
1578 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1579 Ok([y.atan2(x)])
1580 })
1581 }
1582 crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1583 Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1584 Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1585 Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1586 }),
1587 crate::MathFunction::Acosh => component_wise_float(self, span, [arg], |e| match e {
1588 Float::Abstract([e]) if e >= One::one() => Ok(Float::Abstract([libm::acosh(e)])),
1589 Float::F32([e]) if e >= One::one() => Ok(Float::F32([(e as f64).acosh() as f32])),
1590 Float::F16([e]) if e >= One::one() => Ok(Float::F16([e.acosh()])),
1591 _ => Err(ConstantEvaluatorError::InvalidMathArgValue("acosh".into())),
1592 }),
1593 crate::MathFunction::Atanh => {
1594 component_wise_float!(self, span, [arg], |e| {
1595 if e.abs() < One::one() {
1596 Ok([e.atanh()])
1597 } else {
1598 Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1599 }
1600 })
1601 }
1602 crate::MathFunction::Radians => {
1603 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1604 }
1605 crate::MathFunction::Degrees => {
1606 component_wise_float!(self, span, [arg], |e| {
1607 let result = e.to_degrees();
1608 if result.is_finite() {
1609 Ok([result])
1610 } else {
1611 Err(ConstantEvaluatorError::Overflow("degrees".into()))
1612 }
1613 })
1614 }
1615
1616 crate::MathFunction::Ceil => {
1618 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1619 }
1620 crate::MathFunction::Floor => {
1621 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1622 }
1623 crate::MathFunction::Round => {
1624 component_wise_float(self, span, [arg], |e| match e {
1625 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1626 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1627 Float::F16([e]) => {
1628 fn round_ties_even(x: f64) -> f64 {
1636 let i = x as i64;
1637 let f = (x - i as f64).abs();
1638 if f == 0.5 {
1639 if i & 1 == 1 {
1640 (x.abs() + 0.5).copysign(x)
1642 } else {
1643 (x.abs() - 0.5).copysign(x)
1644 }
1645 } else {
1646 x.round()
1647 }
1648 }
1649
1650 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1651 }
1652 })
1653 }
1654 crate::MathFunction::Fract => {
1655 component_wise_float!(self, span, [arg], |e| {
1656 Ok([e - e.floor()])
1659 })
1660 }
1661 crate::MathFunction::Trunc => {
1662 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1663 }
1664
1665 crate::MathFunction::Exp => {
1667 component_wise_float!(self, span, [arg], |e| {
1668 let result = e.exp();
1669 if result.is_finite() {
1670 Ok([result])
1671 } else {
1672 Err(ConstantEvaluatorError::Overflow("exp".into()))
1673 }
1674 })
1675 }
1676 crate::MathFunction::Exp2 => {
1677 component_wise_float!(self, span, [arg], |e| {
1678 let result = e.exp2();
1679 if result.is_finite() {
1680 Ok([result])
1681 } else {
1682 Err(ConstantEvaluatorError::Overflow("exp2".into()))
1683 }
1684 })
1685 }
1686 crate::MathFunction::Log => {
1687 component_wise_float!(self, span, [arg], |e| {
1688 if e > Zero::zero() {
1689 Ok([e.ln()])
1690 } else {
1691 Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1692 }
1693 })
1694 }
1695 crate::MathFunction::Log2 => {
1696 component_wise_float!(self, span, [arg], |e| {
1697 if e > Zero::zero() {
1698 Ok([e.log2()])
1699 } else {
1700 Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1701 }
1702 })
1703 }
1704 crate::MathFunction::Pow => {
1705 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1706 if e1 < Zero::zero()
1709 || e1.is_one() && e2.is_infinite()
1710 || e1.is_infinite() && e2.is_zero()
1711 || e1.is_zero() && e2.is_zero()
1712 {
1713 Err(ConstantEvaluatorError::InvalidMathArgValue("pow".into()))
1714 } else {
1715 let result = e1.powf(e2);
1716 if result.is_finite() {
1717 Ok([result])
1718 } else {
1719 Err(ConstantEvaluatorError::Overflow("pow".into()))
1720 }
1721 }
1722 })
1723 }
1724
1725 crate::MathFunction::Sign => {
1727 component_wise_signed!(self, span, [arg], |e| {
1728 Ok([if e.is_zero() {
1729 Zero::zero()
1730 } else {
1731 e.signum()
1732 }])
1733 })
1734 }
1735 crate::MathFunction::Fma => {
1736 component_wise_float!(
1737 self,
1738 span,
1739 [arg, arg1.unwrap(), arg2.unwrap()],
1740 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1741 )
1742 }
1743 crate::MathFunction::Step => {
1744 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1745 Float::Abstract([edge, x]) => {
1746 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1747 }
1748 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1749 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1750 f16::one()
1751 } else {
1752 f16::zero()
1753 }])),
1754 })
1755 }
1756 crate::MathFunction::Sqrt => {
1757 component_wise_float!(self, span, [arg], |e| {
1758 if e >= Zero::zero() {
1759 Ok([e.sqrt()])
1760 } else {
1761 Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1762 }
1763 })
1764 }
1765 crate::MathFunction::InverseSqrt => {
1766 component_wise_float(self, span, [arg], |e| match e {
1767 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1768 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1769 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1770 })
1771 }
1772
1773 crate::MathFunction::CountTrailingZeros => {
1775 component_wise_concrete_int!(self, span, [arg], |e| {
1776 #[allow(clippy::useless_conversion)]
1777 Ok([e
1778 .trailing_zeros()
1779 .try_into()
1780 .expect("bit count overflowed 32 bits, somehow!?")])
1781 })
1782 }
1783 crate::MathFunction::CountLeadingZeros => {
1784 component_wise_concrete_int!(self, span, [arg], |e| {
1785 #[allow(clippy::useless_conversion)]
1786 Ok([e
1787 .leading_zeros()
1788 .try_into()
1789 .expect("bit count overflowed 32 bits, somehow!?")])
1790 })
1791 }
1792 crate::MathFunction::CountOneBits => {
1793 component_wise_concrete_int!(self, span, [arg], |e| {
1794 #[allow(clippy::useless_conversion)]
1795 Ok([e
1796 .count_ones()
1797 .try_into()
1798 .expect("bit count overflowed 32 bits, somehow!?")])
1799 })
1800 }
1801 crate::MathFunction::ReverseBits => {
1802 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1803 }
1804 crate::MathFunction::FirstTrailingBit => {
1805 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1806 }
1807 crate::MathFunction::FirstLeadingBit => {
1808 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1809 }
1810
1811 crate::MathFunction::Dot4I8Packed => {
1813 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1814 }
1815 crate::MathFunction::Dot4U8Packed => {
1816 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1817 }
1818 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1819 crate::MathFunction::Dot => {
1820 let e1 = self.extract_vec(arg, false)?;
1822 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1823 if e1.len() != e2.len() {
1824 return Err(ConstantEvaluatorError::InvalidMathArg);
1825 }
1826
1827 fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1828 where
1829 P: num_traits::Float,
1830 {
1831 let result = a
1832 .iter()
1833 .zip(b.iter())
1834 .map(|(&aa, &bb)| aa * bb)
1835 .fold(P::zero(), |acc, x| acc + x);
1836 if result.is_finite() {
1837 Ok(result)
1838 } else {
1839 Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1840 }
1841 }
1842
1843 fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1844 where
1845 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1846 {
1847 a.iter()
1848 .zip(b.iter())
1849 .map(|(&aa, bb)| aa.checked_mul(bb))
1850 .try_fold(P::zero(), |acc, x| {
1851 if let Some(x) = x {
1852 acc.checked_add(&x)
1853 } else {
1854 None
1855 }
1856 })
1857 .ok_or(ConstantEvaluatorError::Overflow(
1858 "in dot built-in".to_string(),
1859 ))
1860 }
1861
1862 fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1863 where
1864 P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1865 {
1866 a.iter()
1867 .zip(b.iter())
1868 .map(|(&aa, bb)| aa.wrapping_mul(bb))
1869 .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1870 }
1871
1872 let result = match_literal_vector!(match (e1, e2) => Literal {
1873 Float => |e1, e2| { float_dot_checked(e1, e2)? },
1874 AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1875 I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1876 U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1877 })?;
1878 self.register_evaluated_expr(Expression::Literal(result), span)
1879 }
1880 crate::MathFunction::Length => {
1881 let e1 = self.extract_vec(arg, true)?;
1883
1884 let result = match_literal_vector!(match e1 => Literal {
1885 Float => |e1| {
1886 float_length(e1).ok_or_else(|| ConstantEvaluatorError::Overflow("length".into()))?
1887 },
1888 })?;
1889 self.register_evaluated_expr(Expression::Literal(result), span)
1890 }
1891 crate::MathFunction::Distance => {
1892 let e1 = self.extract_vec(arg, true)?;
1894 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1895 if e1.len() != e2.len() {
1896 return Err(ConstantEvaluatorError::InvalidMathArg);
1897 }
1898
1899 fn float_distance<F>(a: &[F], b: &[F]) -> F
1900 where
1901 F: core::ops::Mul<F>,
1902 F: num_traits::Float + iter::Sum + core::ops::Sub,
1903 {
1904 if a.len() == 1 {
1905 (a[0] - b[0]).abs()
1907 } else {
1908 a.iter()
1909 .zip(b.iter())
1910 .map(|(&aa, &bb)| aa - bb)
1911 .map(|ei| ei * ei)
1912 .sum::<F>()
1913 .sqrt()
1914 }
1915 }
1916 let result = match_literal_vector!(match (e1, e2) => Literal {
1917 Float => |e1, e2| { float_distance(e1, e2) },
1918 })?;
1919 self.register_evaluated_expr(Expression::Literal(result), span)
1920 }
1921 crate::MathFunction::Normalize => {
1922 let e1 = self.extract_vec(arg, true)?;
1924
1925 fn float_normalize<F>(
1926 e: &[F],
1927 ) -> Result<ArrayVec<F, { crate::VectorSize::MAX }>, ConstantEvaluatorError>
1928 where
1929 F: core::ops::Mul<F>,
1930 F: num_traits::Float + iter::Sum,
1931 {
1932 let len = match float_length(e) {
1933 Some(len) if !len.is_zero() => Ok(len),
1934 Some(_) => Err(ConstantEvaluatorError::InvalidMathArgValue(
1935 "normalize".into(),
1936 )),
1937 None => Err(ConstantEvaluatorError::Overflow("normalize".into())),
1938 }?;
1939
1940 let mut out = ArrayVec::new();
1941 for &ei in e {
1942 out.push(ei / len);
1943 }
1944 Ok(out)
1945 }
1946
1947 let result = match_literal_vector!(match e1 => LiteralVector {
1948 Float => |e1| { float_normalize(e1)? },
1949 })?;
1950 result.register_as_evaluated_expr(self, span)
1951 }
1952
1953 crate::MathFunction::Modf
1955 | crate::MathFunction::Frexp
1956 | crate::MathFunction::Ldexp
1957 | crate::MathFunction::Outer
1958 | crate::MathFunction::FaceForward
1959 | crate::MathFunction::Reflect
1960 | crate::MathFunction::Refract
1961 | crate::MathFunction::Mix
1962 | crate::MathFunction::SmoothStep
1963 | crate::MathFunction::Inverse
1964 | crate::MathFunction::Transpose
1965 | crate::MathFunction::Determinant
1966 | crate::MathFunction::QuantizeToF16
1967 | crate::MathFunction::ExtractBits
1968 | crate::MathFunction::InsertBits
1969 | crate::MathFunction::Pack4x8snorm
1970 | crate::MathFunction::Pack4x8unorm
1971 | crate::MathFunction::Pack2x16snorm
1972 | crate::MathFunction::Pack2x16unorm
1973 | crate::MathFunction::Pack2x16float
1974 | crate::MathFunction::Pack4xI8
1975 | crate::MathFunction::Pack4xU8
1976 | crate::MathFunction::Pack4xI8Clamp
1977 | crate::MathFunction::Pack4xU8Clamp
1978 | crate::MathFunction::Unpack4x8snorm
1979 | crate::MathFunction::Unpack4x8unorm
1980 | crate::MathFunction::Unpack2x16snorm
1981 | crate::MathFunction::Unpack2x16unorm
1982 | crate::MathFunction::Unpack2x16float
1983 | crate::MathFunction::Unpack4xI8
1984 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1985 format!("{fun:?} built-in function"),
1986 )),
1987 }
1988 }
1989
1990 fn packed_dot_product(
1992 &mut self,
1993 a: Handle<Expression>,
1994 b: Handle<Expression>,
1995 span: Span,
1996 signed: bool,
1997 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1998 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1999 return Err(ConstantEvaluatorError::InvalidMathArg);
2000 };
2001 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
2002 return Err(ConstantEvaluatorError::InvalidMathArg);
2003 };
2004
2005 let result = if signed {
2006 Literal::I32(
2007 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
2008 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
2009 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
2010 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
2011 )
2012 } else {
2013 Literal::U32(
2014 (a & 0xFF) * (b & 0xFF)
2015 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
2016 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
2017 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
2018 )
2019 };
2020
2021 self.register_evaluated_expr(Expression::Literal(result), span)
2022 }
2023
2024 fn cross_product(
2026 &mut self,
2027 a: Handle<Expression>,
2028 b: Handle<Expression>,
2029 span: Span,
2030 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2031 use Literal as Li;
2032
2033 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2034 let (b, _) = self.extract_vec_with_size::<3>(b)?;
2035
2036 let product = match (a, b) {
2037 (
2038 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2039 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2040 ) => {
2041 let p = cross_product(
2046 [a0 as f64, a1 as f64, a2 as f64],
2047 [b0 as f64, b1 as f64, b2 as f64],
2048 );
2049 [
2050 Li::AbstractFloat(p[0]),
2051 Li::AbstractFloat(p[1]),
2052 Li::AbstractFloat(p[2]),
2053 ]
2054 }
2055 (
2056 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2057 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2058 ) => {
2059 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2060 [
2061 Li::AbstractFloat(p[0]),
2062 Li::AbstractFloat(p[1]),
2063 Li::AbstractFloat(p[2]),
2064 ]
2065 }
2066 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2067 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2068 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2069 }
2070 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2071 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2072 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2073 }
2074 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2075 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2076 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2077 }
2078 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2079 };
2080
2081 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2082 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2083 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2084
2085 self.register_evaluated_expr(
2086 Expression::Compose {
2087 ty,
2088 components: vec![p0, p1, p2],
2089 },
2090 span,
2091 )
2092 }
2093
2094 fn extract_vec_with_size<const N: usize>(
2102 &mut self,
2103 expr: Handle<Expression>,
2104 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2105 let span = self.expressions.get_span(expr);
2106 let expr = self.eval_zero_value_and_splat(expr, span)?;
2107 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2108 return Err(ConstantEvaluatorError::InvalidMathArg);
2109 };
2110
2111 let mut value = [Literal::Bool(false); N];
2112 for (component, elt) in
2113 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2114 .zip(value.iter_mut())
2115 {
2116 let Expression::Literal(literal) = self.expressions[component] else {
2117 return Err(ConstantEvaluatorError::InvalidMathArg);
2118 };
2119 *elt = literal;
2120 }
2121
2122 Ok((value, ty))
2123 }
2124
2125 fn extract_vec(
2133 &mut self,
2134 expr: Handle<Expression>,
2135 allow_single: bool,
2136 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2137 let span = self.expressions.get_span(expr);
2138 let expr = self.eval_zero_value_and_splat(expr, span)?;
2139
2140 match self.expressions[expr] {
2141 Expression::Literal(literal) if allow_single => {
2142 Ok(LiteralVector::from_literal(literal))
2143 }
2144 Expression::Compose { ty, ref components } => {
2145 let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2146 for expr in
2147 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2148 {
2149 match self.expressions[expr] {
2150 Expression::Literal(l) => components_out.push(l),
2151 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2152 }
2153 }
2154 LiteralVector::from_literal_vec(components_out)
2155 }
2156 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2157 }
2158 }
2159
2160 fn array_length(
2161 &mut self,
2162 array: Handle<Expression>,
2163 span: Span,
2164 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2165 match self.expressions[array] {
2166 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2167 match self.types[ty].inner {
2168 TypeInner::Array { size, .. } => match size {
2169 ArraySize::Constant(len) => {
2170 let expr = Expression::Literal(Literal::U32(len.get()));
2171 self.register_evaluated_expr(expr, span)
2172 }
2173 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2174 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2175 },
2176 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2177 }
2178 }
2179 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2180 }
2181 }
2182
2183 fn access(
2184 &mut self,
2185 base: Handle<Expression>,
2186 index: usize,
2187 span: Span,
2188 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2189 match self.expressions[base] {
2190 Expression::ZeroValue(ty) => {
2191 let ty_inner = &self.types[ty].inner;
2192 let components = ty_inner
2193 .components()
2194 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2195
2196 if index >= components as usize {
2197 Err(ConstantEvaluatorError::InvalidAccessBase)
2198 } else {
2199 let ty_res = ty_inner
2200 .component_type(index)
2201 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2202 let ty = match ty_res {
2203 crate::proc::TypeResolution::Handle(ty) => ty,
2204 crate::proc::TypeResolution::Value(inner) => {
2205 self.types.insert(Type { name: None, inner }, span)
2206 }
2207 };
2208 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2209 }
2210 }
2211 Expression::Splat { size, value } => {
2212 if index >= size as usize {
2213 Err(ConstantEvaluatorError::InvalidAccessBase)
2214 } else {
2215 Ok(value)
2216 }
2217 }
2218 Expression::Compose { ty, ref components } => {
2219 let _ = self.types[ty]
2220 .inner
2221 .components()
2222 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2223
2224 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2225 .nth(index)
2226 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2227 }
2228 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2229 }
2230 }
2231
2232 fn eval_zero_value_and_splat(
2239 &mut self,
2240 mut expr: Handle<Expression>,
2241 span: Span,
2242 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2243 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2246 let components = components
2247 .clone()
2248 .iter()
2249 .map(|component| self.eval_zero_value_and_splat(*component, span))
2250 .collect::<Result<_, _>>()?;
2251 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2252 }
2253
2254 if let Expression::Splat { size, value } = self.expressions[expr] {
2258 expr = self.splat(value, size, span)?;
2259 }
2260 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2261 expr = self.eval_zero_value_impl(ty, span)?;
2262 }
2263 Ok(expr)
2264 }
2265
2266 fn eval_zero_value(
2272 &mut self,
2273 expr: Handle<Expression>,
2274 span: Span,
2275 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2276 match self.expressions[expr] {
2277 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2278 _ => Ok(expr),
2279 }
2280 }
2281
2282 fn eval_zero_value_impl(
2288 &mut self,
2289 ty: Handle<Type>,
2290 span: Span,
2291 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2292 match self.types[ty].inner {
2293 TypeInner::Scalar(scalar) => {
2294 let expr = Expression::Literal(
2295 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2296 );
2297 self.register_evaluated_expr(expr, span)
2298 }
2299 TypeInner::Vector { size, scalar } => {
2300 let scalar_ty = self.types.insert(
2301 Type {
2302 name: None,
2303 inner: TypeInner::Scalar(scalar),
2304 },
2305 span,
2306 );
2307 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2308 let expr = Expression::Compose {
2309 ty,
2310 components: vec![el; size as usize],
2311 };
2312 self.register_evaluated_expr(expr, span)
2313 }
2314 TypeInner::Matrix {
2315 columns,
2316 rows,
2317 scalar,
2318 } => {
2319 let vec_ty = self.types.insert(
2320 Type {
2321 name: None,
2322 inner: TypeInner::Vector { size: rows, scalar },
2323 },
2324 span,
2325 );
2326 let el = self.eval_zero_value_impl(vec_ty, span)?;
2327 let expr = Expression::Compose {
2328 ty,
2329 components: vec![el; columns as usize],
2330 };
2331 self.register_evaluated_expr(expr, span)
2332 }
2333 TypeInner::Array {
2334 base,
2335 size: ArraySize::Constant(size),
2336 ..
2337 } => {
2338 let el = self.eval_zero_value_impl(base, span)?;
2339 let expr = Expression::Compose {
2340 ty,
2341 components: vec![el; size.get() as usize],
2342 };
2343 self.register_evaluated_expr(expr, span)
2344 }
2345 TypeInner::Struct { ref members, .. } => {
2346 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2347 let mut components = Vec::with_capacity(members.len());
2348 for ty in types {
2349 components.push(self.eval_zero_value_impl(ty, span)?);
2350 }
2351 let expr = Expression::Compose { ty, components };
2352 self.register_evaluated_expr(expr, span)
2353 }
2354 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2355 }
2356 }
2357
2358 pub fn cast(
2362 &mut self,
2363 expr: Handle<Expression>,
2364 target: crate::Scalar,
2365 span: Span,
2366 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2367 use crate::Scalar as Sc;
2368
2369 let expr = self.eval_zero_value(expr, span)?;
2370
2371 let make_error = || -> Result<_, ConstantEvaluatorError> {
2372 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2373
2374 #[cfg(feature = "wgsl-in")]
2375 let to = target.to_wgsl_for_diagnostics();
2376
2377 #[cfg(not(feature = "wgsl-in"))]
2378 let to = format!("{target:?}");
2379
2380 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2381 };
2382
2383 use crate::proc::type_methods::IntFloatLimits;
2384
2385 let expr = match self.expressions[expr] {
2386 Expression::Literal(literal) => {
2387 let literal = match target {
2388 Sc::I32 => Literal::I32(match literal {
2389 Literal::I32(v) => v,
2390 Literal::U32(v) => v as i32,
2391 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2392 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2394 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2395 return make_error();
2396 }
2397 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2398 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2399 }),
2400 Sc::U32 => Literal::U32(match literal {
2401 Literal::I32(v) => v as u32,
2402 Literal::U32(v) => v,
2403 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2404 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2406 Literal::Bool(v) => v as u32,
2407 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2408 return make_error();
2409 }
2410 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2411 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2412 }),
2413 Sc::I64 => Literal::I64(match literal {
2414 Literal::I32(v) => v as i64,
2415 Literal::U32(v) => v as i64,
2416 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2417 Literal::Bool(v) => v as i64,
2418 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2419 Literal::I64(v) => v,
2420 Literal::U64(v) => v as i64,
2421 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2423 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2424 }),
2425 Sc::U64 => Literal::U64(match literal {
2426 Literal::I32(v) => v as u64,
2427 Literal::U32(v) => v as u64,
2428 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2429 Literal::Bool(v) => v as u64,
2430 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2431 Literal::I64(v) => v as u64,
2432 Literal::U64(v) => v,
2433 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2435 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2436 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2437 }),
2438 Sc::F16 => Literal::F16(match literal {
2439 Literal::F16(v) => v,
2440 Literal::F32(v) => f16::from_f32(v),
2441 Literal::F64(v) => f16::from_f64(v),
2442 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2443 Literal::I64(v) => f16::from_i64(v).unwrap(),
2444 Literal::U64(v) => f16::from_u64(v).unwrap(),
2445 Literal::I32(v) => f16::from_i32(v).unwrap(),
2446 Literal::U32(v) => f16::from_u32(v).unwrap(),
2447 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2448 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2449 }),
2450 Sc::F32 => Literal::F32(match literal {
2451 Literal::I32(v) => v as f32,
2452 Literal::U32(v) => v as f32,
2453 Literal::F32(v) => v,
2454 Literal::Bool(v) => v as u32 as f32,
2455 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2456 return make_error();
2457 }
2458 Literal::F16(v) => f16::to_f32(v),
2459 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2460 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2461 }),
2462 Sc::F64 => Literal::F64(match literal {
2463 Literal::I32(v) => v as f64,
2464 Literal::U32(v) => v as f64,
2465 Literal::F16(v) => f16::to_f64(v),
2466 Literal::F32(v) => v as f64,
2467 Literal::F64(v) => v,
2468 Literal::Bool(v) => v as u32 as f64,
2469 Literal::I64(_) | Literal::U64(_) => return make_error(),
2470 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2471 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2472 }),
2473 Sc::BOOL => Literal::Bool(match literal {
2474 Literal::I32(v) => v != 0,
2475 Literal::U32(v) => v != 0,
2476 Literal::F32(v) => v != 0.0,
2477 Literal::F16(v) => v != f16::zero(),
2478 Literal::Bool(v) => v,
2479 Literal::AbstractInt(v) => v != 0,
2480 Literal::AbstractFloat(v) => v != 0.0,
2481 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2482 return make_error();
2483 }
2484 }),
2485 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2486 Literal::AbstractInt(v) => {
2487 v as f64
2492 }
2493 Literal::AbstractFloat(v) => v,
2494 _ => return make_error(),
2495 }),
2496 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2497 Literal::AbstractInt(v) => v,
2498 _ => return make_error(),
2499 }),
2500 _ => {
2501 log::debug!("Constant evaluator refused to convert value to {target:?}");
2502 return make_error();
2503 }
2504 };
2505 Expression::Literal(literal)
2506 }
2507 Expression::Compose {
2508 ty,
2509 components: ref src_components,
2510 } => {
2511 let ty_inner = match self.types[ty].inner {
2512 TypeInner::Vector { size, .. } => TypeInner::Vector {
2513 size,
2514 scalar: target,
2515 },
2516 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2517 columns,
2518 rows,
2519 scalar: target,
2520 },
2521 _ => return make_error(),
2522 };
2523
2524 let mut components = src_components.clone();
2525 for component in &mut components {
2526 *component = self.cast(*component, target, span)?;
2527 }
2528
2529 let ty = self.types.insert(
2530 Type {
2531 name: None,
2532 inner: ty_inner,
2533 },
2534 span,
2535 );
2536
2537 Expression::Compose { ty, components }
2538 }
2539 Expression::Splat { size, value } => {
2540 let value_span = self.expressions.get_span(value);
2541 let cast_value = self.cast(value, target, value_span)?;
2542 Expression::Splat {
2543 size,
2544 value: cast_value,
2545 }
2546 }
2547 _ => return make_error(),
2548 };
2549
2550 self.register_evaluated_expr(expr, span)
2551 }
2552
2553 pub fn cast_array(
2566 &mut self,
2567 expr: Handle<Expression>,
2568 target: crate::Scalar,
2569 span: Span,
2570 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2571 let expr = self.check_and_get(expr)?;
2572
2573 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2574 return self.cast(expr, target, span);
2575 };
2576
2577 let TypeInner::Array {
2578 base: _,
2579 size,
2580 stride: _,
2581 } = self.types[ty].inner
2582 else {
2583 return self.cast(expr, target, span);
2584 };
2585
2586 let mut components = components.clone();
2587 for component in &mut components {
2588 *component = self.cast_array(*component, target, span)?;
2589 }
2590
2591 let first = components.first().unwrap();
2592 let new_base = match self.resolve_type(*first)? {
2593 crate::proc::TypeResolution::Handle(ty) => ty,
2594 crate::proc::TypeResolution::Value(inner) => {
2595 self.types.insert(Type { name: None, inner }, span)
2596 }
2597 };
2598 let mut layouter = core::mem::take(self.layouter);
2599 layouter.update(self.to_ctx()).unwrap();
2600 *self.layouter = layouter;
2601
2602 let new_base_stride = self.layouter[new_base].to_stride();
2603 let new_array_ty = self.types.insert(
2604 Type {
2605 name: None,
2606 inner: TypeInner::Array {
2607 base: new_base,
2608 size,
2609 stride: new_base_stride,
2610 },
2611 },
2612 span,
2613 );
2614
2615 let compose = Expression::Compose {
2616 ty: new_array_ty,
2617 components,
2618 };
2619 self.register_evaluated_expr(compose, span)
2620 }
2621
2622 fn unary_op(
2623 &mut self,
2624 op: UnaryOperator,
2625 expr: Handle<Expression>,
2626 span: Span,
2627 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2628 let expr = self.eval_zero_value_and_splat(expr, span)?;
2629
2630 let expr = match self.expressions[expr] {
2631 Expression::Literal(value) => Expression::Literal(match op {
2632 UnaryOperator::Negate => match value {
2633 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2634 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2635 Literal::F32(v) => Literal::F32(-v),
2636 Literal::F16(v) => Literal::F16(-v),
2637 Literal::F64(v) => Literal::F64(-v),
2638 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2639 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2640 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2641 },
2642 UnaryOperator::LogicalNot => match value {
2643 Literal::Bool(v) => Literal::Bool(!v),
2644 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2645 },
2646 UnaryOperator::BitwiseNot => match value {
2647 Literal::I32(v) => Literal::I32(!v),
2648 Literal::I64(v) => Literal::I64(!v),
2649 Literal::U32(v) => Literal::U32(!v),
2650 Literal::U64(v) => Literal::U64(!v),
2651 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2652 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2653 },
2654 }),
2655 Expression::Compose {
2656 ty,
2657 components: ref src_components,
2658 } => {
2659 match self.types[ty].inner {
2660 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2661 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2662 }
2663
2664 let mut components = src_components.clone();
2665 for component in &mut components {
2666 *component = self.unary_op(op, *component, span)?;
2667 }
2668
2669 Expression::Compose { ty, components }
2670 }
2671 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2672 };
2673
2674 self.register_evaluated_expr(expr, span)
2675 }
2676
2677 fn binary_op(
2678 &mut self,
2679 op: BinaryOperator,
2680 left: Handle<Expression>,
2681 right: Handle<Expression>,
2682 span: Span,
2683 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2684 let left = self.eval_zero_value_and_splat(left, span)?;
2685 let right = self.eval_zero_value_and_splat(right, span)?;
2686
2687 let expr = match (&self.expressions[left], &self.expressions[right]) {
2692 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2693 if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2694 && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2695 {
2696 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2697 }
2698
2699 if matches!(
2700 (left_value, op),
2701 (
2702 Literal::Bool(_),
2703 BinaryOperator::Less
2704 | BinaryOperator::LessEqual
2705 | BinaryOperator::Greater
2706 | BinaryOperator::GreaterEqual
2707 )
2708 ) {
2709 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2710 }
2711
2712 let literal = match op {
2713 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2714 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2715 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2716 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2717 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2718 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2719
2720 _ => match (left_value, right_value) {
2721 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2722 BinaryOperator::Add => a.wrapping_add(b),
2723 BinaryOperator::Subtract => a.wrapping_sub(b),
2724 BinaryOperator::Multiply => a.wrapping_mul(b),
2725 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2726 if b == 0 {
2727 ConstantEvaluatorError::DivisionByZero
2728 } else {
2729 ConstantEvaluatorError::Overflow("division".into())
2730 }
2731 })?,
2732 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2733 if b == 0 {
2734 ConstantEvaluatorError::RemainderByZero
2735 } else {
2736 ConstantEvaluatorError::Overflow("remainder".into())
2737 }
2738 })?,
2739 BinaryOperator::And => a & b,
2740 BinaryOperator::ExclusiveOr => a ^ b,
2741 BinaryOperator::InclusiveOr => a | b,
2742 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2743 }),
2744 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2745 BinaryOperator::ShiftLeft => {
2746 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2747 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2748 }
2749 a.checked_shl(b)
2750 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2751 }
2752 BinaryOperator::ShiftRight => a
2753 .checked_shr(b)
2754 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2755 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2756 }),
2757 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2758 BinaryOperator::Add => a.wrapping_add(b),
2759 BinaryOperator::Subtract => a.wrapping_sub(b),
2760 BinaryOperator::Multiply => a.wrapping_mul(b),
2761 BinaryOperator::Divide => a
2762 .checked_div(b)
2763 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2764 BinaryOperator::Modulo => a
2765 .checked_rem(b)
2766 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2767 BinaryOperator::And => a & b,
2768 BinaryOperator::ExclusiveOr => a ^ b,
2769 BinaryOperator::InclusiveOr => a | b,
2770 BinaryOperator::ShiftLeft => a
2771 .checked_mul(
2772 1u32.checked_shl(b)
2773 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2774 )
2775 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2776 BinaryOperator::ShiftRight => a
2777 .checked_shr(b)
2778 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2779 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2780 }),
2781 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2782 BinaryOperator::Add => a + b,
2783 BinaryOperator::Subtract => a - b,
2784 BinaryOperator::Multiply => a * b,
2785 BinaryOperator::Divide => a / b,
2786 BinaryOperator::Modulo => a % b,
2787 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2788 }),
2789 (Literal::AbstractInt(a), Literal::U32(b)) => {
2790 Literal::AbstractInt(match op {
2791 BinaryOperator::ShiftLeft => {
2792 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2793 return Err(ConstantEvaluatorError::Overflow(
2794 "<<".to_string(),
2795 ));
2796 }
2797 a.checked_shl(b).unwrap_or(0)
2798 }
2799 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2800 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2801 })
2802 }
2803 (Literal::F16(a), Literal::F16(b)) => {
2804 let result = match op {
2805 BinaryOperator::Add => a + b,
2806 BinaryOperator::Subtract => a - b,
2807 BinaryOperator::Multiply => a * b,
2808 BinaryOperator::Divide => a / b,
2809 BinaryOperator::Modulo => a % b,
2810 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2811 };
2812 if !result.is_finite() {
2813 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2814 }
2815 Literal::F16(result)
2816 }
2817 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2818 Literal::AbstractInt(match op {
2819 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2820 ConstantEvaluatorError::Overflow("addition".into())
2821 })?,
2822 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2823 ConstantEvaluatorError::Overflow("subtraction".into())
2824 })?,
2825 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2826 ConstantEvaluatorError::Overflow("multiplication".into())
2827 })?,
2828 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2829 if b == 0 {
2830 ConstantEvaluatorError::DivisionByZero
2831 } else {
2832 ConstantEvaluatorError::Overflow("division".into())
2833 }
2834 })?,
2835 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2836 if b == 0 {
2837 ConstantEvaluatorError::RemainderByZero
2838 } else {
2839 ConstantEvaluatorError::Overflow("remainder".into())
2840 }
2841 })?,
2842 BinaryOperator::And => a & b,
2843 BinaryOperator::ExclusiveOr => a ^ b,
2844 BinaryOperator::InclusiveOr => a | b,
2845 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2846 })
2847 }
2848 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2849 let result = match op {
2850 BinaryOperator::Add => a + b,
2851 BinaryOperator::Subtract => a - b,
2852 BinaryOperator::Multiply => a * b,
2853 BinaryOperator::Divide => a / b,
2854 BinaryOperator::Modulo => a % b,
2855 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2856 };
2857 if !result.is_finite() {
2858 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2859 }
2860 Literal::AbstractFloat(result)
2861 }
2862 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2863 BinaryOperator::LogicalAnd => a && b,
2864 BinaryOperator::LogicalOr => a || b,
2865 BinaryOperator::And => a & b,
2866 BinaryOperator::InclusiveOr => a | b,
2867 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2868 }),
2869 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2870 },
2871 };
2872 Expression::Literal(literal)
2873 }
2874 (
2875 &Expression::Compose {
2876 components: ref src_components,
2877 ty,
2878 },
2879 &Expression::Literal(_),
2880 ) => {
2881 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2882 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2883 }
2884 let mut components = src_components.clone();
2885 for component in &mut components {
2886 *component = self.binary_op(op, *component, right, span)?;
2887 }
2888 Expression::Compose { ty, components }
2889 }
2890 (
2891 &Expression::Literal(_),
2892 &Expression::Compose {
2893 components: ref src_components,
2894 ty,
2895 },
2896 ) => {
2897 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2898 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2899 }
2900 let mut components = src_components.clone();
2901 for component in &mut components {
2902 *component = self.binary_op(op, left, *component, span)?;
2903 }
2904 Expression::Compose { ty, components }
2905 }
2906 (
2907 &Expression::Compose {
2908 components: ref left_components,
2909 ty: left_ty,
2910 },
2911 &Expression::Compose {
2912 components: ref right_components,
2913 ty: right_ty,
2914 },
2915 ) => {
2916 let left_flattened = crate::proc::flatten_compose(
2920 left_ty,
2921 left_components,
2922 self.expressions,
2923 self.types,
2924 )
2925 .collect::<Vec<_>>();
2926 let right_flattened = crate::proc::flatten_compose(
2927 right_ty,
2928 right_components,
2929 self.expressions,
2930 self.types,
2931 )
2932 .collect::<Vec<_>>();
2933
2934 self.binary_op_compose(
2935 op,
2936 &left_flattened,
2937 &right_flattened,
2938 left_ty,
2939 right_ty,
2940 span,
2941 )?
2942 }
2943 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2944 };
2945
2946 return self.register_evaluated_expr(expr, span);
2947
2948 fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
2949 let is_numeric_vec = matches!(
2950 compose_ty, TypeInner::Vector { scalar, .. }
2951 if scalar.kind != ScalarKind::Bool
2952 );
2953 let is_allowed_vec_scalar_op = matches!(
2954 op,
2955 BinaryOperator::Add
2956 | BinaryOperator::Subtract
2957 | BinaryOperator::Multiply
2958 | BinaryOperator::Divide
2959 | BinaryOperator::Modulo
2960 );
2961 let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
2962 let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
2963 is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
2964 }
2965 }
2966
2967 fn binary_op_compose(
2968 &mut self,
2969 op: BinaryOperator,
2970 left_components: &[Handle<Expression>],
2971 right_components: &[Handle<Expression>],
2972 left_ty: Handle<Type>,
2973 right_ty: Handle<Type>,
2974 span: Span,
2975 ) -> Result<Expression, ConstantEvaluatorError> {
2976 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2977 (
2979 &TypeInner::Vector {
2980 size: left_size, ..
2981 },
2982 &TypeInner::Vector {
2983 size: right_size, ..
2984 },
2985 ) if left_size == right_size => self.binary_op_vector(
2986 op,
2987 left_size,
2988 left_components,
2989 right_components,
2990 left_ty,
2991 span,
2992 ),
2993 (
2995 &TypeInner::Vector { size, .. },
2996 &TypeInner::Matrix {
2997 columns,
2998 rows,
2999 scalar,
3000 },
3001 ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
3002 left_components,
3003 right_components,
3004 columns,
3005 scalar,
3006 span,
3007 ),
3008 (
3010 &TypeInner::Matrix {
3011 columns,
3012 rows,
3013 scalar,
3014 },
3015 &TypeInner::Vector { size, .. },
3016 ) if op == BinaryOperator::Multiply && size == columns => {
3017 self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
3018 }
3019 (
3021 &TypeInner::Matrix {
3022 columns: left_columns,
3023 rows: left_rows,
3024 scalar,
3025 },
3026 &TypeInner::Matrix {
3027 columns: right_columns,
3028 rows: right_rows,
3029 ..
3030 },
3031 ) => match op {
3032 BinaryOperator::Add | BinaryOperator::Subtract
3033 if left_columns == right_columns && left_rows == right_rows =>
3034 {
3035 let components = left_components
3036 .iter()
3037 .zip(right_components)
3038 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3039 .collect::<Result<Vec<_>, _>>()?;
3040 Ok(Expression::Compose {
3041 ty: left_ty,
3042 components,
3043 })
3044 }
3045 BinaryOperator::Multiply if left_columns == right_rows => self
3046 .multiply_matrix_matrix(
3047 left_components,
3048 right_components,
3049 left_rows,
3050 right_columns,
3051 scalar,
3052 span,
3053 ),
3054 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3055 },
3056 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3057 }
3058 }
3059
3060 fn binary_op_vector(
3061 &mut self,
3062 op: BinaryOperator,
3063 size: crate::VectorSize,
3064 left_components: &[Handle<Expression>],
3065 right_components: &[Handle<Expression>],
3066 left_ty: Handle<Type>,
3067 span: Span,
3068 ) -> Result<Expression, ConstantEvaluatorError> {
3069 let ty = match op {
3070 BinaryOperator::Equal
3072 | BinaryOperator::NotEqual
3073 | BinaryOperator::Less
3074 | BinaryOperator::LessEqual
3075 | BinaryOperator::Greater
3076 | BinaryOperator::GreaterEqual => self.types.insert(
3077 Type {
3078 name: None,
3079 inner: TypeInner::Vector {
3080 size,
3081 scalar: crate::Scalar::BOOL,
3082 },
3083 },
3084 span,
3085 ),
3086
3087 BinaryOperator::Add
3090 | BinaryOperator::Subtract
3091 | BinaryOperator::Multiply
3092 | BinaryOperator::Divide
3093 | BinaryOperator::Modulo
3094 | BinaryOperator::And
3095 | BinaryOperator::ExclusiveOr
3096 | BinaryOperator::InclusiveOr
3097 | BinaryOperator::ShiftLeft
3098 | BinaryOperator::ShiftRight => left_ty,
3099
3100 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3101 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3103 }
3104 };
3105
3106 let components = left_components
3107 .iter()
3108 .zip(right_components)
3109 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3110 .collect::<Result<Vec<_>, _>>()?;
3111
3112 Ok(Expression::Compose { ty, components })
3113 }
3114
3115 fn multiply_vector_matrix(
3116 &mut self,
3117 vec_components: &[Handle<Expression>],
3118 mat_components: &[Handle<Expression>],
3119 mat_columns: crate::VectorSize,
3120 scalar: crate::Scalar,
3121 span: Span,
3122 ) -> Result<Expression, ConstantEvaluatorError> {
3123 let ty = self.types.insert(
3124 Type {
3125 name: None,
3126 inner: TypeInner::Vector {
3127 size: mat_columns,
3128 scalar,
3129 },
3130 },
3131 span,
3132 );
3133 let components = mat_components
3134 .iter()
3135 .map(|&column| {
3136 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3137 unreachable!()
3138 };
3139 self.dot_exprs(
3140 vec_components.iter().cloned(),
3141 components.clone().into_iter(),
3142 span,
3143 )
3144 })
3145 .collect::<Result<Vec<_>, _>>()?;
3146 Ok(Expression::Compose { ty, components })
3147 }
3148
3149 fn multiply_matrix_vector(
3150 &mut self,
3151 mat_components: &[Handle<Expression>],
3152 vec_components: &[Handle<Expression>],
3153 mat_rows: crate::VectorSize,
3154 scalar: crate::Scalar,
3155 span: Span,
3156 ) -> Result<Expression, ConstantEvaluatorError> {
3157 let ty = self.types.insert(
3158 Type {
3159 name: None,
3160 inner: TypeInner::Vector {
3161 size: mat_rows,
3162 scalar,
3163 },
3164 },
3165 span,
3166 );
3167
3168 let flatten = self.flatten_matrix(mat_components);
3169 let nr = mat_rows as usize;
3170 let components = (0..nr)
3171 .map(|r| {
3172 let row = flatten.iter().skip(r).step_by(nr).cloned();
3173 self.dot_exprs(row, vec_components.iter().cloned(), span)
3174 })
3175 .collect::<Result<Vec<_>, _>>()?;
3176 Ok(Expression::Compose { ty, components })
3177 }
3178
3179 fn multiply_matrix_matrix(
3180 &mut self,
3181 left_components: &[Handle<Expression>],
3182 right_components: &[Handle<Expression>],
3183 left_rows: crate::VectorSize,
3184 right_columns: crate::VectorSize,
3185 scalar: crate::Scalar,
3186 span: Span,
3187 ) -> Result<Expression, ConstantEvaluatorError> {
3188 let left_nc = left_components.len();
3189 let left_nr = left_rows as usize;
3190 let right_nc = right_columns as usize;
3191 let right_nr = left_nc;
3192
3193 let mut result = Vec::with_capacity(right_nc);
3194 let result_ty = self.types.insert(
3195 Type {
3196 name: None,
3197 inner: TypeInner::Matrix {
3198 columns: right_columns,
3199 rows: left_rows,
3200 scalar,
3201 },
3202 },
3203 span,
3204 );
3205 let result_column_ty = self.types.insert(
3206 Type {
3207 name: None,
3208 inner: TypeInner::Vector {
3209 size: left_rows,
3210 scalar,
3211 },
3212 },
3213 span,
3214 );
3215
3216 let left_flattened = self.flatten_matrix(left_components);
3217 let right_flattened = self.flatten_matrix(right_components);
3218 for c in 0..right_nc {
3219 let result_column = (0..left_nr)
3220 .map(|r| {
3221 let row = left_flattened.iter().skip(r).step_by(left_nr);
3222 let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3223 self.dot_exprs(row.cloned(), column.cloned(), span)
3224 })
3225 .collect::<Result<Vec<_>, _>>()?;
3226 let expr = Expression::Compose {
3227 ty: result_column_ty,
3228 components: result_column,
3229 };
3230 let handle = self.register_evaluated_expr(expr, span)?;
3231 result.push(handle);
3232 }
3233 Ok(Expression::Compose {
3234 ty: result_ty,
3235 components: result,
3236 })
3237 }
3238
3239 fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3240 let mut flattened = ArrayVec::<_, 16>::new();
3241 for &column in columns {
3242 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3243 unreachable!()
3244 };
3245 flattened.extend(components.iter().cloned());
3246 }
3247 flattened
3248 }
3249
3250 fn dot_exprs(
3251 &mut self,
3252 left: impl Iterator<Item = Handle<Expression>>,
3253 right: impl Iterator<Item = Handle<Expression>>,
3254 span: Span,
3255 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3256 let mut acc = None;
3257 for (l, r) in left.zip(right) {
3258 let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3259 match acc.as_mut() {
3260 Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3261 None => acc = Some(result),
3262 }
3263 }
3264 Ok(acc.unwrap())
3265 }
3266
3267 fn relational(
3268 &mut self,
3269 fun: RelationalFunction,
3270 arg: Handle<Expression>,
3271 span: Span,
3272 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3273 let arg = self.eval_zero_value_and_splat(arg, span)?;
3274 match fun {
3275 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3276 Expression::Literal(Literal::Bool(_)) => Ok(arg),
3277 Expression::Compose { ty, ref components }
3278 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3279 {
3280 let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3281 for component in
3282 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3283 {
3284 match self.expressions[component] {
3285 Expression::Literal(Literal::Bool(val)) => {
3286 bool_components.push(val);
3287 }
3288 _ => {
3289 return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3290 }
3291 }
3292 }
3293 let components = bool_components;
3294 let result = match fun {
3295 RelationalFunction::All => components.iter().all(|c| *c),
3296 RelationalFunction::Any => components.iter().any(|c| *c),
3297 _ => unreachable!(),
3298 };
3299 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3300 }
3301 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3302 },
3303 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3304 "{fun:?} built-in function"
3305 ))),
3306 }
3307 }
3308
3309 fn copy_from(
3317 &mut self,
3318 expr: Handle<Expression>,
3319 expressions: &Arena<Expression>,
3320 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3321 let span = expressions.get_span(expr);
3322 match expressions[expr] {
3323 ref expr @ (Expression::Literal(_)
3324 | Expression::Constant(_)
3325 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3326 Expression::Compose { ty, ref components } => {
3327 let mut components = components.clone();
3328 for component in &mut components {
3329 *component = self.copy_from(*component, expressions)?;
3330 }
3331 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3332 }
3333 Expression::Splat { size, value } => {
3334 let value = self.copy_from(value, expressions)?;
3335 self.register_evaluated_expr(Expression::Splat { size, value }, span)
3336 }
3337 _ => {
3338 log::debug!("copy_from: SubexpressionsAreNotConstant");
3339 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3340 }
3341 }
3342 }
3343
3344 fn vector_compose_flattened_size(
3346 &self,
3347 components: &[Handle<Expression>],
3348 ) -> Result<usize, ConstantEvaluatorError> {
3349 components
3350 .iter()
3351 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3352 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3353 TypeInner::Scalar(_) => 1,
3354 TypeInner::Vector { size, .. } => size as usize,
3358 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3359 };
3360 Ok(acc + size)
3361 })
3362 }
3363
3364 fn register_evaluated_expr(
3365 &mut self,
3366 expr: Expression,
3367 span: Span,
3368 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3369 if let Expression::Literal(literal) = expr {
3374 crate::valid::check_literal_value(literal)?;
3375 }
3376
3377 if let Expression::Compose { ty, ref components } = expr {
3381 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3382 let expected = size as usize;
3383 let actual = self.vector_compose_flattened_size(components)?;
3384 if expected != actual {
3385 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3386 expected,
3387 actual,
3388 });
3389 }
3390 }
3391 }
3392
3393 Ok(self.append_expr(expr, span, ExpressionKind::Const))
3394 }
3395
3396 fn append_expr(
3397 &mut self,
3398 expr: Expression,
3399 span: Span,
3400 expr_type: ExpressionKind,
3401 ) -> Handle<Expression> {
3402 let h = match self.behavior {
3403 Behavior::Wgsl(
3404 WgslRestrictions::Runtime(ref mut function_local_data)
3405 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3406 )
3407 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3408 let is_running = function_local_data.emitter.is_running();
3409 let needs_pre_emit = expr.needs_pre_emit();
3410 if is_running && needs_pre_emit {
3411 function_local_data
3412 .block
3413 .extend(function_local_data.emitter.finish(self.expressions));
3414 let h = self.expressions.append(expr, span);
3415 function_local_data.emitter.start(self.expressions);
3416 h
3417 } else {
3418 self.expressions.append(expr, span)
3419 }
3420 }
3421 _ => self.expressions.append(expr, span),
3422 };
3423 self.expression_kind_tracker.insert(h, expr_type);
3424 h
3425 }
3426
3427 fn resolve_type(
3432 &self,
3433 expr: Handle<Expression>,
3434 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3435 use crate::proc::TypeResolution as Tr;
3436 use crate::Expression as Ex;
3437 let resolution = match self.expressions[expr] {
3438 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3439 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3440 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3441 Ex::Splat { size, value } => {
3442 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3443 return Err(ConstantEvaluatorError::SplatScalarOnly);
3444 };
3445 Tr::Value(TypeInner::Vector { scalar, size })
3446 }
3447 _ => {
3448 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3449 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3450 }
3451 };
3452
3453 Ok(resolution)
3454 }
3455
3456 fn select(
3457 &mut self,
3458 reject: Handle<Expression>,
3459 accept: Handle<Expression>,
3460 condition: Handle<Expression>,
3461 span: Span,
3462 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3463 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3464
3465 let reject = arg(reject)?;
3466 let accept = arg(accept)?;
3467 let condition = arg(condition)?;
3468
3469 let select_single_component =
3470 |this: &mut Self, reject_scalar, reject, accept, condition| {
3471 let accept = this.cast(accept, reject_scalar, span)?;
3472 if condition {
3473 Ok(accept)
3474 } else {
3475 Ok(reject)
3476 }
3477 };
3478
3479 match (&self.expressions[reject], &self.expressions[accept]) {
3480 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3481 let reject_scalar = reject_lit.scalar();
3482 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3483 else {
3484 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3485 };
3486 select_single_component(self, reject_scalar, reject, accept, condition)
3487 }
3488 (
3489 &Expression::Compose {
3490 ty: reject_ty,
3491 components: ref reject_components,
3492 },
3493 &Expression::Compose {
3494 ty: accept_ty,
3495 components: ref accept_components,
3496 },
3497 ) => {
3498 let ty_deets = |ty| {
3499 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3500 (size.unwrap(), scalar)
3501 };
3502
3503 let expected_vec_size = {
3504 let [(reject_vec_size, _), (accept_vec_size, _)] =
3505 [reject_ty, accept_ty].map(ty_deets);
3506
3507 if reject_vec_size != accept_vec_size {
3508 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3509 reject: reject_vec_size,
3510 accept: accept_vec_size,
3511 });
3512 }
3513 reject_vec_size
3514 };
3515
3516 let condition_components = match self.expressions[condition] {
3517 Expression::Literal(Literal::Bool(condition)) => {
3518 vec![condition; (expected_vec_size as u8).into()]
3519 }
3520 Expression::Compose {
3521 ty: condition_ty,
3522 components: ref condition_components,
3523 } => {
3524 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3525 if condition_scalar.kind != ScalarKind::Bool {
3526 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3527 }
3528 if condition_vec_size != expected_vec_size {
3529 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3530 }
3531 condition_components
3532 .iter()
3533 .copied()
3534 .map(|component| match &self.expressions[component] {
3535 &Expression::Literal(Literal::Bool(condition)) => condition,
3536 _ => unreachable!(),
3537 })
3538 .collect()
3539 }
3540
3541 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3542 };
3543
3544 let evaluated = Expression::Compose {
3545 ty: reject_ty,
3546 components: reject_components
3547 .clone()
3548 .into_iter()
3549 .zip(accept_components.clone().into_iter())
3550 .zip(condition_components.into_iter())
3551 .map(|((reject, accept), condition)| {
3552 let reject_scalar = match &self.expressions[reject] {
3553 &Expression::Literal(lit) => lit.scalar(),
3554 _ => unreachable!(),
3555 };
3556 select_single_component(self, reject_scalar, reject, accept, condition)
3557 })
3558 .collect::<Result<_, _>>()?,
3559 };
3560 self.register_evaluated_expr(evaluated, span)
3561 }
3562 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3563 }
3564 }
3565}
3566
3567fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3568 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3572 match e {
3573 idx @ 0..=31 => idx,
3574 32 => u32::MAX,
3575 _ => unreachable!(),
3576 }
3577 };
3578 match concrete_int {
3579 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3580 ConcreteInt::I32([e]) => {
3581 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3582 }
3583 }
3584}
3585
3586#[test]
3587fn first_trailing_bit_smoke() {
3588 assert_eq!(
3589 first_trailing_bit(ConcreteInt::I32([0])),
3590 ConcreteInt::I32([-1])
3591 );
3592 assert_eq!(
3593 first_trailing_bit(ConcreteInt::I32([1])),
3594 ConcreteInt::I32([0])
3595 );
3596 assert_eq!(
3597 first_trailing_bit(ConcreteInt::I32([2])),
3598 ConcreteInt::I32([1])
3599 );
3600 assert_eq!(
3601 first_trailing_bit(ConcreteInt::I32([-1])),
3602 ConcreteInt::I32([0]),
3603 );
3604 assert_eq!(
3605 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3606 ConcreteInt::I32([31]),
3607 );
3608 assert_eq!(
3609 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3610 ConcreteInt::I32([0]),
3611 );
3612 for idx in 0..32 {
3613 assert_eq!(
3614 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3615 ConcreteInt::I32([idx])
3616 )
3617 }
3618
3619 assert_eq!(
3620 first_trailing_bit(ConcreteInt::U32([0])),
3621 ConcreteInt::U32([u32::MAX])
3622 );
3623 assert_eq!(
3624 first_trailing_bit(ConcreteInt::U32([1])),
3625 ConcreteInt::U32([0])
3626 );
3627 assert_eq!(
3628 first_trailing_bit(ConcreteInt::U32([2])),
3629 ConcreteInt::U32([1])
3630 );
3631 assert_eq!(
3632 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3633 ConcreteInt::U32([31]),
3634 );
3635 assert_eq!(
3636 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3637 ConcreteInt::U32([0]),
3638 );
3639 for idx in 0..32 {
3640 assert_eq!(
3641 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3642 ConcreteInt::U32([idx])
3643 )
3644 }
3645}
3646
3647fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3648 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3652 match e {
3653 idx @ 0..=31 => 31 - idx,
3654 32 => u32::MAX,
3655 _ => unreachable!(),
3656 }
3657 };
3658 match concrete_int {
3659 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3660 let rtl_bit_index = if e.is_negative() {
3661 e.leading_ones()
3662 } else {
3663 e.leading_zeros()
3664 };
3665 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3666 }]),
3667 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3668 }
3669}
3670
3671#[test]
3672fn first_leading_bit_smoke() {
3673 assert_eq!(
3674 first_leading_bit(ConcreteInt::I32([-1])),
3675 ConcreteInt::I32([-1])
3676 );
3677 assert_eq!(
3678 first_leading_bit(ConcreteInt::I32([0])),
3679 ConcreteInt::I32([-1])
3680 );
3681 assert_eq!(
3682 first_leading_bit(ConcreteInt::I32([1])),
3683 ConcreteInt::I32([0])
3684 );
3685 assert_eq!(
3686 first_leading_bit(ConcreteInt::I32([-2])),
3687 ConcreteInt::I32([0])
3688 );
3689 assert_eq!(
3690 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3691 ConcreteInt::I32([12])
3692 );
3693 assert_eq!(
3694 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3695 ConcreteInt::I32([30])
3696 );
3697 assert_eq!(
3698 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3699 ConcreteInt::I32([30])
3700 );
3701 for idx in 0..(32 - 1) {
3703 assert_eq!(
3704 first_leading_bit(ConcreteInt::I32([1 << idx])),
3705 ConcreteInt::I32([idx])
3706 );
3707 }
3708 for idx in 1..(32 - 1) {
3709 assert_eq!(
3710 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3711 ConcreteInt::I32([idx - 1])
3712 );
3713 }
3714
3715 assert_eq!(
3716 first_leading_bit(ConcreteInt::U32([0])),
3717 ConcreteInt::U32([u32::MAX])
3718 );
3719 assert_eq!(
3720 first_leading_bit(ConcreteInt::U32([1])),
3721 ConcreteInt::U32([0])
3722 );
3723 assert_eq!(
3724 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3725 ConcreteInt::U32([31])
3726 );
3727 for idx in 0..32 {
3728 assert_eq!(
3729 first_leading_bit(ConcreteInt::U32([1 << idx])),
3730 ConcreteInt::U32([idx])
3731 )
3732 }
3733}
3734
3735trait TryFromAbstract<T>: Sized {
3737 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3759}
3760
3761impl TryFromAbstract<i64> for i32 {
3762 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3763 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3764 value: format!("{value:?}"),
3765 to_type: "i32",
3766 })
3767 }
3768}
3769
3770impl TryFromAbstract<i64> for u32 {
3771 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3772 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3773 value: format!("{value:?}"),
3774 to_type: "u32",
3775 })
3776 }
3777}
3778
3779impl TryFromAbstract<i64> for u64 {
3780 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3781 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3782 value: format!("{value:?}"),
3783 to_type: "u64",
3784 })
3785 }
3786}
3787
3788impl TryFromAbstract<i64> for i64 {
3789 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3790 Ok(value)
3791 }
3792}
3793
3794impl TryFromAbstract<i64> for f32 {
3795 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3796 let f = value as f32;
3797 Ok(f)
3801 }
3802}
3803
3804impl TryFromAbstract<f64> for f32 {
3805 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3806 let f = value as f32;
3807 if f.is_infinite() {
3808 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3809 value: format!("{value:?}"),
3810 to_type: "f32",
3811 });
3812 }
3813 Ok(f)
3814 }
3815}
3816
3817impl TryFromAbstract<i64> for f64 {
3818 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3819 let f = value as f64;
3820 Ok(f)
3824 }
3825}
3826
3827impl TryFromAbstract<f64> for f64 {
3828 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3829 Ok(value)
3830 }
3831}
3832
3833impl TryFromAbstract<f64> for i32 {
3834 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3835 Ok(value as i32)
3848 }
3849}
3850
3851impl TryFromAbstract<f64> for u32 {
3852 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3853 Ok(value as u32)
3856 }
3857}
3858
3859impl TryFromAbstract<f64> for i64 {
3860 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3861 use crate::proc::type_methods::IntFloatLimits;
3864 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3865 }
3866}
3867
3868impl TryFromAbstract<f64> for u64 {
3869 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3870 use crate::proc::type_methods::IntFloatLimits;
3873 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3874 }
3875}
3876
3877impl TryFromAbstract<f64> for f16 {
3878 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3879 let f = f16::from_f64(value);
3880 if f.is_infinite() {
3881 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3882 value: format!("{value:?}"),
3883 to_type: "f16",
3884 });
3885 }
3886 Ok(f)
3887 }
3888}
3889
3890impl TryFromAbstract<i64> for f16 {
3891 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3892 let f = f16::from_i64(value);
3893 if f.is_none() {
3894 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3895 value: format!("{value:?}"),
3896 to_type: "f16",
3897 });
3898 }
3899 Ok(f.unwrap())
3900 }
3901}
3902
3903fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3904where
3905 T: Copy,
3906 T: core::ops::Mul<T, Output = T>,
3907 T: core::ops::Sub<T, Output = T>,
3908{
3909 [
3910 a[1] * b[2] - a[2] * b[1],
3911 a[2] * b[0] - a[0] * b[2],
3912 a[0] * b[1] - a[1] * b[0],
3913 ]
3914}
3915
3916#[cfg(test)]
3917mod tests {
3918 use alloc::{vec, vec::Vec};
3919
3920 use crate::{
3921 Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
3922 Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
3923 };
3924
3925 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3926
3927 #[test]
3928 fn unary_op() {
3929 let mut types = UniqueArena::new();
3930 let mut constants = Arena::new();
3931 let overrides = Arena::new();
3932 let mut global_expressions = Arena::new();
3933
3934 let scalar_ty = types.insert(
3935 Type {
3936 name: None,
3937 inner: TypeInner::Scalar(crate::Scalar::I32),
3938 },
3939 Default::default(),
3940 );
3941
3942 let vec_ty = types.insert(
3943 Type {
3944 name: None,
3945 inner: TypeInner::Vector {
3946 size: VectorSize::Bi,
3947 scalar: crate::Scalar::I32,
3948 },
3949 },
3950 Default::default(),
3951 );
3952
3953 let h = constants.append(
3954 Constant {
3955 name: None,
3956 ty: scalar_ty,
3957 init: global_expressions
3958 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3959 },
3960 Default::default(),
3961 );
3962
3963 let h1 = constants.append(
3964 Constant {
3965 name: None,
3966 ty: scalar_ty,
3967 init: global_expressions
3968 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3969 },
3970 Default::default(),
3971 );
3972
3973 let vec_h = constants.append(
3974 Constant {
3975 name: None,
3976 ty: vec_ty,
3977 init: global_expressions.append(
3978 Expression::Compose {
3979 ty: vec_ty,
3980 components: vec![constants[h].init, constants[h1].init],
3981 },
3982 Default::default(),
3983 ),
3984 },
3985 Default::default(),
3986 );
3987
3988 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3989 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3990
3991 let expr2 = Expression::Unary {
3992 op: UnaryOperator::Negate,
3993 expr,
3994 };
3995
3996 let expr3 = Expression::Unary {
3997 op: UnaryOperator::BitwiseNot,
3998 expr,
3999 };
4000
4001 let expr4 = Expression::Unary {
4002 op: UnaryOperator::BitwiseNot,
4003 expr: expr1,
4004 };
4005
4006 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4007 let mut solver = ConstantEvaluator {
4008 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4009 types: &mut types,
4010 constants: &constants,
4011 overrides: &overrides,
4012 expressions: &mut global_expressions,
4013 expression_kind_tracker,
4014 layouter: &mut crate::proc::Layouter::default(),
4015 };
4016
4017 let res1 = solver
4018 .try_eval_and_append(expr2, Default::default())
4019 .unwrap();
4020 let res2 = solver
4021 .try_eval_and_append(expr3, Default::default())
4022 .unwrap();
4023 let res3 = solver
4024 .try_eval_and_append(expr4, Default::default())
4025 .unwrap();
4026
4027 assert_eq!(
4028 global_expressions[res1],
4029 Expression::Literal(Literal::I32(-4))
4030 );
4031
4032 assert_eq!(
4033 global_expressions[res2],
4034 Expression::Literal(Literal::I32(!4))
4035 );
4036
4037 let res3_inner = &global_expressions[res3];
4038
4039 match *res3_inner {
4040 Expression::Compose {
4041 ref ty,
4042 ref components,
4043 } => {
4044 assert_eq!(*ty, vec_ty);
4045 let mut components_iter = components.iter().copied();
4046 assert_eq!(
4047 global_expressions[components_iter.next().unwrap()],
4048 Expression::Literal(Literal::I32(!4))
4049 );
4050 assert_eq!(
4051 global_expressions[components_iter.next().unwrap()],
4052 Expression::Literal(Literal::I32(!8))
4053 );
4054 assert!(components_iter.next().is_none());
4055 }
4056 _ => panic!("Expected vector"),
4057 }
4058 }
4059
4060 #[test]
4061 fn matrix_op() {
4062 let mut helper = MatrixTestHelper::new();
4063
4064 for nc in 2..=4 {
4065 for nr in 2..=4 {
4066 let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4069 let expected = (0..nc)
4070 .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4071 .collect::<Vec<f32>>();
4072 assert_eq!(evaluated, expected);
4073
4074 let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4077 let expected = (0..nr)
4078 .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4079 .collect::<Vec<f32>>();
4080 assert_eq!(evaluated, expected);
4081
4082 for k in 2..=4 {
4083 let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4086 let expected = (0..nc)
4087 .flat_map(|c| {
4088 (0..nr).map(move |r| {
4089 (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4090 })
4091 })
4092 .collect::<Vec<f32>>();
4093 assert_eq!(evaluated, expected);
4094 }
4095 }
4096 }
4097 }
4098
4099 struct MatrixTestHelper {
4103 types: UniqueArena<Type>,
4104 expressions: Arena<Expression>,
4105 vec_exprs: FastHashMap<usize, Handle<Expression>>,
4107 mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4109 }
4110
4111 impl MatrixTestHelper {
4112 fn new() -> Self {
4113 let mut types = UniqueArena::new();
4114 let mut expressions = Arena::new();
4115 let span = crate::Span::default();
4116
4117 let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4118 for c in 2..=4 {
4119 let vec_ty = types.insert(
4120 Type {
4121 name: None,
4122 inner: TypeInner::Vector {
4123 size: Self::int_to_vector_size(c),
4124 scalar: crate::Scalar::F32,
4125 },
4126 },
4127 span,
4128 );
4129 vec_tys.insert(c, vec_ty);
4130 for r in 2..=4 {
4131 let mat_ty = types.insert(
4132 Type {
4133 name: None,
4134 inner: TypeInner::Matrix {
4135 columns: Self::int_to_vector_size(c),
4136 rows: Self::int_to_vector_size(r),
4137 scalar: crate::Scalar::F32,
4138 },
4139 },
4140 span,
4141 );
4142 mat_tys.insert((c, r), mat_ty);
4143 }
4144 }
4145
4146 let mut lit_exprs = FastHashMap::default();
4147 for i in 0..16 {
4148 let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4149 lit_exprs.insert(i, expr);
4150 }
4151
4152 let mut vec_exprs = FastHashMap::default();
4153 for c in 2..=4 {
4154 let expr = expressions.append(
4155 Expression::Compose {
4156 ty: *vec_tys.get(&c).unwrap(),
4157 components: (0..c)
4158 .map(|i| *lit_exprs.get(&i).unwrap())
4159 .collect::<Vec<_>>(),
4160 },
4161 span,
4162 );
4163 vec_exprs.insert(c, expr);
4164 }
4165
4166 let mut mat_exprs = FastHashMap::default();
4167 for c in 2..=4 {
4168 for r in 2..=4 {
4169 let mut columns = Vec::with_capacity(c);
4170 for cc in 0..c {
4171 let start = cc * r;
4172 let expr = expressions.append(
4173 Expression::Compose {
4174 ty: *vec_tys.get(&r).unwrap(),
4175 components: (start..start + r)
4176 .map(|i| *lit_exprs.get(&i).unwrap())
4177 .collect::<Vec<_>>(),
4178 },
4179 span,
4180 );
4181 columns.push(expr);
4182 }
4183
4184 let expr = expressions.append(
4185 Expression::Compose {
4186 ty: *mat_tys.get(&(c, r)).unwrap(),
4187 components: columns,
4188 },
4189 span,
4190 );
4191 mat_exprs.insert((c, r), expr);
4192 }
4193 }
4194
4195 Self {
4196 types,
4197 expressions,
4198 vec_exprs,
4199 mat_exprs,
4200 }
4201 }
4202
4203 fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4205 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4206 let mut solver = ConstantEvaluator {
4207 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4208 types: &mut self.types,
4209 constants: &Arena::new(),
4210 overrides: &Arena::new(),
4211 expressions: &mut self.expressions,
4212 expression_kind_tracker,
4213 layouter: &mut crate::proc::Layouter::default(),
4214 };
4215
4216 let result = solver
4217 .try_eval_and_append(
4218 Expression::Binary {
4219 op: BinaryOperator::Multiply,
4220 left: *self.vec_exprs.get(&nr).unwrap(),
4221 right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4222 },
4223 Default::default(),
4224 )
4225 .unwrap();
4226 self.flatten(result)
4227 }
4228
4229 fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4231 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4232 let mut solver = ConstantEvaluator {
4233 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4234 types: &mut self.types,
4235 constants: &Arena::new(),
4236 overrides: &Arena::new(),
4237 expressions: &mut self.expressions,
4238 expression_kind_tracker,
4239 layouter: &mut crate::proc::Layouter::default(),
4240 };
4241
4242 let result = solver
4243 .try_eval_and_append(
4244 Expression::Binary {
4245 op: BinaryOperator::Multiply,
4246 left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4247 right: *self.vec_exprs.get(&nc).unwrap(),
4248 },
4249 Default::default(),
4250 )
4251 .unwrap();
4252 self.flatten(result)
4253 }
4254
4255 fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4258 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4259 let mut solver = ConstantEvaluator {
4260 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4261 types: &mut self.types,
4262 constants: &Arena::new(),
4263 overrides: &Arena::new(),
4264 expressions: &mut self.expressions,
4265 expression_kind_tracker,
4266 layouter: &mut crate::proc::Layouter::default(),
4267 };
4268
4269 let result = solver
4270 .try_eval_and_append(
4271 Expression::Binary {
4272 op: BinaryOperator::Multiply,
4273 left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4274 right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4275 },
4276 Default::default(),
4277 )
4278 .unwrap();
4279 self.flatten(result)
4280 }
4281
4282 fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4283 let Expression::Compose {
4284 ref components,
4285 ref ty,
4286 } = self.expressions[expr]
4287 else {
4288 unreachable!()
4289 };
4290
4291 match self.types[*ty].inner {
4292 TypeInner::Vector { .. } => components
4293 .iter()
4294 .map(|&comp| {
4295 let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4296 unreachable!()
4297 };
4298 v
4299 })
4300 .collect(),
4301 TypeInner::Matrix { .. } => components
4302 .iter()
4303 .flat_map(|&comp| self.flatten(comp))
4304 .collect(),
4305 _ => unreachable!(),
4306 }
4307 }
4308
4309 fn int_to_vector_size(int: usize) -> VectorSize {
4310 match int {
4311 2 => VectorSize::Bi,
4312 3 => VectorSize::Tri,
4313 4 => VectorSize::Quad,
4314 _ => unreachable!(),
4315 }
4316 }
4317 }
4318
4319 #[test]
4320 fn cast() {
4321 let mut types = UniqueArena::new();
4322 let mut constants = Arena::new();
4323 let overrides = Arena::new();
4324 let mut global_expressions = Arena::new();
4325
4326 let scalar_ty = types.insert(
4327 Type {
4328 name: None,
4329 inner: TypeInner::Scalar(crate::Scalar::I32),
4330 },
4331 Default::default(),
4332 );
4333
4334 let h = constants.append(
4335 Constant {
4336 name: None,
4337 ty: scalar_ty,
4338 init: global_expressions
4339 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4340 },
4341 Default::default(),
4342 );
4343
4344 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4345
4346 let root = Expression::As {
4347 expr,
4348 kind: ScalarKind::Bool,
4349 convert: Some(crate::BOOL_WIDTH),
4350 };
4351
4352 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4353 let mut solver = ConstantEvaluator {
4354 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4355 types: &mut types,
4356 constants: &constants,
4357 overrides: &overrides,
4358 expressions: &mut global_expressions,
4359 expression_kind_tracker,
4360 layouter: &mut crate::proc::Layouter::default(),
4361 };
4362
4363 let res = solver
4364 .try_eval_and_append(root, Default::default())
4365 .unwrap();
4366
4367 assert_eq!(
4368 global_expressions[res],
4369 Expression::Literal(Literal::Bool(true))
4370 );
4371 }
4372
4373 #[test]
4374 fn access() {
4375 let mut types = UniqueArena::new();
4376 let mut constants = Arena::new();
4377 let overrides = Arena::new();
4378 let mut global_expressions = Arena::new();
4379
4380 let matrix_ty = types.insert(
4381 Type {
4382 name: None,
4383 inner: TypeInner::Matrix {
4384 columns: VectorSize::Bi,
4385 rows: VectorSize::Tri,
4386 scalar: crate::Scalar::F32,
4387 },
4388 },
4389 Default::default(),
4390 );
4391
4392 let vec_ty = types.insert(
4393 Type {
4394 name: None,
4395 inner: TypeInner::Vector {
4396 size: VectorSize::Tri,
4397 scalar: crate::Scalar::F32,
4398 },
4399 },
4400 Default::default(),
4401 );
4402
4403 let mut vec1_components = Vec::with_capacity(3);
4404 let mut vec2_components = Vec::with_capacity(3);
4405
4406 for i in 0..3 {
4407 let h = global_expressions.append(
4408 Expression::Literal(Literal::F32(i as f32)),
4409 Default::default(),
4410 );
4411
4412 vec1_components.push(h)
4413 }
4414
4415 for i in 3..6 {
4416 let h = global_expressions.append(
4417 Expression::Literal(Literal::F32(i as f32)),
4418 Default::default(),
4419 );
4420
4421 vec2_components.push(h)
4422 }
4423
4424 let vec1 = constants.append(
4425 Constant {
4426 name: None,
4427 ty: vec_ty,
4428 init: global_expressions.append(
4429 Expression::Compose {
4430 ty: vec_ty,
4431 components: vec1_components,
4432 },
4433 Default::default(),
4434 ),
4435 },
4436 Default::default(),
4437 );
4438
4439 let vec2 = constants.append(
4440 Constant {
4441 name: None,
4442 ty: vec_ty,
4443 init: global_expressions.append(
4444 Expression::Compose {
4445 ty: vec_ty,
4446 components: vec2_components,
4447 },
4448 Default::default(),
4449 ),
4450 },
4451 Default::default(),
4452 );
4453
4454 let h = constants.append(
4455 Constant {
4456 name: None,
4457 ty: matrix_ty,
4458 init: global_expressions.append(
4459 Expression::Compose {
4460 ty: matrix_ty,
4461 components: vec![constants[vec1].init, constants[vec2].init],
4462 },
4463 Default::default(),
4464 ),
4465 },
4466 Default::default(),
4467 );
4468
4469 let base = global_expressions.append(Expression::Constant(h), Default::default());
4470
4471 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4472 let mut solver = ConstantEvaluator {
4473 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4474 types: &mut types,
4475 constants: &constants,
4476 overrides: &overrides,
4477 expressions: &mut global_expressions,
4478 expression_kind_tracker,
4479 layouter: &mut crate::proc::Layouter::default(),
4480 };
4481
4482 let root1 = Expression::AccessIndex { base, index: 1 };
4483
4484 let res1 = solver
4485 .try_eval_and_append(root1, Default::default())
4486 .unwrap();
4487
4488 let root2 = Expression::AccessIndex {
4489 base: res1,
4490 index: 2,
4491 };
4492
4493 let res2 = solver
4494 .try_eval_and_append(root2, Default::default())
4495 .unwrap();
4496
4497 match global_expressions[res1] {
4498 Expression::Compose {
4499 ref ty,
4500 ref components,
4501 } => {
4502 assert_eq!(*ty, vec_ty);
4503 let mut components_iter = components.iter().copied();
4504 assert_eq!(
4505 global_expressions[components_iter.next().unwrap()],
4506 Expression::Literal(Literal::F32(3.))
4507 );
4508 assert_eq!(
4509 global_expressions[components_iter.next().unwrap()],
4510 Expression::Literal(Literal::F32(4.))
4511 );
4512 assert_eq!(
4513 global_expressions[components_iter.next().unwrap()],
4514 Expression::Literal(Literal::F32(5.))
4515 );
4516 assert!(components_iter.next().is_none());
4517 }
4518 _ => panic!("Expected vector"),
4519 }
4520
4521 assert_eq!(
4522 global_expressions[res2],
4523 Expression::Literal(Literal::F32(5.))
4524 );
4525 }
4526
4527 #[test]
4528 fn compose_of_constants() {
4529 let mut types = UniqueArena::new();
4530 let mut constants = Arena::new();
4531 let overrides = Arena::new();
4532 let mut global_expressions = Arena::new();
4533
4534 let i32_ty = types.insert(
4535 Type {
4536 name: None,
4537 inner: TypeInner::Scalar(crate::Scalar::I32),
4538 },
4539 Default::default(),
4540 );
4541
4542 let vec2_i32_ty = types.insert(
4543 Type {
4544 name: None,
4545 inner: TypeInner::Vector {
4546 size: VectorSize::Bi,
4547 scalar: crate::Scalar::I32,
4548 },
4549 },
4550 Default::default(),
4551 );
4552
4553 let h = constants.append(
4554 Constant {
4555 name: None,
4556 ty: i32_ty,
4557 init: global_expressions
4558 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4559 },
4560 Default::default(),
4561 );
4562
4563 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4564
4565 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4566 let mut solver = ConstantEvaluator {
4567 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4568 types: &mut types,
4569 constants: &constants,
4570 overrides: &overrides,
4571 expressions: &mut global_expressions,
4572 expression_kind_tracker,
4573 layouter: &mut crate::proc::Layouter::default(),
4574 };
4575
4576 let solved_compose = solver
4577 .try_eval_and_append(
4578 Expression::Compose {
4579 ty: vec2_i32_ty,
4580 components: vec![h_expr, h_expr],
4581 },
4582 Default::default(),
4583 )
4584 .unwrap();
4585 let solved_negate = solver
4586 .try_eval_and_append(
4587 Expression::Unary {
4588 op: UnaryOperator::Negate,
4589 expr: solved_compose,
4590 },
4591 Default::default(),
4592 )
4593 .unwrap();
4594
4595 let pass = match global_expressions[solved_negate] {
4596 Expression::Compose { ty, ref components } => {
4597 ty == vec2_i32_ty
4598 && components.iter().all(|&component| {
4599 let component = &global_expressions[component];
4600 matches!(*component, Expression::Literal(Literal::I32(-4)))
4601 })
4602 }
4603 _ => false,
4604 };
4605 if !pass {
4606 panic!("unexpected evaluation result")
4607 }
4608 }
4609
4610 #[test]
4611 fn splat_of_constant() {
4612 let mut types = UniqueArena::new();
4613 let mut constants = Arena::new();
4614 let overrides = Arena::new();
4615 let mut global_expressions = Arena::new();
4616
4617 let i32_ty = types.insert(
4618 Type {
4619 name: None,
4620 inner: TypeInner::Scalar(crate::Scalar::I32),
4621 },
4622 Default::default(),
4623 );
4624
4625 let vec2_i32_ty = types.insert(
4626 Type {
4627 name: None,
4628 inner: TypeInner::Vector {
4629 size: VectorSize::Bi,
4630 scalar: crate::Scalar::I32,
4631 },
4632 },
4633 Default::default(),
4634 );
4635
4636 let h = constants.append(
4637 Constant {
4638 name: None,
4639 ty: i32_ty,
4640 init: global_expressions
4641 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4642 },
4643 Default::default(),
4644 );
4645
4646 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4647
4648 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4649 let mut solver = ConstantEvaluator {
4650 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4651 types: &mut types,
4652 constants: &constants,
4653 overrides: &overrides,
4654 expressions: &mut global_expressions,
4655 expression_kind_tracker,
4656 layouter: &mut crate::proc::Layouter::default(),
4657 };
4658
4659 let solved_compose = solver
4660 .try_eval_and_append(
4661 Expression::Splat {
4662 size: VectorSize::Bi,
4663 value: h_expr,
4664 },
4665 Default::default(),
4666 )
4667 .unwrap();
4668 let solved_negate = solver
4669 .try_eval_and_append(
4670 Expression::Unary {
4671 op: UnaryOperator::Negate,
4672 expr: solved_compose,
4673 },
4674 Default::default(),
4675 )
4676 .unwrap();
4677
4678 let pass = match global_expressions[solved_negate] {
4679 Expression::Compose { ty, ref components } => {
4680 ty == vec2_i32_ty
4681 && components.iter().all(|&component| {
4682 let component = &global_expressions[component];
4683 matches!(*component, Expression::Literal(Literal::I32(-4)))
4684 })
4685 }
4686 _ => false,
4687 };
4688 if !pass {
4689 panic!("unexpected evaluation result")
4690 }
4691 }
4692
4693 #[test]
4694 fn splat_of_zero_value() {
4695 let mut types = UniqueArena::new();
4696 let constants = Arena::new();
4697 let overrides = Arena::new();
4698 let mut global_expressions = Arena::new();
4699
4700 let f32_ty = types.insert(
4701 Type {
4702 name: None,
4703 inner: TypeInner::Scalar(crate::Scalar::F32),
4704 },
4705 Default::default(),
4706 );
4707
4708 let vec2_f32_ty = types.insert(
4709 Type {
4710 name: None,
4711 inner: TypeInner::Vector {
4712 size: VectorSize::Bi,
4713 scalar: crate::Scalar::F32,
4714 },
4715 },
4716 Default::default(),
4717 );
4718
4719 let five =
4720 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4721 let five_splat = global_expressions.append(
4722 Expression::Splat {
4723 size: VectorSize::Bi,
4724 value: five,
4725 },
4726 Default::default(),
4727 );
4728 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4729 let zero_splat = global_expressions.append(
4730 Expression::Splat {
4731 size: VectorSize::Bi,
4732 value: zero,
4733 },
4734 Default::default(),
4735 );
4736
4737 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4738 let mut solver = ConstantEvaluator {
4739 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4740 types: &mut types,
4741 constants: &constants,
4742 overrides: &overrides,
4743 expressions: &mut global_expressions,
4744 expression_kind_tracker,
4745 layouter: &mut crate::proc::Layouter::default(),
4746 };
4747
4748 let solved_add = solver
4749 .try_eval_and_append(
4750 Expression::Binary {
4751 op: BinaryOperator::Add,
4752 left: zero_splat,
4753 right: five_splat,
4754 },
4755 Default::default(),
4756 )
4757 .unwrap();
4758
4759 let pass = match global_expressions[solved_add] {
4760 Expression::Compose { ty, ref components } => {
4761 ty == vec2_f32_ty
4762 && components.iter().all(|&component| {
4763 let component = &global_expressions[component];
4764 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4765 })
4766 }
4767 _ => false,
4768 };
4769 if !pass {
4770 panic!("unexpected evaluation result")
4771 }
4772 }
4773}