1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
272enum Behavior<'a> {
273 Wgsl(WgslRestrictions<'a>),
274 Glsl(GlslRestrictions<'a>),
275}
276
277impl Behavior<'_> {
278 const fn has_runtime_restrictions(&self) -> bool {
280 matches!(
281 self,
282 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
283 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
284 )
285 }
286}
287
288#[derive(Debug)]
306pub struct ConstantEvaluator<'a> {
307 behavior: Behavior<'a>,
309
310 types: &'a mut UniqueArena<Type>,
317
318 constants: &'a Arena<Constant>,
320
321 overrides: &'a Arena<Override>,
323
324 expressions: &'a mut Arena<Expression>,
326
327 expression_kind_tracker: &'a mut ExpressionKindTracker,
329
330 layouter: &'a mut crate::proc::Layouter,
331}
332
333#[derive(Debug)]
334enum WgslRestrictions<'a> {
335 Const(Option<FunctionLocalData<'a>>),
337 Override,
340 Runtime(FunctionLocalData<'a>),
344}
345
346#[derive(Debug)]
347enum GlslRestrictions<'a> {
348 Const,
350 Runtime(FunctionLocalData<'a>),
354}
355
356#[derive(Debug)]
357struct FunctionLocalData<'a> {
358 global_expressions: &'a Arena<Expression>,
360 emitter: &'a mut super::Emitter,
361 block: &'a mut crate::Block,
362}
363
364#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
365pub enum ExpressionKind {
366 Const,
367 Override,
368 Runtime,
369}
370
371#[derive(Debug)]
372pub struct ExpressionKindTracker {
373 inner: HandleVec<Expression, ExpressionKind>,
374}
375
376impl ExpressionKindTracker {
377 pub const fn new() -> Self {
378 Self {
379 inner: HandleVec::new(),
380 }
381 }
382
383 pub fn force_non_const(&mut self, value: Handle<Expression>) {
385 self.inner[value] = ExpressionKind::Runtime;
386 }
387
388 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
389 self.inner.insert(value, expr_type);
390 }
391
392 pub fn is_const(&self, h: Handle<Expression>) -> bool {
393 matches!(self.type_of(h), ExpressionKind::Const)
394 }
395
396 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
397 matches!(
398 self.type_of(h),
399 ExpressionKind::Const | ExpressionKind::Override
400 )
401 }
402
403 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
404 self.inner[value]
405 }
406
407 pub fn from_arena(arena: &Arena<Expression>) -> Self {
408 let mut tracker = Self {
409 inner: HandleVec::with_capacity(arena.len()),
410 };
411 for (handle, expr) in arena.iter() {
412 tracker
413 .inner
414 .insert(handle, tracker.type_of_with_expr(expr));
415 }
416 tracker
417 }
418
419 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
420 match *expr {
421 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
422 ExpressionKind::Const
423 }
424 Expression::Override(_) => ExpressionKind::Override,
425 Expression::Compose { ref components, .. } => {
426 let mut expr_type = ExpressionKind::Const;
427 for component in components {
428 expr_type = expr_type.max(self.type_of(*component))
429 }
430 expr_type
431 }
432 Expression::Splat { value, .. } => self.type_of(value),
433 Expression::AccessIndex { base, .. } => self.type_of(base),
434 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
435 Expression::Swizzle { vector, .. } => self.type_of(vector),
436 Expression::Unary { expr, .. } => self.type_of(expr),
437 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
438 Expression::Math {
439 arg,
440 arg1,
441 arg2,
442 arg3,
443 ..
444 } => self
445 .type_of(arg)
446 .max(
447 arg1.map(|arg| self.type_of(arg))
448 .unwrap_or(ExpressionKind::Const),
449 )
450 .max(
451 arg2.map(|arg| self.type_of(arg))
452 .unwrap_or(ExpressionKind::Const),
453 )
454 .max(
455 arg3.map(|arg| self.type_of(arg))
456 .unwrap_or(ExpressionKind::Const),
457 ),
458 Expression::As { expr, .. } => self.type_of(expr),
459 Expression::Select {
460 condition,
461 accept,
462 reject,
463 } => self
464 .type_of(condition)
465 .max(self.type_of(accept))
466 .max(self.type_of(reject)),
467 Expression::Relational { argument, .. } => self.type_of(argument),
468 Expression::ArrayLength(expr) => self.type_of(expr),
469 _ => ExpressionKind::Runtime,
470 }
471 }
472}
473
474#[derive(Clone, Debug, thiserror::Error)]
475#[cfg_attr(test, derive(PartialEq))]
476pub enum ConstantEvaluatorError {
477 #[error("Constants cannot access function arguments")]
478 FunctionArg,
479 #[error("Constants cannot access global variables")]
480 GlobalVariable,
481 #[error("Constants cannot access local variables")]
482 LocalVariable,
483 #[error("Cannot get the array length of a non array type")]
484 InvalidArrayLengthArg,
485 #[error("Constants cannot get the array length of a dynamically sized array")]
486 ArrayLengthDynamic,
487 #[error("Cannot call arrayLength on array sized by override-expression")]
488 ArrayLengthOverridden,
489 #[error("Constants cannot call functions")]
490 Call,
491 #[error("Constants don't support workGroupUniformLoad")]
492 WorkGroupUniformLoadResult,
493 #[error("Constants don't support atomic functions")]
494 Atomic,
495 #[error("Constants don't support derivative functions")]
496 Derivative,
497 #[error("Constants don't support load expressions")]
498 Load,
499 #[error("Constants don't support image expressions")]
500 ImageExpression,
501 #[error("Constants don't support ray query expressions")]
502 RayQueryExpression,
503 #[error("Constants don't support subgroup expressions")]
504 SubgroupExpression,
505 #[error("Cannot access the type")]
506 InvalidAccessBase,
507 #[error("Cannot access at the index")]
508 InvalidAccessIndex,
509 #[error("Cannot access with index of type")]
510 InvalidAccessIndexTy,
511 #[error("Constants don't support array length expressions")]
512 ArrayLength,
513 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
514 InvalidCastArg { from: String, to: String },
515 #[error("Cannot apply the unary op to the argument")]
516 InvalidUnaryOpArg,
517 #[error("Cannot apply the binary op to the arguments")]
518 InvalidBinaryOpArgs,
519 #[error("Cannot apply math function to type")]
520 InvalidMathArg,
521 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
522 InvalidMathArgCount(crate::MathFunction, usize, usize),
523 #[error("Cannot apply relational function to type")]
524 InvalidRelationalArg(RelationalFunction),
525 #[error("value of `low` is greater than `high` for clamp built-in function")]
526 InvalidClamp,
527 #[error("Constructor expects {expected} components, found {actual}")]
528 InvalidVectorComposeLength { expected: usize, actual: usize },
529 #[error("Constructor must only contain vector or scalar arguments")]
530 InvalidVectorComposeComponent,
531 #[error("Splat is defined only on scalar values")]
532 SplatScalarOnly,
533 #[error("Can only swizzle vector constants")]
534 SwizzleVectorOnly,
535 #[error("swizzle component not present in source expression")]
536 SwizzleOutOfBounds,
537 #[error("Type is not constructible")]
538 TypeNotConstructible,
539 #[error("Subexpression(s) are not constant")]
540 SubexpressionsAreNotConstant,
541 #[error("Not implemented as constant expression: {0}")]
542 NotImplemented(String),
543 #[error("{0} operation overflowed")]
544 Overflow(String),
545 #[error(
546 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
547 )]
548 AutomaticConversionLossy {
549 value: String,
550 to_type: &'static str,
551 },
552 #[error("Division by zero")]
553 DivisionByZero,
554 #[error("Remainder by zero")]
555 RemainderByZero,
556 #[error("RHS of shift operation is greater than or equal to 32")]
557 ShiftedMoreThan32Bits,
558 #[error(transparent)]
559 Literal(#[from] crate::valid::LiteralError),
560 #[error("Can't use pipeline-overridable constants in const-expressions")]
561 Override,
562 #[error("Unexpected runtime-expression")]
563 RuntimeExpr,
564 #[error("Unexpected override-expression")]
565 OverrideExpr,
566 #[error("Expected boolean expression for condition argument of `select`, got something else")]
567 SelectScalarConditionNotABool,
568 #[error(
569 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
570 reject,
571 accept
572 )]
573 SelectVecRejectAcceptSizeMismatch {
574 reject: crate::VectorSize,
575 accept: crate::VectorSize,
576 },
577 #[error("Expected boolean vector for condition arg., got something else")]
578 SelectConditionNotAVecBool,
579 #[error(
580 "Expected same number of vector components between condition, accept, and reject args., got something else",
581 )]
582 SelectConditionVecSizeMismatch,
583 #[error(
584 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
585 )]
586 SelectAcceptRejectTypeMismatch,
587}
588
589impl<'a> ConstantEvaluator<'a> {
590 pub fn for_wgsl_module(
595 module: &'a mut crate::Module,
596 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
597 layouter: &'a mut crate::proc::Layouter,
598 in_override_ctx: bool,
599 ) -> Self {
600 Self::for_module(
601 Behavior::Wgsl(if in_override_ctx {
602 WgslRestrictions::Override
603 } else {
604 WgslRestrictions::Const(None)
605 }),
606 module,
607 global_expression_kind_tracker,
608 layouter,
609 )
610 }
611
612 pub fn for_glsl_module(
617 module: &'a mut crate::Module,
618 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
619 layouter: &'a mut crate::proc::Layouter,
620 ) -> Self {
621 Self::for_module(
622 Behavior::Glsl(GlslRestrictions::Const),
623 module,
624 global_expression_kind_tracker,
625 layouter,
626 )
627 }
628
629 fn for_module(
630 behavior: Behavior<'a>,
631 module: &'a mut crate::Module,
632 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
633 layouter: &'a mut crate::proc::Layouter,
634 ) -> Self {
635 Self {
636 behavior,
637 types: &mut module.types,
638 constants: &module.constants,
639 overrides: &module.overrides,
640 expressions: &mut module.global_expressions,
641 expression_kind_tracker: global_expression_kind_tracker,
642 layouter,
643 }
644 }
645
646 pub fn for_wgsl_function(
651 module: &'a mut crate::Module,
652 expressions: &'a mut Arena<Expression>,
653 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
654 layouter: &'a mut crate::proc::Layouter,
655 emitter: &'a mut super::Emitter,
656 block: &'a mut crate::Block,
657 is_const: bool,
658 ) -> Self {
659 let local_data = FunctionLocalData {
660 global_expressions: &module.global_expressions,
661 emitter,
662 block,
663 };
664 Self {
665 behavior: Behavior::Wgsl(if is_const {
666 WgslRestrictions::Const(Some(local_data))
667 } else {
668 WgslRestrictions::Runtime(local_data)
669 }),
670 types: &mut module.types,
671 constants: &module.constants,
672 overrides: &module.overrides,
673 expressions,
674 expression_kind_tracker: local_expression_kind_tracker,
675 layouter,
676 }
677 }
678
679 pub fn for_glsl_function(
684 module: &'a mut crate::Module,
685 expressions: &'a mut Arena<Expression>,
686 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
687 layouter: &'a mut crate::proc::Layouter,
688 emitter: &'a mut super::Emitter,
689 block: &'a mut crate::Block,
690 ) -> Self {
691 Self {
692 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
693 global_expressions: &module.global_expressions,
694 emitter,
695 block,
696 })),
697 types: &mut module.types,
698 constants: &module.constants,
699 overrides: &module.overrides,
700 expressions,
701 expression_kind_tracker: local_expression_kind_tracker,
702 layouter,
703 }
704 }
705
706 pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
707 crate::proc::GlobalCtx {
708 types: self.types,
709 constants: self.constants,
710 overrides: self.overrides,
711 global_expressions: match self.function_local_data() {
712 Some(data) => data.global_expressions,
713 None => self.expressions,
714 },
715 }
716 }
717
718 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
719 if !self.expression_kind_tracker.is_const(expr) {
720 log::debug!("check: SubexpressionsAreNotConstant");
721 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
722 }
723 Ok(())
724 }
725
726 fn check_and_get(
727 &mut self,
728 expr: Handle<Expression>,
729 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
730 match self.expressions[expr] {
731 Expression::Constant(c) => {
732 if let Some(function_local_data) = self.function_local_data() {
735 self.copy_from(
737 self.constants[c].init,
738 function_local_data.global_expressions,
739 )
740 } else {
741 Ok(self.constants[c].init)
743 }
744 }
745 _ => {
746 self.check(expr)?;
747 Ok(expr)
748 }
749 }
750 }
751
752 pub fn try_eval_and_append(
776 &mut self,
777 expr: Expression,
778 span: Span,
779 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
780 match self.expression_kind_tracker.type_of_with_expr(&expr) {
781 ExpressionKind::Const => {
782 let eval_result = self.try_eval_and_append_impl(&expr, span);
783 if self.behavior.has_runtime_restrictions()
788 && matches!(
789 eval_result,
790 Err(ConstantEvaluatorError::NotImplemented(_)
791 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
792 )
793 {
794 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
795 } else {
796 eval_result
797 }
798 }
799 ExpressionKind::Override => match self.behavior {
800 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
801 Ok(self.append_expr(expr, span, ExpressionKind::Override))
802 }
803 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
804 Err(ConstantEvaluatorError::OverrideExpr)
805 }
806 Behavior::Glsl(_) => {
807 unreachable!()
808 }
809 },
810 ExpressionKind::Runtime => {
811 if self.behavior.has_runtime_restrictions() {
812 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
813 } else {
814 Err(ConstantEvaluatorError::RuntimeExpr)
815 }
816 }
817 }
818 }
819
820 const fn is_global_arena(&self) -> bool {
822 matches!(
823 self.behavior,
824 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
825 | Behavior::Glsl(GlslRestrictions::Const)
826 )
827 }
828
829 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
830 match self.behavior {
831 Behavior::Wgsl(
832 WgslRestrictions::Runtime(ref function_local_data)
833 | WgslRestrictions::Const(Some(ref function_local_data)),
834 )
835 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
836 Some(function_local_data)
837 }
838 _ => None,
839 }
840 }
841
842 fn try_eval_and_append_impl(
843 &mut self,
844 expr: &Expression,
845 span: Span,
846 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
847 log::trace!("try_eval_and_append: {expr:?}");
848 match *expr {
849 Expression::Constant(c) if self.is_global_arena() => {
850 Ok(self.constants[c].init)
853 }
854 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
855 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
856 self.register_evaluated_expr(expr.clone(), span)
857 }
858 Expression::Compose { ty, ref components } => {
859 let components = components
860 .iter()
861 .map(|component| self.check_and_get(*component))
862 .collect::<Result<Vec<_>, _>>()?;
863 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
864 }
865 Expression::Splat { size, value } => {
866 let value = self.check_and_get(value)?;
867 self.register_evaluated_expr(Expression::Splat { size, value }, span)
868 }
869 Expression::AccessIndex { base, index } => {
870 let base = self.check_and_get(base)?;
871
872 self.access(base, index as usize, span)
873 }
874 Expression::Access { base, index } => {
875 let base = self.check_and_get(base)?;
876 let index = self.check_and_get(index)?;
877
878 self.access(base, self.constant_index(index)?, span)
879 }
880 Expression::Swizzle {
881 size,
882 vector,
883 pattern,
884 } => {
885 let vector = self.check_and_get(vector)?;
886
887 self.swizzle(size, span, vector, pattern)
888 }
889 Expression::Unary { expr, op } => {
890 let expr = self.check_and_get(expr)?;
891
892 self.unary_op(op, expr, span)
893 }
894 Expression::Binary { left, right, op } => {
895 let left = self.check_and_get(left)?;
896 let right = self.check_and_get(right)?;
897
898 self.binary_op(op, left, right, span)
899 }
900 Expression::Math {
901 fun,
902 arg,
903 arg1,
904 arg2,
905 arg3,
906 } => {
907 let arg = self.check_and_get(arg)?;
908 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
909 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
910 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
911
912 self.math(arg, arg1, arg2, arg3, fun, span)
913 }
914 Expression::As {
915 convert,
916 expr,
917 kind,
918 } => {
919 let expr = self.check_and_get(expr)?;
920
921 match convert {
922 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
923 None => Err(ConstantEvaluatorError::NotImplemented(
924 "bitcast built-in function".into(),
925 )),
926 }
927 }
928 Expression::Select {
929 reject,
930 accept,
931 condition,
932 } => {
933 let mut arg = |expr| self.check_and_get(expr);
934
935 let reject = arg(reject)?;
936 let accept = arg(accept)?;
937 let condition = arg(condition)?;
938
939 self.select(reject, accept, condition, span)
940 }
941 Expression::Relational { fun, argument } => {
942 let argument = self.check_and_get(argument)?;
943 self.relational(fun, argument, span)
944 }
945 Expression::ArrayLength(expr) => match self.behavior {
946 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
947 Behavior::Glsl(_) => {
948 let expr = self.check_and_get(expr)?;
949 self.array_length(expr, span)
950 }
951 },
952 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
953 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
954 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
955 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
956 Expression::WorkGroupUniformLoadResult { .. } => {
957 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
958 }
959 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
960 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
961 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
962 Expression::ImageSample { .. }
963 | Expression::ImageLoad { .. }
964 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
965 Expression::RayQueryProceedResult
966 | Expression::RayQueryGetIntersection { .. }
967 | Expression::RayQueryVertexPositions { .. } => {
968 Err(ConstantEvaluatorError::RayQueryExpression)
969 }
970 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
971 Expression::SubgroupOperationResult { .. } => {
972 Err(ConstantEvaluatorError::SubgroupExpression)
973 }
974 }
975 }
976
977 fn splat(
990 &mut self,
991 value: Handle<Expression>,
992 size: crate::VectorSize,
993 span: Span,
994 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
995 match self.expressions[value] {
996 Expression::Literal(literal) => {
997 let scalar = literal.scalar();
998 let ty = self.types.insert(
999 Type {
1000 name: None,
1001 inner: TypeInner::Vector { size, scalar },
1002 },
1003 span,
1004 );
1005 let expr = Expression::Compose {
1006 ty,
1007 components: vec![value; size as usize],
1008 };
1009 self.register_evaluated_expr(expr, span)
1010 }
1011 Expression::ZeroValue(ty) => {
1012 let inner = match self.types[ty].inner {
1013 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1014 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1015 };
1016 let res_ty = self.types.insert(Type { name: None, inner }, span);
1017 let expr = Expression::ZeroValue(res_ty);
1018 self.register_evaluated_expr(expr, span)
1019 }
1020 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1021 }
1022 }
1023
1024 fn swizzle(
1025 &mut self,
1026 size: crate::VectorSize,
1027 span: Span,
1028 src_constant: Handle<Expression>,
1029 pattern: [crate::SwizzleComponent; 4],
1030 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1031 let mut get_dst_ty = |ty| match self.types[ty].inner {
1032 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1033 Type {
1034 name: None,
1035 inner: TypeInner::Vector { size, scalar },
1036 },
1037 span,
1038 )),
1039 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1040 };
1041
1042 match self.expressions[src_constant] {
1043 Expression::ZeroValue(ty) => {
1044 let dst_ty = get_dst_ty(ty)?;
1045 let expr = Expression::ZeroValue(dst_ty);
1046 self.register_evaluated_expr(expr, span)
1047 }
1048 Expression::Splat { value, .. } => {
1049 let expr = Expression::Splat { size, value };
1050 self.register_evaluated_expr(expr, span)
1051 }
1052 Expression::Compose { ty, ref components } => {
1053 let dst_ty = get_dst_ty(ty)?;
1054
1055 let mut flattened = [src_constant; 4]; let len =
1057 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1058 .zip(flattened.iter_mut())
1059 .map(|(component, elt)| *elt = component)
1060 .count();
1061 let flattened = &flattened[..len];
1062
1063 let swizzled_components = pattern[..size as usize]
1064 .iter()
1065 .map(|&sc| {
1066 let sc = sc as usize;
1067 if let Some(elt) = flattened.get(sc) {
1068 Ok(*elt)
1069 } else {
1070 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1071 }
1072 })
1073 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1074 let expr = Expression::Compose {
1075 ty: dst_ty,
1076 components: swizzled_components,
1077 };
1078 self.register_evaluated_expr(expr, span)
1079 }
1080 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1081 }
1082 }
1083
1084 fn math(
1085 &mut self,
1086 arg: Handle<Expression>,
1087 arg1: Option<Handle<Expression>>,
1088 arg2: Option<Handle<Expression>>,
1089 arg3: Option<Handle<Expression>>,
1090 fun: crate::MathFunction,
1091 span: Span,
1092 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1093 let expected = fun.argument_count();
1094 let given = Some(arg)
1095 .into_iter()
1096 .chain(arg1)
1097 .chain(arg2)
1098 .chain(arg3)
1099 .count();
1100 if expected != given {
1101 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1102 fun, expected, given,
1103 ));
1104 }
1105
1106 match fun {
1108 crate::MathFunction::Abs => {
1110 component_wise_scalar(self, span, [arg], |args| match args {
1111 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1112 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1113 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1114 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1115 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1116 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1118 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1119 })
1120 }
1121 crate::MathFunction::Min => {
1122 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1123 Ok([e1.min(e2)])
1124 })
1125 }
1126 crate::MathFunction::Max => {
1127 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1128 Ok([e1.max(e2)])
1129 })
1130 }
1131 crate::MathFunction::Clamp => {
1132 component_wise_scalar!(
1133 self,
1134 span,
1135 [arg, arg1.unwrap(), arg2.unwrap()],
1136 |e, low, high| {
1137 if low > high {
1138 Err(ConstantEvaluatorError::InvalidClamp)
1139 } else {
1140 Ok([e.clamp(low, high)])
1141 }
1142 }
1143 )
1144 }
1145 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1146 Float::F16([e]) => Ok(Float::F16(
1147 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1148 )),
1149 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1150 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1151 }),
1152
1153 crate::MathFunction::Cos => {
1155 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1156 }
1157 crate::MathFunction::Cosh => {
1158 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1159 }
1160 crate::MathFunction::Sin => {
1161 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1162 }
1163 crate::MathFunction::Sinh => {
1164 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1165 }
1166 crate::MathFunction::Tan => {
1167 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1168 }
1169 crate::MathFunction::Tanh => {
1170 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1171 }
1172 crate::MathFunction::Acos => {
1173 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1174 }
1175 crate::MathFunction::Asin => {
1176 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1177 }
1178 crate::MathFunction::Atan => {
1179 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1180 }
1181 crate::MathFunction::Atan2 => {
1182 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1183 Ok([y.atan2(x)])
1184 })
1185 }
1186 crate::MathFunction::Asinh => {
1187 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1188 }
1189 crate::MathFunction::Acosh => {
1190 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1191 }
1192 crate::MathFunction::Atanh => {
1193 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1194 }
1195 crate::MathFunction::Radians => {
1196 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1197 }
1198 crate::MathFunction::Degrees => {
1199 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1200 }
1201
1202 crate::MathFunction::Ceil => {
1204 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1205 }
1206 crate::MathFunction::Floor => {
1207 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1208 }
1209 crate::MathFunction::Round => {
1210 component_wise_float(self, span, [arg], |e| match e {
1211 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1212 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1213 Float::F16([e]) => {
1214 fn round_ties_even(x: f64) -> f64 {
1222 let i = x as i64;
1223 let f = (x - i as f64).abs();
1224 if f == 0.5 {
1225 if i & 1 == 1 {
1226 (x.abs() + 0.5).copysign(x)
1228 } else {
1229 (x.abs() - 0.5).copysign(x)
1230 }
1231 } else {
1232 x.round()
1233 }
1234 }
1235
1236 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1237 }
1238 })
1239 }
1240 crate::MathFunction::Fract => {
1241 component_wise_float!(self, span, [arg], |e| {
1242 Ok([e - e.floor()])
1245 })
1246 }
1247 crate::MathFunction::Trunc => {
1248 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1249 }
1250
1251 crate::MathFunction::Exp => {
1253 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1254 }
1255 crate::MathFunction::Exp2 => {
1256 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1257 }
1258 crate::MathFunction::Log => {
1259 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1260 }
1261 crate::MathFunction::Log2 => {
1262 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1263 }
1264 crate::MathFunction::Pow => {
1265 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1266 Ok([e1.powf(e2)])
1267 })
1268 }
1269
1270 crate::MathFunction::Sign => {
1272 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1273 }
1274 crate::MathFunction::Fma => {
1275 component_wise_float!(
1276 self,
1277 span,
1278 [arg, arg1.unwrap(), arg2.unwrap()],
1279 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1280 )
1281 }
1282 crate::MathFunction::Step => {
1283 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1284 Float::Abstract([edge, x]) => {
1285 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1286 }
1287 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1288 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1289 f16::one()
1290 } else {
1291 f16::zero()
1292 }])),
1293 })
1294 }
1295 crate::MathFunction::Sqrt => {
1296 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1297 }
1298 crate::MathFunction::InverseSqrt => {
1299 component_wise_float(self, span, [arg], |e| match e {
1300 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1301 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1302 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1303 })
1304 }
1305
1306 crate::MathFunction::CountTrailingZeros => {
1308 component_wise_concrete_int!(self, span, [arg], |e| {
1309 #[allow(clippy::useless_conversion)]
1310 Ok([e
1311 .trailing_zeros()
1312 .try_into()
1313 .expect("bit count overflowed 32 bits, somehow!?")])
1314 })
1315 }
1316 crate::MathFunction::CountLeadingZeros => {
1317 component_wise_concrete_int!(self, span, [arg], |e| {
1318 #[allow(clippy::useless_conversion)]
1319 Ok([e
1320 .leading_zeros()
1321 .try_into()
1322 .expect("bit count overflowed 32 bits, somehow!?")])
1323 })
1324 }
1325 crate::MathFunction::CountOneBits => {
1326 component_wise_concrete_int!(self, span, [arg], |e| {
1327 #[allow(clippy::useless_conversion)]
1328 Ok([e
1329 .count_ones()
1330 .try_into()
1331 .expect("bit count overflowed 32 bits, somehow!?")])
1332 })
1333 }
1334 crate::MathFunction::ReverseBits => {
1335 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1336 }
1337 crate::MathFunction::FirstTrailingBit => {
1338 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1339 }
1340 crate::MathFunction::FirstLeadingBit => {
1341 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1342 }
1343
1344 crate::MathFunction::Dot4I8Packed => {
1346 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1347 }
1348 crate::MathFunction::Dot4U8Packed => {
1349 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1350 }
1351 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1352
1353 crate::MathFunction::Modf
1355 | crate::MathFunction::Frexp
1356 | crate::MathFunction::Ldexp
1357 | crate::MathFunction::Dot
1358 | crate::MathFunction::Outer
1359 | crate::MathFunction::Distance
1360 | crate::MathFunction::Length
1361 | crate::MathFunction::Normalize
1362 | crate::MathFunction::FaceForward
1363 | crate::MathFunction::Reflect
1364 | crate::MathFunction::Refract
1365 | crate::MathFunction::Mix
1366 | crate::MathFunction::SmoothStep
1367 | crate::MathFunction::Inverse
1368 | crate::MathFunction::Transpose
1369 | crate::MathFunction::Determinant
1370 | crate::MathFunction::QuantizeToF16
1371 | crate::MathFunction::ExtractBits
1372 | crate::MathFunction::InsertBits
1373 | crate::MathFunction::Pack4x8snorm
1374 | crate::MathFunction::Pack4x8unorm
1375 | crate::MathFunction::Pack2x16snorm
1376 | crate::MathFunction::Pack2x16unorm
1377 | crate::MathFunction::Pack2x16float
1378 | crate::MathFunction::Pack4xI8
1379 | crate::MathFunction::Pack4xU8
1380 | crate::MathFunction::Pack4xI8Clamp
1381 | crate::MathFunction::Pack4xU8Clamp
1382 | crate::MathFunction::Unpack4x8snorm
1383 | crate::MathFunction::Unpack4x8unorm
1384 | crate::MathFunction::Unpack2x16snorm
1385 | crate::MathFunction::Unpack2x16unorm
1386 | crate::MathFunction::Unpack2x16float
1387 | crate::MathFunction::Unpack4xI8
1388 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1389 format!("{fun:?} built-in function"),
1390 )),
1391 }
1392 }
1393
1394 fn packed_dot_product(
1396 &mut self,
1397 a: Handle<Expression>,
1398 b: Handle<Expression>,
1399 span: Span,
1400 signed: bool,
1401 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1402 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1403 return Err(ConstantEvaluatorError::InvalidMathArg);
1404 };
1405 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1406 return Err(ConstantEvaluatorError::InvalidMathArg);
1407 };
1408
1409 let result = if signed {
1410 Literal::I32(
1411 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1412 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1413 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1414 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1415 )
1416 } else {
1417 Literal::U32(
1418 (a & 0xFF) * (b & 0xFF)
1419 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1420 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1421 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1422 )
1423 };
1424
1425 self.register_evaluated_expr(Expression::Literal(result), span)
1426 }
1427
1428 fn cross_product(
1430 &mut self,
1431 a: Handle<Expression>,
1432 b: Handle<Expression>,
1433 span: Span,
1434 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1435 use Literal as Li;
1436
1437 let (a, ty) = self.extract_vec::<3>(a)?;
1438 let (b, _) = self.extract_vec::<3>(b)?;
1439
1440 let product = match (a, b) {
1441 (
1442 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1443 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1444 ) => {
1445 let p = cross_product(
1450 [a0 as f64, a1 as f64, a2 as f64],
1451 [b0 as f64, b1 as f64, b2 as f64],
1452 );
1453 [
1454 Li::AbstractFloat(p[0]),
1455 Li::AbstractFloat(p[1]),
1456 Li::AbstractFloat(p[2]),
1457 ]
1458 }
1459 (
1460 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1461 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1462 ) => {
1463 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1464 [
1465 Li::AbstractFloat(p[0]),
1466 Li::AbstractFloat(p[1]),
1467 Li::AbstractFloat(p[2]),
1468 ]
1469 }
1470 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1471 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1472 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1473 }
1474 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1475 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1476 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1477 }
1478 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1479 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1480 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1481 }
1482 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1483 };
1484
1485 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1486 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1487 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1488
1489 self.register_evaluated_expr(
1490 Expression::Compose {
1491 ty,
1492 components: vec![p0, p1, p2],
1493 },
1494 span,
1495 )
1496 }
1497
1498 fn extract_vec<const N: usize>(
1506 &mut self,
1507 expr: Handle<Expression>,
1508 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1509 let span = self.expressions.get_span(expr);
1510 let expr = self.eval_zero_value_and_splat(expr, span)?;
1511 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1512 return Err(ConstantEvaluatorError::InvalidMathArg);
1513 };
1514
1515 let mut value = [Literal::Bool(false); N];
1516 for (component, elt) in
1517 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1518 .zip(value.iter_mut())
1519 {
1520 let Expression::Literal(literal) = self.expressions[component] else {
1521 return Err(ConstantEvaluatorError::InvalidMathArg);
1522 };
1523 *elt = literal;
1524 }
1525
1526 Ok((value, ty))
1527 }
1528
1529 fn array_length(
1530 &mut self,
1531 array: Handle<Expression>,
1532 span: Span,
1533 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1534 match self.expressions[array] {
1535 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1536 match self.types[ty].inner {
1537 TypeInner::Array { size, .. } => match size {
1538 ArraySize::Constant(len) => {
1539 let expr = Expression::Literal(Literal::U32(len.get()));
1540 self.register_evaluated_expr(expr, span)
1541 }
1542 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1543 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1544 },
1545 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1546 }
1547 }
1548 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1549 }
1550 }
1551
1552 fn access(
1553 &mut self,
1554 base: Handle<Expression>,
1555 index: usize,
1556 span: Span,
1557 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1558 match self.expressions[base] {
1559 Expression::ZeroValue(ty) => {
1560 let ty_inner = &self.types[ty].inner;
1561 let components = ty_inner
1562 .components()
1563 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1564
1565 if index >= components as usize {
1566 Err(ConstantEvaluatorError::InvalidAccessBase)
1567 } else {
1568 let ty_res = ty_inner
1569 .component_type(index)
1570 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1571 let ty = match ty_res {
1572 crate::proc::TypeResolution::Handle(ty) => ty,
1573 crate::proc::TypeResolution::Value(inner) => {
1574 self.types.insert(Type { name: None, inner }, span)
1575 }
1576 };
1577 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1578 }
1579 }
1580 Expression::Splat { size, value } => {
1581 if index >= size as usize {
1582 Err(ConstantEvaluatorError::InvalidAccessBase)
1583 } else {
1584 Ok(value)
1585 }
1586 }
1587 Expression::Compose { ty, ref components } => {
1588 let _ = self.types[ty]
1589 .inner
1590 .components()
1591 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1592
1593 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1594 .nth(index)
1595 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1596 }
1597 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1598 }
1599 }
1600
1601 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1602 match self.expressions[expr] {
1603 Expression::ZeroValue(ty)
1604 if matches!(
1605 self.types[ty].inner,
1606 TypeInner::Scalar(crate::Scalar {
1607 kind: ScalarKind::Uint,
1608 ..
1609 })
1610 ) =>
1611 {
1612 Ok(0)
1613 }
1614 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1615 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1616 }
1617 }
1618
1619 fn eval_zero_value_and_splat(
1626 &mut self,
1627 mut expr: Handle<Expression>,
1628 span: Span,
1629 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1630 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
1633 let components = components
1634 .clone()
1635 .iter()
1636 .map(|component| self.eval_zero_value_and_splat(*component, span))
1637 .collect::<Result<_, _>>()?;
1638 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
1639 }
1640
1641 if let Expression::Splat { size, value } = self.expressions[expr] {
1645 expr = self.splat(value, size, span)?;
1646 }
1647 if let Expression::ZeroValue(ty) = self.expressions[expr] {
1648 expr = self.eval_zero_value_impl(ty, span)?;
1649 }
1650 Ok(expr)
1651 }
1652
1653 fn eval_zero_value(
1659 &mut self,
1660 expr: Handle<Expression>,
1661 span: Span,
1662 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1663 match self.expressions[expr] {
1664 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1665 _ => Ok(expr),
1666 }
1667 }
1668
1669 fn eval_zero_value_impl(
1675 &mut self,
1676 ty: Handle<Type>,
1677 span: Span,
1678 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1679 match self.types[ty].inner {
1680 TypeInner::Scalar(scalar) => {
1681 let expr = Expression::Literal(
1682 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1683 );
1684 self.register_evaluated_expr(expr, span)
1685 }
1686 TypeInner::Vector { size, scalar } => {
1687 let scalar_ty = self.types.insert(
1688 Type {
1689 name: None,
1690 inner: TypeInner::Scalar(scalar),
1691 },
1692 span,
1693 );
1694 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1695 let expr = Expression::Compose {
1696 ty,
1697 components: vec![el; size as usize],
1698 };
1699 self.register_evaluated_expr(expr, span)
1700 }
1701 TypeInner::Matrix {
1702 columns,
1703 rows,
1704 scalar,
1705 } => {
1706 let vec_ty = self.types.insert(
1707 Type {
1708 name: None,
1709 inner: TypeInner::Vector { size: rows, scalar },
1710 },
1711 span,
1712 );
1713 let el = self.eval_zero_value_impl(vec_ty, span)?;
1714 let expr = Expression::Compose {
1715 ty,
1716 components: vec![el; columns as usize],
1717 };
1718 self.register_evaluated_expr(expr, span)
1719 }
1720 TypeInner::Array {
1721 base,
1722 size: ArraySize::Constant(size),
1723 ..
1724 } => {
1725 let el = self.eval_zero_value_impl(base, span)?;
1726 let expr = Expression::Compose {
1727 ty,
1728 components: vec![el; size.get() as usize],
1729 };
1730 self.register_evaluated_expr(expr, span)
1731 }
1732 TypeInner::Struct { ref members, .. } => {
1733 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1734 let mut components = Vec::with_capacity(members.len());
1735 for ty in types {
1736 components.push(self.eval_zero_value_impl(ty, span)?);
1737 }
1738 let expr = Expression::Compose { ty, components };
1739 self.register_evaluated_expr(expr, span)
1740 }
1741 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1742 }
1743 }
1744
1745 pub fn cast(
1749 &mut self,
1750 expr: Handle<Expression>,
1751 target: crate::Scalar,
1752 span: Span,
1753 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1754 use crate::Scalar as Sc;
1755
1756 let expr = self.eval_zero_value(expr, span)?;
1757
1758 let make_error = || -> Result<_, ConstantEvaluatorError> {
1759 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1760
1761 #[cfg(feature = "wgsl-in")]
1762 let to = target.to_wgsl_for_diagnostics();
1763
1764 #[cfg(not(feature = "wgsl-in"))]
1765 let to = format!("{target:?}");
1766
1767 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1768 };
1769
1770 use crate::proc::type_methods::IntFloatLimits;
1771
1772 let expr = match self.expressions[expr] {
1773 Expression::Literal(literal) => {
1774 let literal = match target {
1775 Sc::I32 => Literal::I32(match literal {
1776 Literal::I32(v) => v,
1777 Literal::U32(v) => v as i32,
1778 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
1779 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
1781 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1782 return make_error();
1783 }
1784 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1785 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1786 }),
1787 Sc::U32 => Literal::U32(match literal {
1788 Literal::I32(v) => v as u32,
1789 Literal::U32(v) => v,
1790 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
1791 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
1793 Literal::Bool(v) => v as u32,
1794 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1795 return make_error();
1796 }
1797 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1798 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1799 }),
1800 Sc::I64 => Literal::I64(match literal {
1801 Literal::I32(v) => v as i64,
1802 Literal::U32(v) => v as i64,
1803 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1804 Literal::Bool(v) => v as i64,
1805 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1806 Literal::I64(v) => v,
1807 Literal::U64(v) => v as i64,
1808 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1810 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1811 }),
1812 Sc::U64 => Literal::U64(match literal {
1813 Literal::I32(v) => v as u64,
1814 Literal::U32(v) => v as u64,
1815 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1816 Literal::Bool(v) => v as u64,
1817 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1818 Literal::I64(v) => v as u64,
1819 Literal::U64(v) => v,
1820 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
1822 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1823 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1824 }),
1825 Sc::F16 => Literal::F16(match literal {
1826 Literal::F16(v) => v,
1827 Literal::F32(v) => f16::from_f32(v),
1828 Literal::F64(v) => f16::from_f64(v),
1829 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
1830 Literal::I64(v) => f16::from_i64(v).unwrap(),
1831 Literal::U64(v) => f16::from_u64(v).unwrap(),
1832 Literal::I32(v) => f16::from_i32(v).unwrap(),
1833 Literal::U32(v) => f16::from_u32(v).unwrap(),
1834 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
1835 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
1836 }),
1837 Sc::F32 => Literal::F32(match literal {
1838 Literal::I32(v) => v as f32,
1839 Literal::U32(v) => v as f32,
1840 Literal::F32(v) => v,
1841 Literal::Bool(v) => v as u32 as f32,
1842 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1843 return make_error();
1844 }
1845 Literal::F16(v) => f16::to_f32(v),
1846 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1847 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1848 }),
1849 Sc::F64 => Literal::F64(match literal {
1850 Literal::I32(v) => v as f64,
1851 Literal::U32(v) => v as f64,
1852 Literal::F16(v) => f16::to_f64(v),
1853 Literal::F32(v) => v as f64,
1854 Literal::F64(v) => v,
1855 Literal::Bool(v) => v as u32 as f64,
1856 Literal::I64(_) | Literal::U64(_) => return make_error(),
1857 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1858 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1859 }),
1860 Sc::BOOL => Literal::Bool(match literal {
1861 Literal::I32(v) => v != 0,
1862 Literal::U32(v) => v != 0,
1863 Literal::F32(v) => v != 0.0,
1864 Literal::F16(v) => v != f16::zero(),
1865 Literal::Bool(v) => v,
1866 Literal::AbstractInt(v) => v != 0,
1867 Literal::AbstractFloat(v) => v != 0.0,
1868 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1869 return make_error();
1870 }
1871 }),
1872 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1873 Literal::AbstractInt(v) => {
1874 v as f64
1879 }
1880 Literal::AbstractFloat(v) => v,
1881 _ => return make_error(),
1882 }),
1883 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
1884 Literal::AbstractInt(v) => v,
1885 _ => return make_error(),
1886 }),
1887 _ => {
1888 log::debug!("Constant evaluator refused to convert value to {target:?}");
1889 return make_error();
1890 }
1891 };
1892 Expression::Literal(literal)
1893 }
1894 Expression::Compose {
1895 ty,
1896 components: ref src_components,
1897 } => {
1898 let ty_inner = match self.types[ty].inner {
1899 TypeInner::Vector { size, .. } => TypeInner::Vector {
1900 size,
1901 scalar: target,
1902 },
1903 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1904 columns,
1905 rows,
1906 scalar: target,
1907 },
1908 _ => return make_error(),
1909 };
1910
1911 let mut components = src_components.clone();
1912 for component in &mut components {
1913 *component = self.cast(*component, target, span)?;
1914 }
1915
1916 let ty = self.types.insert(
1917 Type {
1918 name: None,
1919 inner: ty_inner,
1920 },
1921 span,
1922 );
1923
1924 Expression::Compose { ty, components }
1925 }
1926 Expression::Splat { size, value } => {
1927 let value_span = self.expressions.get_span(value);
1928 let cast_value = self.cast(value, target, value_span)?;
1929 Expression::Splat {
1930 size,
1931 value: cast_value,
1932 }
1933 }
1934 _ => return make_error(),
1935 };
1936
1937 self.register_evaluated_expr(expr, span)
1938 }
1939
1940 pub fn cast_array(
1953 &mut self,
1954 expr: Handle<Expression>,
1955 target: crate::Scalar,
1956 span: Span,
1957 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1958 let expr = self.check_and_get(expr)?;
1959
1960 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1961 return self.cast(expr, target, span);
1962 };
1963
1964 let TypeInner::Array {
1965 base: _,
1966 size,
1967 stride: _,
1968 } = self.types[ty].inner
1969 else {
1970 return self.cast(expr, target, span);
1971 };
1972
1973 let mut components = components.clone();
1974 for component in &mut components {
1975 *component = self.cast_array(*component, target, span)?;
1976 }
1977
1978 let first = components.first().unwrap();
1979 let new_base = match self.resolve_type(*first)? {
1980 crate::proc::TypeResolution::Handle(ty) => ty,
1981 crate::proc::TypeResolution::Value(inner) => {
1982 self.types.insert(Type { name: None, inner }, span)
1983 }
1984 };
1985 let mut layouter = core::mem::take(self.layouter);
1986 layouter.update(self.to_ctx()).unwrap();
1987 *self.layouter = layouter;
1988
1989 let new_base_stride = self.layouter[new_base].to_stride();
1990 let new_array_ty = self.types.insert(
1991 Type {
1992 name: None,
1993 inner: TypeInner::Array {
1994 base: new_base,
1995 size,
1996 stride: new_base_stride,
1997 },
1998 },
1999 span,
2000 );
2001
2002 let compose = Expression::Compose {
2003 ty: new_array_ty,
2004 components,
2005 };
2006 self.register_evaluated_expr(compose, span)
2007 }
2008
2009 fn unary_op(
2010 &mut self,
2011 op: UnaryOperator,
2012 expr: Handle<Expression>,
2013 span: Span,
2014 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2015 let expr = self.eval_zero_value_and_splat(expr, span)?;
2016
2017 let expr = match self.expressions[expr] {
2018 Expression::Literal(value) => Expression::Literal(match op {
2019 UnaryOperator::Negate => match value {
2020 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2021 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2022 Literal::F32(v) => Literal::F32(-v),
2023 Literal::F16(v) => Literal::F16(-v),
2024 Literal::F64(v) => Literal::F64(-v),
2025 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2026 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2027 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2028 },
2029 UnaryOperator::LogicalNot => match value {
2030 Literal::Bool(v) => Literal::Bool(!v),
2031 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2032 },
2033 UnaryOperator::BitwiseNot => match value {
2034 Literal::I32(v) => Literal::I32(!v),
2035 Literal::I64(v) => Literal::I64(!v),
2036 Literal::U32(v) => Literal::U32(!v),
2037 Literal::U64(v) => Literal::U64(!v),
2038 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2039 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2040 },
2041 }),
2042 Expression::Compose {
2043 ty,
2044 components: ref src_components,
2045 } => {
2046 match self.types[ty].inner {
2047 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2048 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2049 }
2050
2051 let mut components = src_components.clone();
2052 for component in &mut components {
2053 *component = self.unary_op(op, *component, span)?;
2054 }
2055
2056 Expression::Compose { ty, components }
2057 }
2058 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2059 };
2060
2061 self.register_evaluated_expr(expr, span)
2062 }
2063
2064 fn binary_op(
2065 &mut self,
2066 op: BinaryOperator,
2067 left: Handle<Expression>,
2068 right: Handle<Expression>,
2069 span: Span,
2070 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2071 let left = self.eval_zero_value_and_splat(left, span)?;
2072 let right = self.eval_zero_value_and_splat(right, span)?;
2073
2074 let expr = match (&self.expressions[left], &self.expressions[right]) {
2075 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2076 let literal = match op {
2077 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2078 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2079 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2080 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2081 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2082 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2083
2084 _ => match (left_value, right_value) {
2085 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2086 BinaryOperator::Add => a.wrapping_add(b),
2087 BinaryOperator::Subtract => a.wrapping_sub(b),
2088 BinaryOperator::Multiply => a.wrapping_mul(b),
2089 BinaryOperator::Divide => {
2090 if b == 0 {
2091 return Err(ConstantEvaluatorError::DivisionByZero);
2092 } else {
2093 a.wrapping_div(b)
2094 }
2095 }
2096 BinaryOperator::Modulo => {
2097 if b == 0 {
2098 return Err(ConstantEvaluatorError::RemainderByZero);
2099 } else {
2100 a.wrapping_rem(b)
2101 }
2102 }
2103 BinaryOperator::And => a & b,
2104 BinaryOperator::ExclusiveOr => a ^ b,
2105 BinaryOperator::InclusiveOr => a | b,
2106 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2107 }),
2108 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2109 BinaryOperator::ShiftLeft => {
2110 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2111 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2112 }
2113 a.checked_shl(b)
2114 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2115 }
2116 BinaryOperator::ShiftRight => a
2117 .checked_shr(b)
2118 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2119 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2120 }),
2121 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2122 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2123 ConstantEvaluatorError::Overflow("addition".into())
2124 })?,
2125 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2126 ConstantEvaluatorError::Overflow("subtraction".into())
2127 })?,
2128 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2129 ConstantEvaluatorError::Overflow("multiplication".into())
2130 })?,
2131 BinaryOperator::Divide => a
2132 .checked_div(b)
2133 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2134 BinaryOperator::Modulo => a
2135 .checked_rem(b)
2136 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2137 BinaryOperator::And => a & b,
2138 BinaryOperator::ExclusiveOr => a ^ b,
2139 BinaryOperator::InclusiveOr => a | b,
2140 BinaryOperator::ShiftLeft => a
2141 .checked_mul(
2142 1u32.checked_shl(b)
2143 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2144 )
2145 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2146 BinaryOperator::ShiftRight => a
2147 .checked_shr(b)
2148 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2149 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2150 }),
2151 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2152 BinaryOperator::Add => a + b,
2153 BinaryOperator::Subtract => a - b,
2154 BinaryOperator::Multiply => a * b,
2155 BinaryOperator::Divide => a / b,
2156 BinaryOperator::Modulo => a % b,
2157 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2158 }),
2159 (Literal::AbstractInt(a), Literal::U32(b)) => {
2160 Literal::AbstractInt(match op {
2161 BinaryOperator::ShiftLeft => {
2162 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2163 return Err(ConstantEvaluatorError::Overflow(
2164 "<<".to_string(),
2165 ));
2166 }
2167 a.checked_shl(b).unwrap_or(0)
2168 }
2169 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2170 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2171 })
2172 }
2173 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2174 BinaryOperator::Add => a + b,
2175 BinaryOperator::Subtract => a - b,
2176 BinaryOperator::Multiply => a * b,
2177 BinaryOperator::Divide => a / b,
2178 BinaryOperator::Modulo => a % b,
2179 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2180 }),
2181 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2182 Literal::AbstractInt(match op {
2183 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2184 ConstantEvaluatorError::Overflow("addition".into())
2185 })?,
2186 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2187 ConstantEvaluatorError::Overflow("subtraction".into())
2188 })?,
2189 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2190 ConstantEvaluatorError::Overflow("multiplication".into())
2191 })?,
2192 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2193 if b == 0 {
2194 ConstantEvaluatorError::DivisionByZero
2195 } else {
2196 ConstantEvaluatorError::Overflow("division".into())
2197 }
2198 })?,
2199 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2200 if b == 0 {
2201 ConstantEvaluatorError::RemainderByZero
2202 } else {
2203 ConstantEvaluatorError::Overflow("remainder".into())
2204 }
2205 })?,
2206 BinaryOperator::And => a & b,
2207 BinaryOperator::ExclusiveOr => a ^ b,
2208 BinaryOperator::InclusiveOr => a | b,
2209 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2210 })
2211 }
2212 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2213 Literal::AbstractFloat(match op {
2214 BinaryOperator::Add => a + b,
2215 BinaryOperator::Subtract => a - b,
2216 BinaryOperator::Multiply => a * b,
2217 BinaryOperator::Divide => a / b,
2218 BinaryOperator::Modulo => a % b,
2219 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2220 })
2221 }
2222 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2223 BinaryOperator::LogicalAnd => a && b,
2224 BinaryOperator::LogicalOr => a || b,
2225 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2226 }),
2227 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2228 },
2229 };
2230 Expression::Literal(literal)
2231 }
2232 (
2233 &Expression::Compose {
2234 components: ref src_components,
2235 ty,
2236 },
2237 &Expression::Literal(_),
2238 ) => {
2239 let mut components = src_components.clone();
2240 for component in &mut components {
2241 *component = self.binary_op(op, *component, right, span)?;
2242 }
2243 Expression::Compose { ty, components }
2244 }
2245 (
2246 &Expression::Literal(_),
2247 &Expression::Compose {
2248 components: ref src_components,
2249 ty,
2250 },
2251 ) => {
2252 let mut components = src_components.clone();
2253 for component in &mut components {
2254 *component = self.binary_op(op, left, *component, span)?;
2255 }
2256 Expression::Compose { ty, components }
2257 }
2258 (
2259 &Expression::Compose {
2260 components: ref left_components,
2261 ty: left_ty,
2262 },
2263 &Expression::Compose {
2264 components: ref right_components,
2265 ty: right_ty,
2266 },
2267 ) => {
2268 let left_flattened = crate::proc::flatten_compose(
2272 left_ty,
2273 left_components,
2274 self.expressions,
2275 self.types,
2276 );
2277 let right_flattened = crate::proc::flatten_compose(
2278 right_ty,
2279 right_components,
2280 self.expressions,
2281 self.types,
2282 );
2283
2284 let mut flattened = Vec::with_capacity(left_components.len());
2287 flattened.extend(left_flattened.zip(right_flattened));
2288
2289 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2290 (
2291 &TypeInner::Vector {
2292 size: left_size, ..
2293 },
2294 &TypeInner::Vector {
2295 size: right_size, ..
2296 },
2297 ) if left_size == right_size => {
2298 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2299 }
2300 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2301 }
2302 }
2303 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2304 };
2305
2306 self.register_evaluated_expr(expr, span)
2307 }
2308
2309 fn binary_op_vector(
2310 &mut self,
2311 op: BinaryOperator,
2312 size: crate::VectorSize,
2313 components: &[(Handle<Expression>, Handle<Expression>)],
2314 left_ty: Handle<Type>,
2315 span: Span,
2316 ) -> Result<Expression, ConstantEvaluatorError> {
2317 let ty = match op {
2318 BinaryOperator::Equal
2320 | BinaryOperator::NotEqual
2321 | BinaryOperator::Less
2322 | BinaryOperator::LessEqual
2323 | BinaryOperator::Greater
2324 | BinaryOperator::GreaterEqual => self.types.insert(
2325 Type {
2326 name: None,
2327 inner: TypeInner::Vector {
2328 size,
2329 scalar: crate::Scalar::BOOL,
2330 },
2331 },
2332 span,
2333 ),
2334
2335 BinaryOperator::Add
2338 | BinaryOperator::Subtract
2339 | BinaryOperator::Multiply
2340 | BinaryOperator::Divide
2341 | BinaryOperator::Modulo
2342 | BinaryOperator::And
2343 | BinaryOperator::ExclusiveOr
2344 | BinaryOperator::InclusiveOr
2345 | BinaryOperator::ShiftLeft
2346 | BinaryOperator::ShiftRight => left_ty,
2347
2348 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2349 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2351 }
2352 };
2353
2354 let components = components
2355 .iter()
2356 .map(|&(left, right)| self.binary_op(op, left, right, span))
2357 .collect::<Result<Vec<_>, _>>()?;
2358
2359 Ok(Expression::Compose { ty, components })
2360 }
2361
2362 fn relational(
2363 &mut self,
2364 fun: RelationalFunction,
2365 arg: Handle<Expression>,
2366 span: Span,
2367 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2368 let arg = self.eval_zero_value_and_splat(arg, span)?;
2369 match fun {
2370 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2371 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2372 Expression::Compose { ty, ref components }
2373 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2374 {
2375 let components =
2376 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2377 .map(|component| match self.expressions[component] {
2378 Expression::Literal(Literal::Bool(val)) => Ok(val),
2379 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2380 })
2381 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2382 let result = match fun {
2383 RelationalFunction::All => components.iter().all(|c| *c),
2384 RelationalFunction::Any => components.iter().any(|c| *c),
2385 _ => unreachable!(),
2386 };
2387 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2388 }
2389 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2390 },
2391 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2392 "{fun:?} built-in function"
2393 ))),
2394 }
2395 }
2396
2397 fn copy_from(
2405 &mut self,
2406 expr: Handle<Expression>,
2407 expressions: &Arena<Expression>,
2408 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2409 let span = expressions.get_span(expr);
2410 match expressions[expr] {
2411 ref expr @ (Expression::Literal(_)
2412 | Expression::Constant(_)
2413 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2414 Expression::Compose { ty, ref components } => {
2415 let mut components = components.clone();
2416 for component in &mut components {
2417 *component = self.copy_from(*component, expressions)?;
2418 }
2419 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2420 }
2421 Expression::Splat { size, value } => {
2422 let value = self.copy_from(value, expressions)?;
2423 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2424 }
2425 _ => {
2426 log::debug!("copy_from: SubexpressionsAreNotConstant");
2427 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2428 }
2429 }
2430 }
2431
2432 fn vector_compose_flattened_size(
2434 &self,
2435 components: &[Handle<Expression>],
2436 ) -> Result<usize, ConstantEvaluatorError> {
2437 components
2438 .iter()
2439 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2440 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2441 TypeInner::Scalar(_) => 1,
2442 TypeInner::Vector { size, .. } => size as usize,
2446 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2447 };
2448 Ok(acc + size)
2449 })
2450 }
2451
2452 fn register_evaluated_expr(
2453 &mut self,
2454 expr: Expression,
2455 span: Span,
2456 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2457 if let Expression::Literal(literal) = expr {
2462 crate::valid::check_literal_value(literal)?;
2463 }
2464
2465 if let Expression::Compose { ty, ref components } = expr {
2469 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2470 let expected = size as usize;
2471 let actual = self.vector_compose_flattened_size(components)?;
2472 if expected != actual {
2473 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2474 expected,
2475 actual,
2476 });
2477 }
2478 }
2479 }
2480
2481 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2482 }
2483
2484 fn append_expr(
2485 &mut self,
2486 expr: Expression,
2487 span: Span,
2488 expr_type: ExpressionKind,
2489 ) -> Handle<Expression> {
2490 let h = match self.behavior {
2491 Behavior::Wgsl(
2492 WgslRestrictions::Runtime(ref mut function_local_data)
2493 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2494 )
2495 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2496 let is_running = function_local_data.emitter.is_running();
2497 let needs_pre_emit = expr.needs_pre_emit();
2498 if is_running && needs_pre_emit {
2499 function_local_data
2500 .block
2501 .extend(function_local_data.emitter.finish(self.expressions));
2502 let h = self.expressions.append(expr, span);
2503 function_local_data.emitter.start(self.expressions);
2504 h
2505 } else {
2506 self.expressions.append(expr, span)
2507 }
2508 }
2509 _ => self.expressions.append(expr, span),
2510 };
2511 self.expression_kind_tracker.insert(h, expr_type);
2512 h
2513 }
2514
2515 fn resolve_type(
2516 &self,
2517 expr: Handle<Expression>,
2518 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2519 use crate::proc::TypeResolution as Tr;
2520 use crate::Expression as Ex;
2521 let resolution = match self.expressions[expr] {
2522 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2523 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2524 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2525 Ex::Splat { size, value } => {
2526 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2527 return Err(ConstantEvaluatorError::SplatScalarOnly);
2528 };
2529 Tr::Value(TypeInner::Vector { scalar, size })
2530 }
2531 _ => {
2532 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2533 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2534 }
2535 };
2536
2537 Ok(resolution)
2538 }
2539
2540 fn select(
2541 &mut self,
2542 reject: Handle<Expression>,
2543 accept: Handle<Expression>,
2544 condition: Handle<Expression>,
2545 span: Span,
2546 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2547 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
2548
2549 let reject = arg(reject)?;
2550 let accept = arg(accept)?;
2551 let condition = arg(condition)?;
2552
2553 let select_single_component =
2554 |this: &mut Self, reject_scalar, reject, accept, condition| {
2555 let accept = this.cast(accept, reject_scalar, span)?;
2556 if condition {
2557 Ok(accept)
2558 } else {
2559 Ok(reject)
2560 }
2561 };
2562
2563 match (&self.expressions[reject], &self.expressions[accept]) {
2564 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
2565 let reject_scalar = reject_lit.scalar();
2566 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
2567 else {
2568 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
2569 };
2570 select_single_component(self, reject_scalar, reject, accept, condition)
2571 }
2572 (
2573 &Expression::Compose {
2574 ty: reject_ty,
2575 components: ref reject_components,
2576 },
2577 &Expression::Compose {
2578 ty: accept_ty,
2579 components: ref accept_components,
2580 },
2581 ) => {
2582 let ty_deets = |ty| {
2583 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
2584 (size.unwrap(), scalar)
2585 };
2586
2587 let expected_vec_size = {
2588 let [(reject_vec_size, _), (accept_vec_size, _)] =
2589 [reject_ty, accept_ty].map(ty_deets);
2590
2591 if reject_vec_size != accept_vec_size {
2592 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
2593 reject: reject_vec_size,
2594 accept: accept_vec_size,
2595 });
2596 }
2597 reject_vec_size
2598 };
2599
2600 let condition_components = match self.expressions[condition] {
2601 Expression::Literal(Literal::Bool(condition)) => {
2602 vec![condition; (expected_vec_size as u8).into()]
2603 }
2604 Expression::Compose {
2605 ty: condition_ty,
2606 components: ref condition_components,
2607 } => {
2608 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
2609 if condition_scalar.kind != ScalarKind::Bool {
2610 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
2611 }
2612 if condition_vec_size != expected_vec_size {
2613 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
2614 }
2615 condition_components
2616 .iter()
2617 .copied()
2618 .map(|component| match &self.expressions[component] {
2619 &Expression::Literal(Literal::Bool(condition)) => condition,
2620 _ => unreachable!(),
2621 })
2622 .collect()
2623 }
2624
2625 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
2626 };
2627
2628 let evaluated = Expression::Compose {
2629 ty: reject_ty,
2630 components: reject_components
2631 .clone()
2632 .into_iter()
2633 .zip(accept_components.clone().into_iter())
2634 .zip(condition_components.into_iter())
2635 .map(|((reject, accept), condition)| {
2636 let reject_scalar = match &self.expressions[reject] {
2637 &Expression::Literal(lit) => lit.scalar(),
2638 _ => unreachable!(),
2639 };
2640 select_single_component(self, reject_scalar, reject, accept, condition)
2641 })
2642 .collect::<Result<_, _>>()?,
2643 };
2644 self.register_evaluated_expr(evaluated, span)
2645 }
2646 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
2647 }
2648 }
2649}
2650
2651fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2652 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2656 match e {
2657 idx @ 0..=31 => idx,
2658 32 => u32::MAX,
2659 _ => unreachable!(),
2660 }
2661 };
2662 match concrete_int {
2663 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2664 ConcreteInt::I32([e]) => {
2665 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2666 }
2667 }
2668}
2669
2670#[test]
2671fn first_trailing_bit_smoke() {
2672 assert_eq!(
2673 first_trailing_bit(ConcreteInt::I32([0])),
2674 ConcreteInt::I32([-1])
2675 );
2676 assert_eq!(
2677 first_trailing_bit(ConcreteInt::I32([1])),
2678 ConcreteInt::I32([0])
2679 );
2680 assert_eq!(
2681 first_trailing_bit(ConcreteInt::I32([2])),
2682 ConcreteInt::I32([1])
2683 );
2684 assert_eq!(
2685 first_trailing_bit(ConcreteInt::I32([-1])),
2686 ConcreteInt::I32([0]),
2687 );
2688 assert_eq!(
2689 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2690 ConcreteInt::I32([31]),
2691 );
2692 assert_eq!(
2693 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2694 ConcreteInt::I32([0]),
2695 );
2696 for idx in 0..32 {
2697 assert_eq!(
2698 first_trailing_bit(ConcreteInt::I32([1 << idx])),
2699 ConcreteInt::I32([idx])
2700 )
2701 }
2702
2703 assert_eq!(
2704 first_trailing_bit(ConcreteInt::U32([0])),
2705 ConcreteInt::U32([u32::MAX])
2706 );
2707 assert_eq!(
2708 first_trailing_bit(ConcreteInt::U32([1])),
2709 ConcreteInt::U32([0])
2710 );
2711 assert_eq!(
2712 first_trailing_bit(ConcreteInt::U32([2])),
2713 ConcreteInt::U32([1])
2714 );
2715 assert_eq!(
2716 first_trailing_bit(ConcreteInt::U32([1 << 31])),
2717 ConcreteInt::U32([31]),
2718 );
2719 assert_eq!(
2720 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2721 ConcreteInt::U32([0]),
2722 );
2723 for idx in 0..32 {
2724 assert_eq!(
2725 first_trailing_bit(ConcreteInt::U32([1 << idx])),
2726 ConcreteInt::U32([idx])
2727 )
2728 }
2729}
2730
2731fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2732 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2736 match e {
2737 idx @ 0..=31 => 31 - idx,
2738 32 => u32::MAX,
2739 _ => unreachable!(),
2740 }
2741 };
2742 match concrete_int {
2743 ConcreteInt::I32([e]) => ConcreteInt::I32([{
2744 let rtl_bit_index = if e.is_negative() {
2745 e.leading_ones()
2746 } else {
2747 e.leading_zeros()
2748 };
2749 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2750 }]),
2751 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2752 }
2753}
2754
2755#[test]
2756fn first_leading_bit_smoke() {
2757 assert_eq!(
2758 first_leading_bit(ConcreteInt::I32([-1])),
2759 ConcreteInt::I32([-1])
2760 );
2761 assert_eq!(
2762 first_leading_bit(ConcreteInt::I32([0])),
2763 ConcreteInt::I32([-1])
2764 );
2765 assert_eq!(
2766 first_leading_bit(ConcreteInt::I32([1])),
2767 ConcreteInt::I32([0])
2768 );
2769 assert_eq!(
2770 first_leading_bit(ConcreteInt::I32([-2])),
2771 ConcreteInt::I32([0])
2772 );
2773 assert_eq!(
2774 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2775 ConcreteInt::I32([12])
2776 );
2777 assert_eq!(
2778 first_leading_bit(ConcreteInt::I32([i32::MAX])),
2779 ConcreteInt::I32([30])
2780 );
2781 assert_eq!(
2782 first_leading_bit(ConcreteInt::I32([i32::MIN])),
2783 ConcreteInt::I32([30])
2784 );
2785 for idx in 0..(32 - 1) {
2787 assert_eq!(
2788 first_leading_bit(ConcreteInt::I32([1 << idx])),
2789 ConcreteInt::I32([idx])
2790 );
2791 }
2792 for idx in 1..(32 - 1) {
2793 assert_eq!(
2794 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2795 ConcreteInt::I32([idx - 1])
2796 );
2797 }
2798
2799 assert_eq!(
2800 first_leading_bit(ConcreteInt::U32([0])),
2801 ConcreteInt::U32([u32::MAX])
2802 );
2803 assert_eq!(
2804 first_leading_bit(ConcreteInt::U32([1])),
2805 ConcreteInt::U32([0])
2806 );
2807 assert_eq!(
2808 first_leading_bit(ConcreteInt::U32([u32::MAX])),
2809 ConcreteInt::U32([31])
2810 );
2811 for idx in 0..32 {
2812 assert_eq!(
2813 first_leading_bit(ConcreteInt::U32([1 << idx])),
2814 ConcreteInt::U32([idx])
2815 )
2816 }
2817}
2818
2819trait TryFromAbstract<T>: Sized {
2821 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2843}
2844
2845impl TryFromAbstract<i64> for i32 {
2846 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2847 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2848 value: format!("{value:?}"),
2849 to_type: "i32",
2850 })
2851 }
2852}
2853
2854impl TryFromAbstract<i64> for u32 {
2855 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2856 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2857 value: format!("{value:?}"),
2858 to_type: "u32",
2859 })
2860 }
2861}
2862
2863impl TryFromAbstract<i64> for u64 {
2864 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2865 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2866 value: format!("{value:?}"),
2867 to_type: "u64",
2868 })
2869 }
2870}
2871
2872impl TryFromAbstract<i64> for i64 {
2873 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2874 Ok(value)
2875 }
2876}
2877
2878impl TryFromAbstract<i64> for f32 {
2879 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2880 let f = value as f32;
2881 Ok(f)
2885 }
2886}
2887
2888impl TryFromAbstract<f64> for f32 {
2889 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2890 let f = value as f32;
2891 if f.is_infinite() {
2892 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2893 value: format!("{value:?}"),
2894 to_type: "f32",
2895 });
2896 }
2897 Ok(f)
2898 }
2899}
2900
2901impl TryFromAbstract<i64> for f64 {
2902 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2903 let f = value as f64;
2904 Ok(f)
2908 }
2909}
2910
2911impl TryFromAbstract<f64> for f64 {
2912 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2913 Ok(value)
2914 }
2915}
2916
2917impl TryFromAbstract<f64> for i32 {
2918 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2919 Ok(value as i32)
2932 }
2933}
2934
2935impl TryFromAbstract<f64> for u32 {
2936 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2937 Ok(value as u32)
2940 }
2941}
2942
2943impl TryFromAbstract<f64> for i64 {
2944 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2945 use crate::proc::type_methods::IntFloatLimits;
2948 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
2949 }
2950}
2951
2952impl TryFromAbstract<f64> for u64 {
2953 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2954 use crate::proc::type_methods::IntFloatLimits;
2957 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
2958 }
2959}
2960
2961impl TryFromAbstract<f64> for f16 {
2962 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
2963 let f = f16::from_f64(value);
2964 if f.is_infinite() {
2965 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2966 value: format!("{value:?}"),
2967 to_type: "f16",
2968 });
2969 }
2970 Ok(f)
2971 }
2972}
2973
2974impl TryFromAbstract<i64> for f16 {
2975 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
2976 let f = f16::from_i64(value);
2977 if f.is_none() {
2978 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2979 value: format!("{value:?}"),
2980 to_type: "f16",
2981 });
2982 }
2983 Ok(f.unwrap())
2984 }
2985}
2986
2987fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2988where
2989 T: Copy,
2990 T: core::ops::Mul<T, Output = T>,
2991 T: core::ops::Sub<T, Output = T>,
2992{
2993 [
2994 a[1] * b[2] - a[2] * b[1],
2995 a[2] * b[0] - a[0] * b[2],
2996 a[0] * b[1] - a[1] * b[0],
2997 ]
2998}
2999
3000#[cfg(test)]
3001mod tests {
3002 use alloc::{vec, vec::Vec};
3003
3004 use crate::{
3005 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3006 UniqueArena, VectorSize,
3007 };
3008
3009 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3010
3011 #[test]
3012 fn unary_op() {
3013 let mut types = UniqueArena::new();
3014 let mut constants = Arena::new();
3015 let overrides = Arena::new();
3016 let mut global_expressions = Arena::new();
3017
3018 let scalar_ty = types.insert(
3019 Type {
3020 name: None,
3021 inner: TypeInner::Scalar(crate::Scalar::I32),
3022 },
3023 Default::default(),
3024 );
3025
3026 let vec_ty = types.insert(
3027 Type {
3028 name: None,
3029 inner: TypeInner::Vector {
3030 size: VectorSize::Bi,
3031 scalar: crate::Scalar::I32,
3032 },
3033 },
3034 Default::default(),
3035 );
3036
3037 let h = constants.append(
3038 Constant {
3039 name: None,
3040 ty: scalar_ty,
3041 init: global_expressions
3042 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3043 },
3044 Default::default(),
3045 );
3046
3047 let h1 = constants.append(
3048 Constant {
3049 name: None,
3050 ty: scalar_ty,
3051 init: global_expressions
3052 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3053 },
3054 Default::default(),
3055 );
3056
3057 let vec_h = constants.append(
3058 Constant {
3059 name: None,
3060 ty: vec_ty,
3061 init: global_expressions.append(
3062 Expression::Compose {
3063 ty: vec_ty,
3064 components: vec![constants[h].init, constants[h1].init],
3065 },
3066 Default::default(),
3067 ),
3068 },
3069 Default::default(),
3070 );
3071
3072 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3073 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3074
3075 let expr2 = Expression::Unary {
3076 op: UnaryOperator::Negate,
3077 expr,
3078 };
3079
3080 let expr3 = Expression::Unary {
3081 op: UnaryOperator::BitwiseNot,
3082 expr,
3083 };
3084
3085 let expr4 = Expression::Unary {
3086 op: UnaryOperator::BitwiseNot,
3087 expr: expr1,
3088 };
3089
3090 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3091 let mut solver = ConstantEvaluator {
3092 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3093 types: &mut types,
3094 constants: &constants,
3095 overrides: &overrides,
3096 expressions: &mut global_expressions,
3097 expression_kind_tracker,
3098 layouter: &mut crate::proc::Layouter::default(),
3099 };
3100
3101 let res1 = solver
3102 .try_eval_and_append(expr2, Default::default())
3103 .unwrap();
3104 let res2 = solver
3105 .try_eval_and_append(expr3, Default::default())
3106 .unwrap();
3107 let res3 = solver
3108 .try_eval_and_append(expr4, Default::default())
3109 .unwrap();
3110
3111 assert_eq!(
3112 global_expressions[res1],
3113 Expression::Literal(Literal::I32(-4))
3114 );
3115
3116 assert_eq!(
3117 global_expressions[res2],
3118 Expression::Literal(Literal::I32(!4))
3119 );
3120
3121 let res3_inner = &global_expressions[res3];
3122
3123 match *res3_inner {
3124 Expression::Compose {
3125 ref ty,
3126 ref components,
3127 } => {
3128 assert_eq!(*ty, vec_ty);
3129 let mut components_iter = components.iter().copied();
3130 assert_eq!(
3131 global_expressions[components_iter.next().unwrap()],
3132 Expression::Literal(Literal::I32(!4))
3133 );
3134 assert_eq!(
3135 global_expressions[components_iter.next().unwrap()],
3136 Expression::Literal(Literal::I32(!8))
3137 );
3138 assert!(components_iter.next().is_none());
3139 }
3140 _ => panic!("Expected vector"),
3141 }
3142 }
3143
3144 #[test]
3145 fn cast() {
3146 let mut types = UniqueArena::new();
3147 let mut constants = Arena::new();
3148 let overrides = Arena::new();
3149 let mut global_expressions = Arena::new();
3150
3151 let scalar_ty = types.insert(
3152 Type {
3153 name: None,
3154 inner: TypeInner::Scalar(crate::Scalar::I32),
3155 },
3156 Default::default(),
3157 );
3158
3159 let h = constants.append(
3160 Constant {
3161 name: None,
3162 ty: scalar_ty,
3163 init: global_expressions
3164 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3165 },
3166 Default::default(),
3167 );
3168
3169 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3170
3171 let root = Expression::As {
3172 expr,
3173 kind: ScalarKind::Bool,
3174 convert: Some(crate::BOOL_WIDTH),
3175 };
3176
3177 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3178 let mut solver = ConstantEvaluator {
3179 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3180 types: &mut types,
3181 constants: &constants,
3182 overrides: &overrides,
3183 expressions: &mut global_expressions,
3184 expression_kind_tracker,
3185 layouter: &mut crate::proc::Layouter::default(),
3186 };
3187
3188 let res = solver
3189 .try_eval_and_append(root, Default::default())
3190 .unwrap();
3191
3192 assert_eq!(
3193 global_expressions[res],
3194 Expression::Literal(Literal::Bool(true))
3195 );
3196 }
3197
3198 #[test]
3199 fn access() {
3200 let mut types = UniqueArena::new();
3201 let mut constants = Arena::new();
3202 let overrides = Arena::new();
3203 let mut global_expressions = Arena::new();
3204
3205 let matrix_ty = types.insert(
3206 Type {
3207 name: None,
3208 inner: TypeInner::Matrix {
3209 columns: VectorSize::Bi,
3210 rows: VectorSize::Tri,
3211 scalar: crate::Scalar::F32,
3212 },
3213 },
3214 Default::default(),
3215 );
3216
3217 let vec_ty = types.insert(
3218 Type {
3219 name: None,
3220 inner: TypeInner::Vector {
3221 size: VectorSize::Tri,
3222 scalar: crate::Scalar::F32,
3223 },
3224 },
3225 Default::default(),
3226 );
3227
3228 let mut vec1_components = Vec::with_capacity(3);
3229 let mut vec2_components = Vec::with_capacity(3);
3230
3231 for i in 0..3 {
3232 let h = global_expressions.append(
3233 Expression::Literal(Literal::F32(i as f32)),
3234 Default::default(),
3235 );
3236
3237 vec1_components.push(h)
3238 }
3239
3240 for i in 3..6 {
3241 let h = global_expressions.append(
3242 Expression::Literal(Literal::F32(i as f32)),
3243 Default::default(),
3244 );
3245
3246 vec2_components.push(h)
3247 }
3248
3249 let vec1 = constants.append(
3250 Constant {
3251 name: None,
3252 ty: vec_ty,
3253 init: global_expressions.append(
3254 Expression::Compose {
3255 ty: vec_ty,
3256 components: vec1_components,
3257 },
3258 Default::default(),
3259 ),
3260 },
3261 Default::default(),
3262 );
3263
3264 let vec2 = constants.append(
3265 Constant {
3266 name: None,
3267 ty: vec_ty,
3268 init: global_expressions.append(
3269 Expression::Compose {
3270 ty: vec_ty,
3271 components: vec2_components,
3272 },
3273 Default::default(),
3274 ),
3275 },
3276 Default::default(),
3277 );
3278
3279 let h = constants.append(
3280 Constant {
3281 name: None,
3282 ty: matrix_ty,
3283 init: global_expressions.append(
3284 Expression::Compose {
3285 ty: matrix_ty,
3286 components: vec![constants[vec1].init, constants[vec2].init],
3287 },
3288 Default::default(),
3289 ),
3290 },
3291 Default::default(),
3292 );
3293
3294 let base = global_expressions.append(Expression::Constant(h), Default::default());
3295
3296 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3297 let mut solver = ConstantEvaluator {
3298 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3299 types: &mut types,
3300 constants: &constants,
3301 overrides: &overrides,
3302 expressions: &mut global_expressions,
3303 expression_kind_tracker,
3304 layouter: &mut crate::proc::Layouter::default(),
3305 };
3306
3307 let root1 = Expression::AccessIndex { base, index: 1 };
3308
3309 let res1 = solver
3310 .try_eval_and_append(root1, Default::default())
3311 .unwrap();
3312
3313 let root2 = Expression::AccessIndex {
3314 base: res1,
3315 index: 2,
3316 };
3317
3318 let res2 = solver
3319 .try_eval_and_append(root2, Default::default())
3320 .unwrap();
3321
3322 match global_expressions[res1] {
3323 Expression::Compose {
3324 ref ty,
3325 ref components,
3326 } => {
3327 assert_eq!(*ty, vec_ty);
3328 let mut components_iter = components.iter().copied();
3329 assert_eq!(
3330 global_expressions[components_iter.next().unwrap()],
3331 Expression::Literal(Literal::F32(3.))
3332 );
3333 assert_eq!(
3334 global_expressions[components_iter.next().unwrap()],
3335 Expression::Literal(Literal::F32(4.))
3336 );
3337 assert_eq!(
3338 global_expressions[components_iter.next().unwrap()],
3339 Expression::Literal(Literal::F32(5.))
3340 );
3341 assert!(components_iter.next().is_none());
3342 }
3343 _ => panic!("Expected vector"),
3344 }
3345
3346 assert_eq!(
3347 global_expressions[res2],
3348 Expression::Literal(Literal::F32(5.))
3349 );
3350 }
3351
3352 #[test]
3353 fn compose_of_constants() {
3354 let mut types = UniqueArena::new();
3355 let mut constants = Arena::new();
3356 let overrides = Arena::new();
3357 let mut global_expressions = Arena::new();
3358
3359 let i32_ty = types.insert(
3360 Type {
3361 name: None,
3362 inner: TypeInner::Scalar(crate::Scalar::I32),
3363 },
3364 Default::default(),
3365 );
3366
3367 let vec2_i32_ty = types.insert(
3368 Type {
3369 name: None,
3370 inner: TypeInner::Vector {
3371 size: VectorSize::Bi,
3372 scalar: crate::Scalar::I32,
3373 },
3374 },
3375 Default::default(),
3376 );
3377
3378 let h = constants.append(
3379 Constant {
3380 name: None,
3381 ty: i32_ty,
3382 init: global_expressions
3383 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3384 },
3385 Default::default(),
3386 );
3387
3388 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3389
3390 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3391 let mut solver = ConstantEvaluator {
3392 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3393 types: &mut types,
3394 constants: &constants,
3395 overrides: &overrides,
3396 expressions: &mut global_expressions,
3397 expression_kind_tracker,
3398 layouter: &mut crate::proc::Layouter::default(),
3399 };
3400
3401 let solved_compose = solver
3402 .try_eval_and_append(
3403 Expression::Compose {
3404 ty: vec2_i32_ty,
3405 components: vec![h_expr, h_expr],
3406 },
3407 Default::default(),
3408 )
3409 .unwrap();
3410 let solved_negate = solver
3411 .try_eval_and_append(
3412 Expression::Unary {
3413 op: UnaryOperator::Negate,
3414 expr: solved_compose,
3415 },
3416 Default::default(),
3417 )
3418 .unwrap();
3419
3420 let pass = match global_expressions[solved_negate] {
3421 Expression::Compose { ty, ref components } => {
3422 ty == vec2_i32_ty
3423 && components.iter().all(|&component| {
3424 let component = &global_expressions[component];
3425 matches!(*component, Expression::Literal(Literal::I32(-4)))
3426 })
3427 }
3428 _ => false,
3429 };
3430 if !pass {
3431 panic!("unexpected evaluation result")
3432 }
3433 }
3434
3435 #[test]
3436 fn splat_of_constant() {
3437 let mut types = UniqueArena::new();
3438 let mut constants = Arena::new();
3439 let overrides = Arena::new();
3440 let mut global_expressions = Arena::new();
3441
3442 let i32_ty = types.insert(
3443 Type {
3444 name: None,
3445 inner: TypeInner::Scalar(crate::Scalar::I32),
3446 },
3447 Default::default(),
3448 );
3449
3450 let vec2_i32_ty = types.insert(
3451 Type {
3452 name: None,
3453 inner: TypeInner::Vector {
3454 size: VectorSize::Bi,
3455 scalar: crate::Scalar::I32,
3456 },
3457 },
3458 Default::default(),
3459 );
3460
3461 let h = constants.append(
3462 Constant {
3463 name: None,
3464 ty: i32_ty,
3465 init: global_expressions
3466 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3467 },
3468 Default::default(),
3469 );
3470
3471 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3472
3473 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3474 let mut solver = ConstantEvaluator {
3475 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3476 types: &mut types,
3477 constants: &constants,
3478 overrides: &overrides,
3479 expressions: &mut global_expressions,
3480 expression_kind_tracker,
3481 layouter: &mut crate::proc::Layouter::default(),
3482 };
3483
3484 let solved_compose = solver
3485 .try_eval_and_append(
3486 Expression::Splat {
3487 size: VectorSize::Bi,
3488 value: h_expr,
3489 },
3490 Default::default(),
3491 )
3492 .unwrap();
3493 let solved_negate = solver
3494 .try_eval_and_append(
3495 Expression::Unary {
3496 op: UnaryOperator::Negate,
3497 expr: solved_compose,
3498 },
3499 Default::default(),
3500 )
3501 .unwrap();
3502
3503 let pass = match global_expressions[solved_negate] {
3504 Expression::Compose { ty, ref components } => {
3505 ty == vec2_i32_ty
3506 && components.iter().all(|&component| {
3507 let component = &global_expressions[component];
3508 matches!(*component, Expression::Literal(Literal::I32(-4)))
3509 })
3510 }
3511 _ => false,
3512 };
3513 if !pass {
3514 panic!("unexpected evaluation result")
3515 }
3516 }
3517
3518 #[test]
3519 fn splat_of_zero_value() {
3520 let mut types = UniqueArena::new();
3521 let constants = Arena::new();
3522 let overrides = Arena::new();
3523 let mut global_expressions = Arena::new();
3524
3525 let f32_ty = types.insert(
3526 Type {
3527 name: None,
3528 inner: TypeInner::Scalar(crate::Scalar::F32),
3529 },
3530 Default::default(),
3531 );
3532
3533 let vec2_f32_ty = types.insert(
3534 Type {
3535 name: None,
3536 inner: TypeInner::Vector {
3537 size: VectorSize::Bi,
3538 scalar: crate::Scalar::F32,
3539 },
3540 },
3541 Default::default(),
3542 );
3543
3544 let five =
3545 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3546 let five_splat = global_expressions.append(
3547 Expression::Splat {
3548 size: VectorSize::Bi,
3549 value: five,
3550 },
3551 Default::default(),
3552 );
3553 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3554 let zero_splat = global_expressions.append(
3555 Expression::Splat {
3556 size: VectorSize::Bi,
3557 value: zero,
3558 },
3559 Default::default(),
3560 );
3561
3562 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3563 let mut solver = ConstantEvaluator {
3564 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3565 types: &mut types,
3566 constants: &constants,
3567 overrides: &overrides,
3568 expressions: &mut global_expressions,
3569 expression_kind_tracker,
3570 layouter: &mut crate::proc::Layouter::default(),
3571 };
3572
3573 let solved_add = solver
3574 .try_eval_and_append(
3575 Expression::Binary {
3576 op: crate::BinaryOperator::Add,
3577 left: zero_splat,
3578 right: five_splat,
3579 },
3580 Default::default(),
3581 )
3582 .unwrap();
3583
3584 let pass = match global_expressions[solved_add] {
3585 Expression::Compose { ty, ref components } => {
3586 ty == vec2_f32_ty
3587 && components.iter().all(|&component| {
3588 let component = &global_expressions[component];
3589 matches!(*component, Expression::Literal(Literal::F32(5.0)))
3590 })
3591 }
3592 _ => false,
3593 };
3594 if !pass {
3595 panic!("unexpected evaluation result")
3596 }
3597 }
3598}