1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
272enum Behavior<'a> {
273 Wgsl(WgslRestrictions<'a>),
274 Glsl(GlslRestrictions<'a>),
275}
276
277impl Behavior<'_> {
278 const fn has_runtime_restrictions(&self) -> bool {
280 matches!(
281 self,
282 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
283 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
284 )
285 }
286}
287
288#[derive(Debug)]
306pub struct ConstantEvaluator<'a> {
307 behavior: Behavior<'a>,
309
310 types: &'a mut UniqueArena<Type>,
317
318 constants: &'a Arena<Constant>,
320
321 overrides: &'a Arena<Override>,
323
324 expressions: &'a mut Arena<Expression>,
326
327 expression_kind_tracker: &'a mut ExpressionKindTracker,
329
330 layouter: &'a mut crate::proc::Layouter,
331}
332
333#[derive(Debug)]
334enum WgslRestrictions<'a> {
335 Const(Option<FunctionLocalData<'a>>),
337 Override,
340 Runtime(FunctionLocalData<'a>),
344}
345
346#[derive(Debug)]
347enum GlslRestrictions<'a> {
348 Const,
350 Runtime(FunctionLocalData<'a>),
354}
355
356#[derive(Debug)]
357struct FunctionLocalData<'a> {
358 global_expressions: &'a Arena<Expression>,
360 emitter: &'a mut super::Emitter,
361 block: &'a mut crate::Block,
362}
363
364#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
365pub enum ExpressionKind {
366 Const,
367 Override,
368 Runtime,
369}
370
371#[derive(Debug)]
372pub struct ExpressionKindTracker {
373 inner: HandleVec<Expression, ExpressionKind>,
374}
375
376impl ExpressionKindTracker {
377 pub const fn new() -> Self {
378 Self {
379 inner: HandleVec::new(),
380 }
381 }
382
383 pub fn force_non_const(&mut self, value: Handle<Expression>) {
385 self.inner[value] = ExpressionKind::Runtime;
386 }
387
388 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
389 self.inner.insert(value, expr_type);
390 }
391
392 pub fn is_const(&self, h: Handle<Expression>) -> bool {
393 matches!(self.type_of(h), ExpressionKind::Const)
394 }
395
396 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
397 matches!(
398 self.type_of(h),
399 ExpressionKind::Const | ExpressionKind::Override
400 )
401 }
402
403 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
404 self.inner[value]
405 }
406
407 pub fn from_arena(arena: &Arena<Expression>) -> Self {
408 let mut tracker = Self {
409 inner: HandleVec::with_capacity(arena.len()),
410 };
411 for (handle, expr) in arena.iter() {
412 tracker
413 .inner
414 .insert(handle, tracker.type_of_with_expr(expr));
415 }
416 tracker
417 }
418
419 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
420 match *expr {
421 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
422 ExpressionKind::Const
423 }
424 Expression::Override(_) => ExpressionKind::Override,
425 Expression::Compose { ref components, .. } => {
426 let mut expr_type = ExpressionKind::Const;
427 for component in components {
428 expr_type = expr_type.max(self.type_of(*component))
429 }
430 expr_type
431 }
432 Expression::Splat { value, .. } => self.type_of(value),
433 Expression::AccessIndex { base, .. } => self.type_of(base),
434 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
435 Expression::Swizzle { vector, .. } => self.type_of(vector),
436 Expression::Unary { expr, .. } => self.type_of(expr),
437 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
438 Expression::Math {
439 arg,
440 arg1,
441 arg2,
442 arg3,
443 ..
444 } => self
445 .type_of(arg)
446 .max(
447 arg1.map(|arg| self.type_of(arg))
448 .unwrap_or(ExpressionKind::Const),
449 )
450 .max(
451 arg2.map(|arg| self.type_of(arg))
452 .unwrap_or(ExpressionKind::Const),
453 )
454 .max(
455 arg3.map(|arg| self.type_of(arg))
456 .unwrap_or(ExpressionKind::Const),
457 ),
458 Expression::As { expr, .. } => self.type_of(expr),
459 Expression::Select {
460 condition,
461 accept,
462 reject,
463 } => self
464 .type_of(condition)
465 .max(self.type_of(accept))
466 .max(self.type_of(reject)),
467 Expression::Relational { argument, .. } => self.type_of(argument),
468 Expression::ArrayLength(expr) => self.type_of(expr),
469 _ => ExpressionKind::Runtime,
470 }
471 }
472}
473
474#[derive(Clone, Debug, thiserror::Error)]
475#[cfg_attr(test, derive(PartialEq))]
476pub enum ConstantEvaluatorError {
477 #[error("Constants cannot access function arguments")]
478 FunctionArg,
479 #[error("Constants cannot access global variables")]
480 GlobalVariable,
481 #[error("Constants cannot access local variables")]
482 LocalVariable,
483 #[error("Cannot get the array length of a non array type")]
484 InvalidArrayLengthArg,
485 #[error("Constants cannot get the array length of a dynamically sized array")]
486 ArrayLengthDynamic,
487 #[error("Cannot call arrayLength on array sized by override-expression")]
488 ArrayLengthOverridden,
489 #[error("Constants cannot call functions")]
490 Call,
491 #[error("Constants don't support workGroupUniformLoad")]
492 WorkGroupUniformLoadResult,
493 #[error("Constants don't support atomic functions")]
494 Atomic,
495 #[error("Constants don't support derivative functions")]
496 Derivative,
497 #[error("Constants don't support load expressions")]
498 Load,
499 #[error("Constants don't support image expressions")]
500 ImageExpression,
501 #[error("Constants don't support ray query expressions")]
502 RayQueryExpression,
503 #[error("Constants don't support subgroup expressions")]
504 SubgroupExpression,
505 #[error("Cannot access the type")]
506 InvalidAccessBase,
507 #[error("Cannot access at the index")]
508 InvalidAccessIndex,
509 #[error("Cannot access with index of type")]
510 InvalidAccessIndexTy,
511 #[error("Constants don't support array length expressions")]
512 ArrayLength,
513 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
514 InvalidCastArg { from: String, to: String },
515 #[error("Cannot apply the unary op to the argument")]
516 InvalidUnaryOpArg,
517 #[error("Cannot apply the binary op to the arguments")]
518 InvalidBinaryOpArgs,
519 #[error("Cannot apply math function to type")]
520 InvalidMathArg,
521 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
522 InvalidMathArgCount(crate::MathFunction, usize, usize),
523 #[error("Cannot apply relational function to type")]
524 InvalidRelationalArg(RelationalFunction),
525 #[error("value of `low` is greater than `high` for clamp built-in function")]
526 InvalidClamp,
527 #[error("Constructor expects {expected} components, found {actual}")]
528 InvalidVectorComposeLength { expected: usize, actual: usize },
529 #[error("Constructor must only contain vector or scalar arguments")]
530 InvalidVectorComposeComponent,
531 #[error("Splat is defined only on scalar values")]
532 SplatScalarOnly,
533 #[error("Can only swizzle vector constants")]
534 SwizzleVectorOnly,
535 #[error("swizzle component not present in source expression")]
536 SwizzleOutOfBounds,
537 #[error("Type is not constructible")]
538 TypeNotConstructible,
539 #[error("Subexpression(s) are not constant")]
540 SubexpressionsAreNotConstant,
541 #[error("Not implemented as constant expression: {0}")]
542 NotImplemented(String),
543 #[error("{0} operation overflowed")]
544 Overflow(String),
545 #[error(
546 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
547 )]
548 AutomaticConversionLossy {
549 value: String,
550 to_type: &'static str,
551 },
552 #[error("Division by zero")]
553 DivisionByZero,
554 #[error("Remainder by zero")]
555 RemainderByZero,
556 #[error("RHS of shift operation is greater than or equal to 32")]
557 ShiftedMoreThan32Bits,
558 #[error(transparent)]
559 Literal(#[from] crate::valid::LiteralError),
560 #[error("Can't use pipeline-overridable constants in const-expressions")]
561 Override,
562 #[error("Unexpected runtime-expression")]
563 RuntimeExpr,
564 #[error("Unexpected override-expression")]
565 OverrideExpr,
566 #[error("Expected boolean expression for condition argument of `select`, got something else")]
567 SelectScalarConditionNotABool,
568 #[error(
569 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
570 reject,
571 accept
572 )]
573 SelectVecRejectAcceptSizeMismatch {
574 reject: crate::VectorSize,
575 accept: crate::VectorSize,
576 },
577 #[error("Expected boolean vector for condition arg., got something else")]
578 SelectConditionNotAVecBool,
579 #[error(
580 "Expected same number of vector components between condition, accept, and reject args., got something else",
581 )]
582 SelectConditionVecSizeMismatch,
583 #[error(
584 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
585 )]
586 SelectAcceptRejectTypeMismatch,
587}
588
589impl<'a> ConstantEvaluator<'a> {
590 pub fn for_wgsl_module(
595 module: &'a mut crate::Module,
596 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
597 layouter: &'a mut crate::proc::Layouter,
598 in_override_ctx: bool,
599 ) -> Self {
600 Self::for_module(
601 Behavior::Wgsl(if in_override_ctx {
602 WgslRestrictions::Override
603 } else {
604 WgslRestrictions::Const(None)
605 }),
606 module,
607 global_expression_kind_tracker,
608 layouter,
609 )
610 }
611
612 pub fn for_glsl_module(
617 module: &'a mut crate::Module,
618 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
619 layouter: &'a mut crate::proc::Layouter,
620 ) -> Self {
621 Self::for_module(
622 Behavior::Glsl(GlslRestrictions::Const),
623 module,
624 global_expression_kind_tracker,
625 layouter,
626 )
627 }
628
629 fn for_module(
630 behavior: Behavior<'a>,
631 module: &'a mut crate::Module,
632 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
633 layouter: &'a mut crate::proc::Layouter,
634 ) -> Self {
635 Self {
636 behavior,
637 types: &mut module.types,
638 constants: &module.constants,
639 overrides: &module.overrides,
640 expressions: &mut module.global_expressions,
641 expression_kind_tracker: global_expression_kind_tracker,
642 layouter,
643 }
644 }
645
646 pub fn for_wgsl_function(
651 module: &'a mut crate::Module,
652 expressions: &'a mut Arena<Expression>,
653 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
654 layouter: &'a mut crate::proc::Layouter,
655 emitter: &'a mut super::Emitter,
656 block: &'a mut crate::Block,
657 is_const: bool,
658 ) -> Self {
659 let local_data = FunctionLocalData {
660 global_expressions: &module.global_expressions,
661 emitter,
662 block,
663 };
664 Self {
665 behavior: Behavior::Wgsl(if is_const {
666 WgslRestrictions::Const(Some(local_data))
667 } else {
668 WgslRestrictions::Runtime(local_data)
669 }),
670 types: &mut module.types,
671 constants: &module.constants,
672 overrides: &module.overrides,
673 expressions,
674 expression_kind_tracker: local_expression_kind_tracker,
675 layouter,
676 }
677 }
678
679 pub fn for_glsl_function(
684 module: &'a mut crate::Module,
685 expressions: &'a mut Arena<Expression>,
686 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
687 layouter: &'a mut crate::proc::Layouter,
688 emitter: &'a mut super::Emitter,
689 block: &'a mut crate::Block,
690 ) -> Self {
691 Self {
692 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
693 global_expressions: &module.global_expressions,
694 emitter,
695 block,
696 })),
697 types: &mut module.types,
698 constants: &module.constants,
699 overrides: &module.overrides,
700 expressions,
701 expression_kind_tracker: local_expression_kind_tracker,
702 layouter,
703 }
704 }
705
706 pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
707 crate::proc::GlobalCtx {
708 types: self.types,
709 constants: self.constants,
710 overrides: self.overrides,
711 global_expressions: match self.function_local_data() {
712 Some(data) => data.global_expressions,
713 None => self.expressions,
714 },
715 }
716 }
717
718 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
719 if !self.expression_kind_tracker.is_const(expr) {
720 log::debug!("check: SubexpressionsAreNotConstant");
721 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
722 }
723 Ok(())
724 }
725
726 fn check_and_get(
727 &mut self,
728 expr: Handle<Expression>,
729 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
730 match self.expressions[expr] {
731 Expression::Constant(c) => {
732 if let Some(function_local_data) = self.function_local_data() {
735 self.copy_from(
737 self.constants[c].init,
738 function_local_data.global_expressions,
739 )
740 } else {
741 Ok(self.constants[c].init)
743 }
744 }
745 _ => {
746 self.check(expr)?;
747 Ok(expr)
748 }
749 }
750 }
751
752 pub fn try_eval_and_append(
776 &mut self,
777 expr: Expression,
778 span: Span,
779 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
780 match self.expression_kind_tracker.type_of_with_expr(&expr) {
781 ExpressionKind::Const => {
782 let eval_result = self.try_eval_and_append_impl(&expr, span);
783 if self.behavior.has_runtime_restrictions()
788 && matches!(
789 eval_result,
790 Err(ConstantEvaluatorError::NotImplemented(_)
791 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
792 )
793 {
794 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
795 } else {
796 eval_result
797 }
798 }
799 ExpressionKind::Override => match self.behavior {
800 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
801 Ok(self.append_expr(expr, span, ExpressionKind::Override))
802 }
803 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
804 Err(ConstantEvaluatorError::OverrideExpr)
805 }
806 Behavior::Glsl(_) => {
807 unreachable!()
808 }
809 },
810 ExpressionKind::Runtime => {
811 if self.behavior.has_runtime_restrictions() {
812 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
813 } else {
814 Err(ConstantEvaluatorError::RuntimeExpr)
815 }
816 }
817 }
818 }
819
820 const fn is_global_arena(&self) -> bool {
822 matches!(
823 self.behavior,
824 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
825 | Behavior::Glsl(GlslRestrictions::Const)
826 )
827 }
828
829 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
830 match self.behavior {
831 Behavior::Wgsl(
832 WgslRestrictions::Runtime(ref function_local_data)
833 | WgslRestrictions::Const(Some(ref function_local_data)),
834 )
835 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
836 Some(function_local_data)
837 }
838 _ => None,
839 }
840 }
841
842 fn try_eval_and_append_impl(
843 &mut self,
844 expr: &Expression,
845 span: Span,
846 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
847 log::trace!("try_eval_and_append: {expr:?}");
848 match *expr {
849 Expression::Constant(c) if self.is_global_arena() => {
850 Ok(self.constants[c].init)
853 }
854 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
855 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
856 self.register_evaluated_expr(expr.clone(), span)
857 }
858 Expression::Compose { ty, ref components } => {
859 let components = components
860 .iter()
861 .map(|component| self.check_and_get(*component))
862 .collect::<Result<Vec<_>, _>>()?;
863 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
864 }
865 Expression::Splat { size, value } => {
866 let value = self.check_and_get(value)?;
867 self.register_evaluated_expr(Expression::Splat { size, value }, span)
868 }
869 Expression::AccessIndex { base, index } => {
870 let base = self.check_and_get(base)?;
871
872 self.access(base, index as usize, span)
873 }
874 Expression::Access { base, index } => {
875 let base = self.check_and_get(base)?;
876 let index = self.check_and_get(index)?;
877
878 self.access(base, self.constant_index(index)?, span)
879 }
880 Expression::Swizzle {
881 size,
882 vector,
883 pattern,
884 } => {
885 let vector = self.check_and_get(vector)?;
886
887 self.swizzle(size, span, vector, pattern)
888 }
889 Expression::Unary { expr, op } => {
890 let expr = self.check_and_get(expr)?;
891
892 self.unary_op(op, expr, span)
893 }
894 Expression::Binary { left, right, op } => {
895 let left = self.check_and_get(left)?;
896 let right = self.check_and_get(right)?;
897
898 self.binary_op(op, left, right, span)
899 }
900 Expression::Math {
901 fun,
902 arg,
903 arg1,
904 arg2,
905 arg3,
906 } => {
907 let arg = self.check_and_get(arg)?;
908 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
909 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
910 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
911
912 self.math(arg, arg1, arg2, arg3, fun, span)
913 }
914 Expression::As {
915 convert,
916 expr,
917 kind,
918 } => {
919 let expr = self.check_and_get(expr)?;
920
921 match convert {
922 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
923 None => Err(ConstantEvaluatorError::NotImplemented(
924 "bitcast built-in function".into(),
925 )),
926 }
927 }
928 Expression::Select {
929 reject,
930 accept,
931 condition,
932 } => {
933 let mut arg = |expr| self.check_and_get(expr);
934
935 let reject = arg(reject)?;
936 let accept = arg(accept)?;
937 let condition = arg(condition)?;
938
939 self.select(reject, accept, condition, span)
940 }
941 Expression::Relational { fun, argument } => {
942 let argument = self.check_and_get(argument)?;
943 self.relational(fun, argument, span)
944 }
945 Expression::ArrayLength(expr) => match self.behavior {
946 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
947 Behavior::Glsl(_) => {
948 let expr = self.check_and_get(expr)?;
949 self.array_length(expr, span)
950 }
951 },
952 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
953 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
954 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
955 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
956 Expression::WorkGroupUniformLoadResult { .. } => {
957 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
958 }
959 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
960 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
961 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
962 Expression::ImageSample { .. }
963 | Expression::ImageLoad { .. }
964 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
965 Expression::RayQueryProceedResult
966 | Expression::RayQueryGetIntersection { .. }
967 | Expression::RayQueryVertexPositions { .. } => {
968 Err(ConstantEvaluatorError::RayQueryExpression)
969 }
970 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
971 Expression::SubgroupOperationResult { .. } => {
972 Err(ConstantEvaluatorError::SubgroupExpression)
973 }
974 }
975 }
976
977 fn splat(
990 &mut self,
991 value: Handle<Expression>,
992 size: crate::VectorSize,
993 span: Span,
994 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
995 match self.expressions[value] {
996 Expression::Literal(literal) => {
997 let scalar = literal.scalar();
998 let ty = self.types.insert(
999 Type {
1000 name: None,
1001 inner: TypeInner::Vector { size, scalar },
1002 },
1003 span,
1004 );
1005 let expr = Expression::Compose {
1006 ty,
1007 components: vec![value; size as usize],
1008 };
1009 self.register_evaluated_expr(expr, span)
1010 }
1011 Expression::ZeroValue(ty) => {
1012 let inner = match self.types[ty].inner {
1013 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1014 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1015 };
1016 let res_ty = self.types.insert(Type { name: None, inner }, span);
1017 let expr = Expression::ZeroValue(res_ty);
1018 self.register_evaluated_expr(expr, span)
1019 }
1020 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1021 }
1022 }
1023
1024 fn swizzle(
1025 &mut self,
1026 size: crate::VectorSize,
1027 span: Span,
1028 src_constant: Handle<Expression>,
1029 pattern: [crate::SwizzleComponent; 4],
1030 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1031 let mut get_dst_ty = |ty| match self.types[ty].inner {
1032 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1033 Type {
1034 name: None,
1035 inner: TypeInner::Vector { size, scalar },
1036 },
1037 span,
1038 )),
1039 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1040 };
1041
1042 match self.expressions[src_constant] {
1043 Expression::ZeroValue(ty) => {
1044 let dst_ty = get_dst_ty(ty)?;
1045 let expr = Expression::ZeroValue(dst_ty);
1046 self.register_evaluated_expr(expr, span)
1047 }
1048 Expression::Splat { value, .. } => {
1049 let expr = Expression::Splat { size, value };
1050 self.register_evaluated_expr(expr, span)
1051 }
1052 Expression::Compose { ty, ref components } => {
1053 let dst_ty = get_dst_ty(ty)?;
1054
1055 let mut flattened = [src_constant; 4]; let len =
1057 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1058 .zip(flattened.iter_mut())
1059 .map(|(component, elt)| *elt = component)
1060 .count();
1061 let flattened = &flattened[..len];
1062
1063 let swizzled_components = pattern[..size as usize]
1064 .iter()
1065 .map(|&sc| {
1066 let sc = sc as usize;
1067 if let Some(elt) = flattened.get(sc) {
1068 Ok(*elt)
1069 } else {
1070 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1071 }
1072 })
1073 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1074 let expr = Expression::Compose {
1075 ty: dst_ty,
1076 components: swizzled_components,
1077 };
1078 self.register_evaluated_expr(expr, span)
1079 }
1080 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1081 }
1082 }
1083
1084 fn math(
1085 &mut self,
1086 arg: Handle<Expression>,
1087 arg1: Option<Handle<Expression>>,
1088 arg2: Option<Handle<Expression>>,
1089 arg3: Option<Handle<Expression>>,
1090 fun: crate::MathFunction,
1091 span: Span,
1092 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1093 let expected = fun.argument_count();
1094 let given = Some(arg)
1095 .into_iter()
1096 .chain(arg1)
1097 .chain(arg2)
1098 .chain(arg3)
1099 .count();
1100 if expected != given {
1101 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1102 fun, expected, given,
1103 ));
1104 }
1105
1106 match fun {
1108 crate::MathFunction::Abs => {
1110 component_wise_scalar(self, span, [arg], |args| match args {
1111 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1112 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1113 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1114 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1115 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1116 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1118 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1119 })
1120 }
1121 crate::MathFunction::Min => {
1122 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1123 Ok([e1.min(e2)])
1124 })
1125 }
1126 crate::MathFunction::Max => {
1127 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1128 Ok([e1.max(e2)])
1129 })
1130 }
1131 crate::MathFunction::Clamp => {
1132 component_wise_scalar!(
1133 self,
1134 span,
1135 [arg, arg1.unwrap(), arg2.unwrap()],
1136 |e, low, high| {
1137 if low > high {
1138 Err(ConstantEvaluatorError::InvalidClamp)
1139 } else {
1140 Ok([e.clamp(low, high)])
1141 }
1142 }
1143 )
1144 }
1145 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1146 Float::F16([e]) => Ok(Float::F16(
1147 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1148 )),
1149 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1150 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1151 }),
1152
1153 crate::MathFunction::Cos => {
1155 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1156 }
1157 crate::MathFunction::Cosh => {
1158 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1159 }
1160 crate::MathFunction::Sin => {
1161 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1162 }
1163 crate::MathFunction::Sinh => {
1164 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1165 }
1166 crate::MathFunction::Tan => {
1167 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1168 }
1169 crate::MathFunction::Tanh => {
1170 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1171 }
1172 crate::MathFunction::Acos => {
1173 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1174 }
1175 crate::MathFunction::Asin => {
1176 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1177 }
1178 crate::MathFunction::Atan => {
1179 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1180 }
1181 crate::MathFunction::Asinh => {
1182 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1183 }
1184 crate::MathFunction::Acosh => {
1185 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1186 }
1187 crate::MathFunction::Atanh => {
1188 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1189 }
1190 crate::MathFunction::Radians => {
1191 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1192 }
1193 crate::MathFunction::Degrees => {
1194 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1195 }
1196
1197 crate::MathFunction::Ceil => {
1199 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1200 }
1201 crate::MathFunction::Floor => {
1202 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1203 }
1204 crate::MathFunction::Round => {
1205 component_wise_float(self, span, [arg], |e| match e {
1206 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1207 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1208 Float::F16([e]) => {
1209 fn round_ties_even(x: f64) -> f64 {
1217 let i = x as i64;
1218 let f = (x - i as f64).abs();
1219 if f == 0.5 {
1220 if i & 1 == 1 {
1221 (x.abs() + 0.5).copysign(x)
1223 } else {
1224 (x.abs() - 0.5).copysign(x)
1225 }
1226 } else {
1227 x.round()
1228 }
1229 }
1230
1231 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1232 }
1233 })
1234 }
1235 crate::MathFunction::Fract => {
1236 component_wise_float!(self, span, [arg], |e| {
1237 Ok([e - e.floor()])
1240 })
1241 }
1242 crate::MathFunction::Trunc => {
1243 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1244 }
1245
1246 crate::MathFunction::Exp => {
1248 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1249 }
1250 crate::MathFunction::Exp2 => {
1251 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1252 }
1253 crate::MathFunction::Log => {
1254 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1255 }
1256 crate::MathFunction::Log2 => {
1257 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1258 }
1259 crate::MathFunction::Pow => {
1260 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1261 Ok([e1.powf(e2)])
1262 })
1263 }
1264
1265 crate::MathFunction::Sign => {
1267 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1268 }
1269 crate::MathFunction::Fma => {
1270 component_wise_float!(
1271 self,
1272 span,
1273 [arg, arg1.unwrap(), arg2.unwrap()],
1274 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1275 )
1276 }
1277 crate::MathFunction::Step => {
1278 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1279 Float::Abstract([edge, x]) => {
1280 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1281 }
1282 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1283 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1284 f16::one()
1285 } else {
1286 f16::zero()
1287 }])),
1288 })
1289 }
1290 crate::MathFunction::Sqrt => {
1291 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1292 }
1293 crate::MathFunction::InverseSqrt => {
1294 component_wise_float(self, span, [arg], |e| match e {
1295 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1296 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1297 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1298 })
1299 }
1300
1301 crate::MathFunction::CountTrailingZeros => {
1303 component_wise_concrete_int!(self, span, [arg], |e| {
1304 #[allow(clippy::useless_conversion)]
1305 Ok([e
1306 .trailing_zeros()
1307 .try_into()
1308 .expect("bit count overflowed 32 bits, somehow!?")])
1309 })
1310 }
1311 crate::MathFunction::CountLeadingZeros => {
1312 component_wise_concrete_int!(self, span, [arg], |e| {
1313 #[allow(clippy::useless_conversion)]
1314 Ok([e
1315 .leading_zeros()
1316 .try_into()
1317 .expect("bit count overflowed 32 bits, somehow!?")])
1318 })
1319 }
1320 crate::MathFunction::CountOneBits => {
1321 component_wise_concrete_int!(self, span, [arg], |e| {
1322 #[allow(clippy::useless_conversion)]
1323 Ok([e
1324 .count_ones()
1325 .try_into()
1326 .expect("bit count overflowed 32 bits, somehow!?")])
1327 })
1328 }
1329 crate::MathFunction::ReverseBits => {
1330 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1331 }
1332 crate::MathFunction::FirstTrailingBit => {
1333 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1334 }
1335 crate::MathFunction::FirstLeadingBit => {
1336 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1337 }
1338
1339 crate::MathFunction::Dot4I8Packed => {
1341 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1342 }
1343 crate::MathFunction::Dot4U8Packed => {
1344 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1345 }
1346 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1347
1348 crate::MathFunction::Atan2
1350 | crate::MathFunction::Modf
1351 | crate::MathFunction::Frexp
1352 | crate::MathFunction::Ldexp
1353 | crate::MathFunction::Dot
1354 | crate::MathFunction::Outer
1355 | crate::MathFunction::Distance
1356 | crate::MathFunction::Length
1357 | crate::MathFunction::Normalize
1358 | crate::MathFunction::FaceForward
1359 | crate::MathFunction::Reflect
1360 | crate::MathFunction::Refract
1361 | crate::MathFunction::Mix
1362 | crate::MathFunction::SmoothStep
1363 | crate::MathFunction::Inverse
1364 | crate::MathFunction::Transpose
1365 | crate::MathFunction::Determinant
1366 | crate::MathFunction::QuantizeToF16
1367 | crate::MathFunction::ExtractBits
1368 | crate::MathFunction::InsertBits
1369 | crate::MathFunction::Pack4x8snorm
1370 | crate::MathFunction::Pack4x8unorm
1371 | crate::MathFunction::Pack2x16snorm
1372 | crate::MathFunction::Pack2x16unorm
1373 | crate::MathFunction::Pack2x16float
1374 | crate::MathFunction::Pack4xI8
1375 | crate::MathFunction::Pack4xU8
1376 | crate::MathFunction::Pack4xI8Clamp
1377 | crate::MathFunction::Pack4xU8Clamp
1378 | crate::MathFunction::Unpack4x8snorm
1379 | crate::MathFunction::Unpack4x8unorm
1380 | crate::MathFunction::Unpack2x16snorm
1381 | crate::MathFunction::Unpack2x16unorm
1382 | crate::MathFunction::Unpack2x16float
1383 | crate::MathFunction::Unpack4xI8
1384 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1385 format!("{fun:?} built-in function"),
1386 )),
1387 }
1388 }
1389
1390 fn packed_dot_product(
1392 &mut self,
1393 a: Handle<Expression>,
1394 b: Handle<Expression>,
1395 span: Span,
1396 signed: bool,
1397 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1398 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1399 return Err(ConstantEvaluatorError::InvalidMathArg);
1400 };
1401 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1402 return Err(ConstantEvaluatorError::InvalidMathArg);
1403 };
1404
1405 let result = if signed {
1406 Literal::I32(
1407 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1408 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1409 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1410 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1411 )
1412 } else {
1413 Literal::U32(
1414 (a & 0xFF) * (b & 0xFF)
1415 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1416 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1417 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1418 )
1419 };
1420
1421 self.register_evaluated_expr(Expression::Literal(result), span)
1422 }
1423
1424 fn cross_product(
1426 &mut self,
1427 a: Handle<Expression>,
1428 b: Handle<Expression>,
1429 span: Span,
1430 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1431 use Literal as Li;
1432
1433 let (a, ty) = self.extract_vec::<3>(a)?;
1434 let (b, _) = self.extract_vec::<3>(b)?;
1435
1436 let product = match (a, b) {
1437 (
1438 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1439 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1440 ) => {
1441 let p = cross_product(
1446 [a0 as f64, a1 as f64, a2 as f64],
1447 [b0 as f64, b1 as f64, b2 as f64],
1448 );
1449 [
1450 Li::AbstractFloat(p[0]),
1451 Li::AbstractFloat(p[1]),
1452 Li::AbstractFloat(p[2]),
1453 ]
1454 }
1455 (
1456 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1457 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1458 ) => {
1459 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1460 [
1461 Li::AbstractFloat(p[0]),
1462 Li::AbstractFloat(p[1]),
1463 Li::AbstractFloat(p[2]),
1464 ]
1465 }
1466 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1467 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1468 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1469 }
1470 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1471 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1472 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1473 }
1474 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1475 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1476 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1477 }
1478 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1479 };
1480
1481 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1482 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1483 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1484
1485 self.register_evaluated_expr(
1486 Expression::Compose {
1487 ty,
1488 components: vec![p0, p1, p2],
1489 },
1490 span,
1491 )
1492 }
1493
1494 fn extract_vec<const N: usize>(
1502 &mut self,
1503 expr: Handle<Expression>,
1504 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1505 let span = self.expressions.get_span(expr);
1506 let expr = self.eval_zero_value_and_splat(expr, span)?;
1507 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1508 return Err(ConstantEvaluatorError::InvalidMathArg);
1509 };
1510
1511 let mut value = [Literal::Bool(false); N];
1512 for (component, elt) in
1513 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1514 .zip(value.iter_mut())
1515 {
1516 let Expression::Literal(literal) = self.expressions[component] else {
1517 return Err(ConstantEvaluatorError::InvalidMathArg);
1518 };
1519 *elt = literal;
1520 }
1521
1522 Ok((value, ty))
1523 }
1524
1525 fn array_length(
1526 &mut self,
1527 array: Handle<Expression>,
1528 span: Span,
1529 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1530 match self.expressions[array] {
1531 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1532 match self.types[ty].inner {
1533 TypeInner::Array { size, .. } => match size {
1534 ArraySize::Constant(len) => {
1535 let expr = Expression::Literal(Literal::U32(len.get()));
1536 self.register_evaluated_expr(expr, span)
1537 }
1538 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1539 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1540 },
1541 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1542 }
1543 }
1544 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1545 }
1546 }
1547
1548 fn access(
1549 &mut self,
1550 base: Handle<Expression>,
1551 index: usize,
1552 span: Span,
1553 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1554 match self.expressions[base] {
1555 Expression::ZeroValue(ty) => {
1556 let ty_inner = &self.types[ty].inner;
1557 let components = ty_inner
1558 .components()
1559 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1560
1561 if index >= components as usize {
1562 Err(ConstantEvaluatorError::InvalidAccessBase)
1563 } else {
1564 let ty_res = ty_inner
1565 .component_type(index)
1566 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1567 let ty = match ty_res {
1568 crate::proc::TypeResolution::Handle(ty) => ty,
1569 crate::proc::TypeResolution::Value(inner) => {
1570 self.types.insert(Type { name: None, inner }, span)
1571 }
1572 };
1573 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1574 }
1575 }
1576 Expression::Splat { size, value } => {
1577 if index >= size as usize {
1578 Err(ConstantEvaluatorError::InvalidAccessBase)
1579 } else {
1580 Ok(value)
1581 }
1582 }
1583 Expression::Compose { ty, ref components } => {
1584 let _ = self.types[ty]
1585 .inner
1586 .components()
1587 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1588
1589 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1590 .nth(index)
1591 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1592 }
1593 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1594 }
1595 }
1596
1597 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1598 match self.expressions[expr] {
1599 Expression::ZeroValue(ty)
1600 if matches!(
1601 self.types[ty].inner,
1602 TypeInner::Scalar(crate::Scalar {
1603 kind: ScalarKind::Uint,
1604 ..
1605 })
1606 ) =>
1607 {
1608 Ok(0)
1609 }
1610 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1611 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1612 }
1613 }
1614
1615 fn eval_zero_value_and_splat(
1622 &mut self,
1623 mut expr: Handle<Expression>,
1624 span: Span,
1625 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1626 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
1629 let components = components
1630 .clone()
1631 .iter()
1632 .map(|component| self.eval_zero_value_and_splat(*component, span))
1633 .collect::<Result<_, _>>()?;
1634 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
1635 }
1636
1637 if let Expression::Splat { size, value } = self.expressions[expr] {
1641 expr = self.splat(value, size, span)?;
1642 }
1643 if let Expression::ZeroValue(ty) = self.expressions[expr] {
1644 expr = self.eval_zero_value_impl(ty, span)?;
1645 }
1646 Ok(expr)
1647 }
1648
1649 fn eval_zero_value(
1655 &mut self,
1656 expr: Handle<Expression>,
1657 span: Span,
1658 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1659 match self.expressions[expr] {
1660 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1661 _ => Ok(expr),
1662 }
1663 }
1664
1665 fn eval_zero_value_impl(
1671 &mut self,
1672 ty: Handle<Type>,
1673 span: Span,
1674 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1675 match self.types[ty].inner {
1676 TypeInner::Scalar(scalar) => {
1677 let expr = Expression::Literal(
1678 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1679 );
1680 self.register_evaluated_expr(expr, span)
1681 }
1682 TypeInner::Vector { size, scalar } => {
1683 let scalar_ty = self.types.insert(
1684 Type {
1685 name: None,
1686 inner: TypeInner::Scalar(scalar),
1687 },
1688 span,
1689 );
1690 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1691 let expr = Expression::Compose {
1692 ty,
1693 components: vec![el; size as usize],
1694 };
1695 self.register_evaluated_expr(expr, span)
1696 }
1697 TypeInner::Matrix {
1698 columns,
1699 rows,
1700 scalar,
1701 } => {
1702 let vec_ty = self.types.insert(
1703 Type {
1704 name: None,
1705 inner: TypeInner::Vector { size: rows, scalar },
1706 },
1707 span,
1708 );
1709 let el = self.eval_zero_value_impl(vec_ty, span)?;
1710 let expr = Expression::Compose {
1711 ty,
1712 components: vec![el; columns as usize],
1713 };
1714 self.register_evaluated_expr(expr, span)
1715 }
1716 TypeInner::Array {
1717 base,
1718 size: ArraySize::Constant(size),
1719 ..
1720 } => {
1721 let el = self.eval_zero_value_impl(base, span)?;
1722 let expr = Expression::Compose {
1723 ty,
1724 components: vec![el; size.get() as usize],
1725 };
1726 self.register_evaluated_expr(expr, span)
1727 }
1728 TypeInner::Struct { ref members, .. } => {
1729 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1730 let mut components = Vec::with_capacity(members.len());
1731 for ty in types {
1732 components.push(self.eval_zero_value_impl(ty, span)?);
1733 }
1734 let expr = Expression::Compose { ty, components };
1735 self.register_evaluated_expr(expr, span)
1736 }
1737 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1738 }
1739 }
1740
1741 pub fn cast(
1745 &mut self,
1746 expr: Handle<Expression>,
1747 target: crate::Scalar,
1748 span: Span,
1749 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1750 use crate::Scalar as Sc;
1751
1752 let expr = self.eval_zero_value(expr, span)?;
1753
1754 let make_error = || -> Result<_, ConstantEvaluatorError> {
1755 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1756
1757 #[cfg(feature = "wgsl-in")]
1758 let to = target.to_wgsl_for_diagnostics();
1759
1760 #[cfg(not(feature = "wgsl-in"))]
1761 let to = format!("{target:?}");
1762
1763 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1764 };
1765
1766 use crate::proc::type_methods::IntFloatLimits;
1767
1768 let expr = match self.expressions[expr] {
1769 Expression::Literal(literal) => {
1770 let literal = match target {
1771 Sc::I32 => Literal::I32(match literal {
1772 Literal::I32(v) => v,
1773 Literal::U32(v) => v as i32,
1774 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
1775 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
1777 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1778 return make_error();
1779 }
1780 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1781 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1782 }),
1783 Sc::U32 => Literal::U32(match literal {
1784 Literal::I32(v) => v as u32,
1785 Literal::U32(v) => v,
1786 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
1787 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
1789 Literal::Bool(v) => v as u32,
1790 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1791 return make_error();
1792 }
1793 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1794 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1795 }),
1796 Sc::I64 => Literal::I64(match literal {
1797 Literal::I32(v) => v as i64,
1798 Literal::U32(v) => v as i64,
1799 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1800 Literal::Bool(v) => v as i64,
1801 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1802 Literal::I64(v) => v,
1803 Literal::U64(v) => v as i64,
1804 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1806 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1807 }),
1808 Sc::U64 => Literal::U64(match literal {
1809 Literal::I32(v) => v as u64,
1810 Literal::U32(v) => v as u64,
1811 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1812 Literal::Bool(v) => v as u64,
1813 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1814 Literal::I64(v) => v as u64,
1815 Literal::U64(v) => v,
1816 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
1818 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1819 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1820 }),
1821 Sc::F16 => Literal::F16(match literal {
1822 Literal::F16(v) => v,
1823 Literal::F32(v) => f16::from_f32(v),
1824 Literal::F64(v) => f16::from_f64(v),
1825 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
1826 Literal::I64(v) => f16::from_i64(v).unwrap(),
1827 Literal::U64(v) => f16::from_u64(v).unwrap(),
1828 Literal::I32(v) => f16::from_i32(v).unwrap(),
1829 Literal::U32(v) => f16::from_u32(v).unwrap(),
1830 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
1831 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
1832 }),
1833 Sc::F32 => Literal::F32(match literal {
1834 Literal::I32(v) => v as f32,
1835 Literal::U32(v) => v as f32,
1836 Literal::F32(v) => v,
1837 Literal::Bool(v) => v as u32 as f32,
1838 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1839 return make_error();
1840 }
1841 Literal::F16(v) => f16::to_f32(v),
1842 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1843 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1844 }),
1845 Sc::F64 => Literal::F64(match literal {
1846 Literal::I32(v) => v as f64,
1847 Literal::U32(v) => v as f64,
1848 Literal::F16(v) => f16::to_f64(v),
1849 Literal::F32(v) => v as f64,
1850 Literal::F64(v) => v,
1851 Literal::Bool(v) => v as u32 as f64,
1852 Literal::I64(_) | Literal::U64(_) => return make_error(),
1853 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1854 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1855 }),
1856 Sc::BOOL => Literal::Bool(match literal {
1857 Literal::I32(v) => v != 0,
1858 Literal::U32(v) => v != 0,
1859 Literal::F32(v) => v != 0.0,
1860 Literal::F16(v) => v != f16::zero(),
1861 Literal::Bool(v) => v,
1862 Literal::AbstractInt(v) => v != 0,
1863 Literal::AbstractFloat(v) => v != 0.0,
1864 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1865 return make_error();
1866 }
1867 }),
1868 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1869 Literal::AbstractInt(v) => {
1870 v as f64
1875 }
1876 Literal::AbstractFloat(v) => v,
1877 _ => return make_error(),
1878 }),
1879 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
1880 Literal::AbstractInt(v) => v,
1881 _ => return make_error(),
1882 }),
1883 _ => {
1884 log::debug!("Constant evaluator refused to convert value to {target:?}");
1885 return make_error();
1886 }
1887 };
1888 Expression::Literal(literal)
1889 }
1890 Expression::Compose {
1891 ty,
1892 components: ref src_components,
1893 } => {
1894 let ty_inner = match self.types[ty].inner {
1895 TypeInner::Vector { size, .. } => TypeInner::Vector {
1896 size,
1897 scalar: target,
1898 },
1899 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1900 columns,
1901 rows,
1902 scalar: target,
1903 },
1904 _ => return make_error(),
1905 };
1906
1907 let mut components = src_components.clone();
1908 for component in &mut components {
1909 *component = self.cast(*component, target, span)?;
1910 }
1911
1912 let ty = self.types.insert(
1913 Type {
1914 name: None,
1915 inner: ty_inner,
1916 },
1917 span,
1918 );
1919
1920 Expression::Compose { ty, components }
1921 }
1922 Expression::Splat { size, value } => {
1923 let value_span = self.expressions.get_span(value);
1924 let cast_value = self.cast(value, target, value_span)?;
1925 Expression::Splat {
1926 size,
1927 value: cast_value,
1928 }
1929 }
1930 _ => return make_error(),
1931 };
1932
1933 self.register_evaluated_expr(expr, span)
1934 }
1935
1936 pub fn cast_array(
1949 &mut self,
1950 expr: Handle<Expression>,
1951 target: crate::Scalar,
1952 span: Span,
1953 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1954 let expr = self.check_and_get(expr)?;
1955
1956 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1957 return self.cast(expr, target, span);
1958 };
1959
1960 let TypeInner::Array {
1961 base: _,
1962 size,
1963 stride: _,
1964 } = self.types[ty].inner
1965 else {
1966 return self.cast(expr, target, span);
1967 };
1968
1969 let mut components = components.clone();
1970 for component in &mut components {
1971 *component = self.cast_array(*component, target, span)?;
1972 }
1973
1974 let first = components.first().unwrap();
1975 let new_base = match self.resolve_type(*first)? {
1976 crate::proc::TypeResolution::Handle(ty) => ty,
1977 crate::proc::TypeResolution::Value(inner) => {
1978 self.types.insert(Type { name: None, inner }, span)
1979 }
1980 };
1981 let mut layouter = core::mem::take(self.layouter);
1982 layouter.update(self.to_ctx()).unwrap();
1983 *self.layouter = layouter;
1984
1985 let new_base_stride = self.layouter[new_base].to_stride();
1986 let new_array_ty = self.types.insert(
1987 Type {
1988 name: None,
1989 inner: TypeInner::Array {
1990 base: new_base,
1991 size,
1992 stride: new_base_stride,
1993 },
1994 },
1995 span,
1996 );
1997
1998 let compose = Expression::Compose {
1999 ty: new_array_ty,
2000 components,
2001 };
2002 self.register_evaluated_expr(compose, span)
2003 }
2004
2005 fn unary_op(
2006 &mut self,
2007 op: UnaryOperator,
2008 expr: Handle<Expression>,
2009 span: Span,
2010 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2011 let expr = self.eval_zero_value_and_splat(expr, span)?;
2012
2013 let expr = match self.expressions[expr] {
2014 Expression::Literal(value) => Expression::Literal(match op {
2015 UnaryOperator::Negate => match value {
2016 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2017 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2018 Literal::F32(v) => Literal::F32(-v),
2019 Literal::F16(v) => Literal::F16(-v),
2020 Literal::F64(v) => Literal::F64(-v),
2021 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2022 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2023 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2024 },
2025 UnaryOperator::LogicalNot => match value {
2026 Literal::Bool(v) => Literal::Bool(!v),
2027 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2028 },
2029 UnaryOperator::BitwiseNot => match value {
2030 Literal::I32(v) => Literal::I32(!v),
2031 Literal::I64(v) => Literal::I64(!v),
2032 Literal::U32(v) => Literal::U32(!v),
2033 Literal::U64(v) => Literal::U64(!v),
2034 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2035 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2036 },
2037 }),
2038 Expression::Compose {
2039 ty,
2040 components: ref src_components,
2041 } => {
2042 match self.types[ty].inner {
2043 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2044 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2045 }
2046
2047 let mut components = src_components.clone();
2048 for component in &mut components {
2049 *component = self.unary_op(op, *component, span)?;
2050 }
2051
2052 Expression::Compose { ty, components }
2053 }
2054 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2055 };
2056
2057 self.register_evaluated_expr(expr, span)
2058 }
2059
2060 fn binary_op(
2061 &mut self,
2062 op: BinaryOperator,
2063 left: Handle<Expression>,
2064 right: Handle<Expression>,
2065 span: Span,
2066 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2067 let left = self.eval_zero_value_and_splat(left, span)?;
2068 let right = self.eval_zero_value_and_splat(right, span)?;
2069
2070 let expr = match (&self.expressions[left], &self.expressions[right]) {
2071 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2072 let literal = match op {
2073 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2074 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2075 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2076 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2077 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2078 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2079
2080 _ => match (left_value, right_value) {
2081 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2082 BinaryOperator::Add => a.wrapping_add(b),
2083 BinaryOperator::Subtract => a.wrapping_sub(b),
2084 BinaryOperator::Multiply => a.wrapping_mul(b),
2085 BinaryOperator::Divide => {
2086 if b == 0 {
2087 return Err(ConstantEvaluatorError::DivisionByZero);
2088 } else {
2089 a.wrapping_div(b)
2090 }
2091 }
2092 BinaryOperator::Modulo => {
2093 if b == 0 {
2094 return Err(ConstantEvaluatorError::RemainderByZero);
2095 } else {
2096 a.wrapping_rem(b)
2097 }
2098 }
2099 BinaryOperator::And => a & b,
2100 BinaryOperator::ExclusiveOr => a ^ b,
2101 BinaryOperator::InclusiveOr => a | b,
2102 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2103 }),
2104 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2105 BinaryOperator::ShiftLeft => {
2106 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2107 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2108 }
2109 a.checked_shl(b)
2110 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2111 }
2112 BinaryOperator::ShiftRight => a
2113 .checked_shr(b)
2114 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2115 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2116 }),
2117 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2118 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2119 ConstantEvaluatorError::Overflow("addition".into())
2120 })?,
2121 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2122 ConstantEvaluatorError::Overflow("subtraction".into())
2123 })?,
2124 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2125 ConstantEvaluatorError::Overflow("multiplication".into())
2126 })?,
2127 BinaryOperator::Divide => a
2128 .checked_div(b)
2129 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2130 BinaryOperator::Modulo => a
2131 .checked_rem(b)
2132 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2133 BinaryOperator::And => a & b,
2134 BinaryOperator::ExclusiveOr => a ^ b,
2135 BinaryOperator::InclusiveOr => a | b,
2136 BinaryOperator::ShiftLeft => a
2137 .checked_mul(
2138 1u32.checked_shl(b)
2139 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2140 )
2141 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2142 BinaryOperator::ShiftRight => a
2143 .checked_shr(b)
2144 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2145 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2146 }),
2147 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2148 BinaryOperator::Add => a + b,
2149 BinaryOperator::Subtract => a - b,
2150 BinaryOperator::Multiply => a * b,
2151 BinaryOperator::Divide => a / b,
2152 BinaryOperator::Modulo => a % b,
2153 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2154 }),
2155 (Literal::AbstractInt(a), Literal::U32(b)) => {
2156 Literal::AbstractInt(match op {
2157 BinaryOperator::ShiftLeft => {
2158 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2159 return Err(ConstantEvaluatorError::Overflow(
2160 "<<".to_string(),
2161 ));
2162 }
2163 a.checked_shl(b).unwrap_or(0)
2164 }
2165 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2166 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2167 })
2168 }
2169 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2170 BinaryOperator::Add => a + b,
2171 BinaryOperator::Subtract => a - b,
2172 BinaryOperator::Multiply => a * b,
2173 BinaryOperator::Divide => a / b,
2174 BinaryOperator::Modulo => a % b,
2175 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2176 }),
2177 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2178 Literal::AbstractInt(match op {
2179 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2180 ConstantEvaluatorError::Overflow("addition".into())
2181 })?,
2182 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2183 ConstantEvaluatorError::Overflow("subtraction".into())
2184 })?,
2185 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2186 ConstantEvaluatorError::Overflow("multiplication".into())
2187 })?,
2188 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2189 if b == 0 {
2190 ConstantEvaluatorError::DivisionByZero
2191 } else {
2192 ConstantEvaluatorError::Overflow("division".into())
2193 }
2194 })?,
2195 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2196 if b == 0 {
2197 ConstantEvaluatorError::RemainderByZero
2198 } else {
2199 ConstantEvaluatorError::Overflow("remainder".into())
2200 }
2201 })?,
2202 BinaryOperator::And => a & b,
2203 BinaryOperator::ExclusiveOr => a ^ b,
2204 BinaryOperator::InclusiveOr => a | b,
2205 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2206 })
2207 }
2208 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2209 Literal::AbstractFloat(match op {
2210 BinaryOperator::Add => a + b,
2211 BinaryOperator::Subtract => a - b,
2212 BinaryOperator::Multiply => a * b,
2213 BinaryOperator::Divide => a / b,
2214 BinaryOperator::Modulo => a % b,
2215 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2216 })
2217 }
2218 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2219 BinaryOperator::LogicalAnd => a && b,
2220 BinaryOperator::LogicalOr => a || b,
2221 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2222 }),
2223 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2224 },
2225 };
2226 Expression::Literal(literal)
2227 }
2228 (
2229 &Expression::Compose {
2230 components: ref src_components,
2231 ty,
2232 },
2233 &Expression::Literal(_),
2234 ) => {
2235 let mut components = src_components.clone();
2236 for component in &mut components {
2237 *component = self.binary_op(op, *component, right, span)?;
2238 }
2239 Expression::Compose { ty, components }
2240 }
2241 (
2242 &Expression::Literal(_),
2243 &Expression::Compose {
2244 components: ref src_components,
2245 ty,
2246 },
2247 ) => {
2248 let mut components = src_components.clone();
2249 for component in &mut components {
2250 *component = self.binary_op(op, left, *component, span)?;
2251 }
2252 Expression::Compose { ty, components }
2253 }
2254 (
2255 &Expression::Compose {
2256 components: ref left_components,
2257 ty: left_ty,
2258 },
2259 &Expression::Compose {
2260 components: ref right_components,
2261 ty: right_ty,
2262 },
2263 ) => {
2264 let left_flattened = crate::proc::flatten_compose(
2268 left_ty,
2269 left_components,
2270 self.expressions,
2271 self.types,
2272 );
2273 let right_flattened = crate::proc::flatten_compose(
2274 right_ty,
2275 right_components,
2276 self.expressions,
2277 self.types,
2278 );
2279
2280 let mut flattened = Vec::with_capacity(left_components.len());
2283 flattened.extend(left_flattened.zip(right_flattened));
2284
2285 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2286 (
2287 &TypeInner::Vector {
2288 size: left_size, ..
2289 },
2290 &TypeInner::Vector {
2291 size: right_size, ..
2292 },
2293 ) if left_size == right_size => {
2294 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2295 }
2296 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2297 }
2298 }
2299 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2300 };
2301
2302 self.register_evaluated_expr(expr, span)
2303 }
2304
2305 fn binary_op_vector(
2306 &mut self,
2307 op: BinaryOperator,
2308 size: crate::VectorSize,
2309 components: &[(Handle<Expression>, Handle<Expression>)],
2310 left_ty: Handle<Type>,
2311 span: Span,
2312 ) -> Result<Expression, ConstantEvaluatorError> {
2313 let ty = match op {
2314 BinaryOperator::Equal
2316 | BinaryOperator::NotEqual
2317 | BinaryOperator::Less
2318 | BinaryOperator::LessEqual
2319 | BinaryOperator::Greater
2320 | BinaryOperator::GreaterEqual => self.types.insert(
2321 Type {
2322 name: None,
2323 inner: TypeInner::Vector {
2324 size,
2325 scalar: crate::Scalar::BOOL,
2326 },
2327 },
2328 span,
2329 ),
2330
2331 BinaryOperator::Add
2334 | BinaryOperator::Subtract
2335 | BinaryOperator::Multiply
2336 | BinaryOperator::Divide
2337 | BinaryOperator::Modulo
2338 | BinaryOperator::And
2339 | BinaryOperator::ExclusiveOr
2340 | BinaryOperator::InclusiveOr
2341 | BinaryOperator::ShiftLeft
2342 | BinaryOperator::ShiftRight => left_ty,
2343
2344 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2345 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2347 }
2348 };
2349
2350 let components = components
2351 .iter()
2352 .map(|&(left, right)| self.binary_op(op, left, right, span))
2353 .collect::<Result<Vec<_>, _>>()?;
2354
2355 Ok(Expression::Compose { ty, components })
2356 }
2357
2358 fn relational(
2359 &mut self,
2360 fun: RelationalFunction,
2361 arg: Handle<Expression>,
2362 span: Span,
2363 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2364 let arg = self.eval_zero_value_and_splat(arg, span)?;
2365 match fun {
2366 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2367 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2368 Expression::Compose { ty, ref components }
2369 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2370 {
2371 let components =
2372 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2373 .map(|component| match self.expressions[component] {
2374 Expression::Literal(Literal::Bool(val)) => Ok(val),
2375 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2376 })
2377 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2378 let result = match fun {
2379 RelationalFunction::All => components.iter().all(|c| *c),
2380 RelationalFunction::Any => components.iter().any(|c| *c),
2381 _ => unreachable!(),
2382 };
2383 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2384 }
2385 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2386 },
2387 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2388 "{fun:?} built-in function"
2389 ))),
2390 }
2391 }
2392
2393 fn copy_from(
2401 &mut self,
2402 expr: Handle<Expression>,
2403 expressions: &Arena<Expression>,
2404 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2405 let span = expressions.get_span(expr);
2406 match expressions[expr] {
2407 ref expr @ (Expression::Literal(_)
2408 | Expression::Constant(_)
2409 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2410 Expression::Compose { ty, ref components } => {
2411 let mut components = components.clone();
2412 for component in &mut components {
2413 *component = self.copy_from(*component, expressions)?;
2414 }
2415 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2416 }
2417 Expression::Splat { size, value } => {
2418 let value = self.copy_from(value, expressions)?;
2419 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2420 }
2421 _ => {
2422 log::debug!("copy_from: SubexpressionsAreNotConstant");
2423 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2424 }
2425 }
2426 }
2427
2428 fn vector_compose_flattened_size(
2430 &self,
2431 components: &[Handle<Expression>],
2432 ) -> Result<usize, ConstantEvaluatorError> {
2433 components
2434 .iter()
2435 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2436 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2437 TypeInner::Scalar(_) => 1,
2438 TypeInner::Vector { size, .. } => size as usize,
2442 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2443 };
2444 Ok(acc + size)
2445 })
2446 }
2447
2448 fn register_evaluated_expr(
2449 &mut self,
2450 expr: Expression,
2451 span: Span,
2452 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2453 if let Expression::Literal(literal) = expr {
2458 crate::valid::check_literal_value(literal)?;
2459 }
2460
2461 if let Expression::Compose { ty, ref components } = expr {
2465 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2466 let expected = size as usize;
2467 let actual = self.vector_compose_flattened_size(components)?;
2468 if expected != actual {
2469 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2470 expected,
2471 actual,
2472 });
2473 }
2474 }
2475 }
2476
2477 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2478 }
2479
2480 fn append_expr(
2481 &mut self,
2482 expr: Expression,
2483 span: Span,
2484 expr_type: ExpressionKind,
2485 ) -> Handle<Expression> {
2486 let h = match self.behavior {
2487 Behavior::Wgsl(
2488 WgslRestrictions::Runtime(ref mut function_local_data)
2489 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2490 )
2491 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2492 let is_running = function_local_data.emitter.is_running();
2493 let needs_pre_emit = expr.needs_pre_emit();
2494 if is_running && needs_pre_emit {
2495 function_local_data
2496 .block
2497 .extend(function_local_data.emitter.finish(self.expressions));
2498 let h = self.expressions.append(expr, span);
2499 function_local_data.emitter.start(self.expressions);
2500 h
2501 } else {
2502 self.expressions.append(expr, span)
2503 }
2504 }
2505 _ => self.expressions.append(expr, span),
2506 };
2507 self.expression_kind_tracker.insert(h, expr_type);
2508 h
2509 }
2510
2511 fn resolve_type(
2512 &self,
2513 expr: Handle<Expression>,
2514 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2515 use crate::proc::TypeResolution as Tr;
2516 use crate::Expression as Ex;
2517 let resolution = match self.expressions[expr] {
2518 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2519 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2520 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2521 Ex::Splat { size, value } => {
2522 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2523 return Err(ConstantEvaluatorError::SplatScalarOnly);
2524 };
2525 Tr::Value(TypeInner::Vector { scalar, size })
2526 }
2527 _ => {
2528 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2529 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2530 }
2531 };
2532
2533 Ok(resolution)
2534 }
2535
2536 fn select(
2537 &mut self,
2538 reject: Handle<Expression>,
2539 accept: Handle<Expression>,
2540 condition: Handle<Expression>,
2541 span: Span,
2542 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2543 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
2544
2545 let reject = arg(reject)?;
2546 let accept = arg(accept)?;
2547 let condition = arg(condition)?;
2548
2549 let select_single_component =
2550 |this: &mut Self, reject_scalar, reject, accept, condition| {
2551 let accept = this.cast(accept, reject_scalar, span)?;
2552 if condition {
2553 Ok(accept)
2554 } else {
2555 Ok(reject)
2556 }
2557 };
2558
2559 match (&self.expressions[reject], &self.expressions[accept]) {
2560 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
2561 let reject_scalar = reject_lit.scalar();
2562 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
2563 else {
2564 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
2565 };
2566 select_single_component(self, reject_scalar, reject, accept, condition)
2567 }
2568 (
2569 &Expression::Compose {
2570 ty: reject_ty,
2571 components: ref reject_components,
2572 },
2573 &Expression::Compose {
2574 ty: accept_ty,
2575 components: ref accept_components,
2576 },
2577 ) => {
2578 let ty_deets = |ty| {
2579 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
2580 (size.unwrap(), scalar)
2581 };
2582
2583 let expected_vec_size = {
2584 let [(reject_vec_size, _), (accept_vec_size, _)] =
2585 [reject_ty, accept_ty].map(ty_deets);
2586
2587 if reject_vec_size != accept_vec_size {
2588 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
2589 reject: reject_vec_size,
2590 accept: accept_vec_size,
2591 });
2592 }
2593 reject_vec_size
2594 };
2595
2596 let condition_components = match self.expressions[condition] {
2597 Expression::Literal(Literal::Bool(condition)) => {
2598 vec![condition; (expected_vec_size as u8).into()]
2599 }
2600 Expression::Compose {
2601 ty: condition_ty,
2602 components: ref condition_components,
2603 } => {
2604 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
2605 if condition_scalar.kind != ScalarKind::Bool {
2606 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
2607 }
2608 if condition_vec_size != expected_vec_size {
2609 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
2610 }
2611 condition_components
2612 .iter()
2613 .copied()
2614 .map(|component| match &self.expressions[component] {
2615 &Expression::Literal(Literal::Bool(condition)) => condition,
2616 _ => unreachable!(),
2617 })
2618 .collect()
2619 }
2620
2621 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
2622 };
2623
2624 let evaluated = Expression::Compose {
2625 ty: reject_ty,
2626 components: reject_components
2627 .clone()
2628 .into_iter()
2629 .zip(accept_components.clone().into_iter())
2630 .zip(condition_components.into_iter())
2631 .map(|((reject, accept), condition)| {
2632 let reject_scalar = match &self.expressions[reject] {
2633 &Expression::Literal(lit) => lit.scalar(),
2634 _ => unreachable!(),
2635 };
2636 select_single_component(self, reject_scalar, reject, accept, condition)
2637 })
2638 .collect::<Result<_, _>>()?,
2639 };
2640 self.register_evaluated_expr(evaluated, span)
2641 }
2642 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
2643 }
2644 }
2645}
2646
2647fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2648 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2652 match e {
2653 idx @ 0..=31 => idx,
2654 32 => u32::MAX,
2655 _ => unreachable!(),
2656 }
2657 };
2658 match concrete_int {
2659 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2660 ConcreteInt::I32([e]) => {
2661 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2662 }
2663 }
2664}
2665
2666#[test]
2667fn first_trailing_bit_smoke() {
2668 assert_eq!(
2669 first_trailing_bit(ConcreteInt::I32([0])),
2670 ConcreteInt::I32([-1])
2671 );
2672 assert_eq!(
2673 first_trailing_bit(ConcreteInt::I32([1])),
2674 ConcreteInt::I32([0])
2675 );
2676 assert_eq!(
2677 first_trailing_bit(ConcreteInt::I32([2])),
2678 ConcreteInt::I32([1])
2679 );
2680 assert_eq!(
2681 first_trailing_bit(ConcreteInt::I32([-1])),
2682 ConcreteInt::I32([0]),
2683 );
2684 assert_eq!(
2685 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2686 ConcreteInt::I32([31]),
2687 );
2688 assert_eq!(
2689 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2690 ConcreteInt::I32([0]),
2691 );
2692 for idx in 0..32 {
2693 assert_eq!(
2694 first_trailing_bit(ConcreteInt::I32([1 << idx])),
2695 ConcreteInt::I32([idx])
2696 )
2697 }
2698
2699 assert_eq!(
2700 first_trailing_bit(ConcreteInt::U32([0])),
2701 ConcreteInt::U32([u32::MAX])
2702 );
2703 assert_eq!(
2704 first_trailing_bit(ConcreteInt::U32([1])),
2705 ConcreteInt::U32([0])
2706 );
2707 assert_eq!(
2708 first_trailing_bit(ConcreteInt::U32([2])),
2709 ConcreteInt::U32([1])
2710 );
2711 assert_eq!(
2712 first_trailing_bit(ConcreteInt::U32([1 << 31])),
2713 ConcreteInt::U32([31]),
2714 );
2715 assert_eq!(
2716 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2717 ConcreteInt::U32([0]),
2718 );
2719 for idx in 0..32 {
2720 assert_eq!(
2721 first_trailing_bit(ConcreteInt::U32([1 << idx])),
2722 ConcreteInt::U32([idx])
2723 )
2724 }
2725}
2726
2727fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2728 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2732 match e {
2733 idx @ 0..=31 => 31 - idx,
2734 32 => u32::MAX,
2735 _ => unreachable!(),
2736 }
2737 };
2738 match concrete_int {
2739 ConcreteInt::I32([e]) => ConcreteInt::I32([{
2740 let rtl_bit_index = if e.is_negative() {
2741 e.leading_ones()
2742 } else {
2743 e.leading_zeros()
2744 };
2745 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2746 }]),
2747 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2748 }
2749}
2750
2751#[test]
2752fn first_leading_bit_smoke() {
2753 assert_eq!(
2754 first_leading_bit(ConcreteInt::I32([-1])),
2755 ConcreteInt::I32([-1])
2756 );
2757 assert_eq!(
2758 first_leading_bit(ConcreteInt::I32([0])),
2759 ConcreteInt::I32([-1])
2760 );
2761 assert_eq!(
2762 first_leading_bit(ConcreteInt::I32([1])),
2763 ConcreteInt::I32([0])
2764 );
2765 assert_eq!(
2766 first_leading_bit(ConcreteInt::I32([-2])),
2767 ConcreteInt::I32([0])
2768 );
2769 assert_eq!(
2770 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2771 ConcreteInt::I32([12])
2772 );
2773 assert_eq!(
2774 first_leading_bit(ConcreteInt::I32([i32::MAX])),
2775 ConcreteInt::I32([30])
2776 );
2777 assert_eq!(
2778 first_leading_bit(ConcreteInt::I32([i32::MIN])),
2779 ConcreteInt::I32([30])
2780 );
2781 for idx in 0..(32 - 1) {
2783 assert_eq!(
2784 first_leading_bit(ConcreteInt::I32([1 << idx])),
2785 ConcreteInt::I32([idx])
2786 );
2787 }
2788 for idx in 1..(32 - 1) {
2789 assert_eq!(
2790 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2791 ConcreteInt::I32([idx - 1])
2792 );
2793 }
2794
2795 assert_eq!(
2796 first_leading_bit(ConcreteInt::U32([0])),
2797 ConcreteInt::U32([u32::MAX])
2798 );
2799 assert_eq!(
2800 first_leading_bit(ConcreteInt::U32([1])),
2801 ConcreteInt::U32([0])
2802 );
2803 assert_eq!(
2804 first_leading_bit(ConcreteInt::U32([u32::MAX])),
2805 ConcreteInt::U32([31])
2806 );
2807 for idx in 0..32 {
2808 assert_eq!(
2809 first_leading_bit(ConcreteInt::U32([1 << idx])),
2810 ConcreteInt::U32([idx])
2811 )
2812 }
2813}
2814
2815trait TryFromAbstract<T>: Sized {
2817 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2839}
2840
2841impl TryFromAbstract<i64> for i32 {
2842 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2843 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2844 value: format!("{value:?}"),
2845 to_type: "i32",
2846 })
2847 }
2848}
2849
2850impl TryFromAbstract<i64> for u32 {
2851 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2852 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2853 value: format!("{value:?}"),
2854 to_type: "u32",
2855 })
2856 }
2857}
2858
2859impl TryFromAbstract<i64> for u64 {
2860 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2861 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2862 value: format!("{value:?}"),
2863 to_type: "u64",
2864 })
2865 }
2866}
2867
2868impl TryFromAbstract<i64> for i64 {
2869 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2870 Ok(value)
2871 }
2872}
2873
2874impl TryFromAbstract<i64> for f32 {
2875 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2876 let f = value as f32;
2877 Ok(f)
2881 }
2882}
2883
2884impl TryFromAbstract<f64> for f32 {
2885 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2886 let f = value as f32;
2887 if f.is_infinite() {
2888 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2889 value: format!("{value:?}"),
2890 to_type: "f32",
2891 });
2892 }
2893 Ok(f)
2894 }
2895}
2896
2897impl TryFromAbstract<i64> for f64 {
2898 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2899 let f = value as f64;
2900 Ok(f)
2904 }
2905}
2906
2907impl TryFromAbstract<f64> for f64 {
2908 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2909 Ok(value)
2910 }
2911}
2912
2913impl TryFromAbstract<f64> for i32 {
2914 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2915 Ok(value as i32)
2928 }
2929}
2930
2931impl TryFromAbstract<f64> for u32 {
2932 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2933 Ok(value as u32)
2936 }
2937}
2938
2939impl TryFromAbstract<f64> for i64 {
2940 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2941 use crate::proc::type_methods::IntFloatLimits;
2944 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
2945 }
2946}
2947
2948impl TryFromAbstract<f64> for u64 {
2949 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2950 use crate::proc::type_methods::IntFloatLimits;
2953 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
2954 }
2955}
2956
2957impl TryFromAbstract<f64> for f16 {
2958 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
2959 let f = f16::from_f64(value);
2960 if f.is_infinite() {
2961 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2962 value: format!("{value:?}"),
2963 to_type: "f16",
2964 });
2965 }
2966 Ok(f)
2967 }
2968}
2969
2970impl TryFromAbstract<i64> for f16 {
2971 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
2972 let f = f16::from_i64(value);
2973 if f.is_none() {
2974 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2975 value: format!("{value:?}"),
2976 to_type: "f16",
2977 });
2978 }
2979 Ok(f.unwrap())
2980 }
2981}
2982
2983fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2984where
2985 T: Copy,
2986 T: core::ops::Mul<T, Output = T>,
2987 T: core::ops::Sub<T, Output = T>,
2988{
2989 [
2990 a[1] * b[2] - a[2] * b[1],
2991 a[2] * b[0] - a[0] * b[2],
2992 a[0] * b[1] - a[1] * b[0],
2993 ]
2994}
2995
2996#[cfg(test)]
2997mod tests {
2998 use alloc::{vec, vec::Vec};
2999
3000 use crate::{
3001 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3002 UniqueArena, VectorSize,
3003 };
3004
3005 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3006
3007 #[test]
3008 fn unary_op() {
3009 let mut types = UniqueArena::new();
3010 let mut constants = Arena::new();
3011 let overrides = Arena::new();
3012 let mut global_expressions = Arena::new();
3013
3014 let scalar_ty = types.insert(
3015 Type {
3016 name: None,
3017 inner: TypeInner::Scalar(crate::Scalar::I32),
3018 },
3019 Default::default(),
3020 );
3021
3022 let vec_ty = types.insert(
3023 Type {
3024 name: None,
3025 inner: TypeInner::Vector {
3026 size: VectorSize::Bi,
3027 scalar: crate::Scalar::I32,
3028 },
3029 },
3030 Default::default(),
3031 );
3032
3033 let h = constants.append(
3034 Constant {
3035 name: None,
3036 ty: scalar_ty,
3037 init: global_expressions
3038 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3039 },
3040 Default::default(),
3041 );
3042
3043 let h1 = constants.append(
3044 Constant {
3045 name: None,
3046 ty: scalar_ty,
3047 init: global_expressions
3048 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3049 },
3050 Default::default(),
3051 );
3052
3053 let vec_h = constants.append(
3054 Constant {
3055 name: None,
3056 ty: vec_ty,
3057 init: global_expressions.append(
3058 Expression::Compose {
3059 ty: vec_ty,
3060 components: vec![constants[h].init, constants[h1].init],
3061 },
3062 Default::default(),
3063 ),
3064 },
3065 Default::default(),
3066 );
3067
3068 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3069 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3070
3071 let expr2 = Expression::Unary {
3072 op: UnaryOperator::Negate,
3073 expr,
3074 };
3075
3076 let expr3 = Expression::Unary {
3077 op: UnaryOperator::BitwiseNot,
3078 expr,
3079 };
3080
3081 let expr4 = Expression::Unary {
3082 op: UnaryOperator::BitwiseNot,
3083 expr: expr1,
3084 };
3085
3086 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3087 let mut solver = ConstantEvaluator {
3088 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3089 types: &mut types,
3090 constants: &constants,
3091 overrides: &overrides,
3092 expressions: &mut global_expressions,
3093 expression_kind_tracker,
3094 layouter: &mut crate::proc::Layouter::default(),
3095 };
3096
3097 let res1 = solver
3098 .try_eval_and_append(expr2, Default::default())
3099 .unwrap();
3100 let res2 = solver
3101 .try_eval_and_append(expr3, Default::default())
3102 .unwrap();
3103 let res3 = solver
3104 .try_eval_and_append(expr4, Default::default())
3105 .unwrap();
3106
3107 assert_eq!(
3108 global_expressions[res1],
3109 Expression::Literal(Literal::I32(-4))
3110 );
3111
3112 assert_eq!(
3113 global_expressions[res2],
3114 Expression::Literal(Literal::I32(!4))
3115 );
3116
3117 let res3_inner = &global_expressions[res3];
3118
3119 match *res3_inner {
3120 Expression::Compose {
3121 ref ty,
3122 ref components,
3123 } => {
3124 assert_eq!(*ty, vec_ty);
3125 let mut components_iter = components.iter().copied();
3126 assert_eq!(
3127 global_expressions[components_iter.next().unwrap()],
3128 Expression::Literal(Literal::I32(!4))
3129 );
3130 assert_eq!(
3131 global_expressions[components_iter.next().unwrap()],
3132 Expression::Literal(Literal::I32(!8))
3133 );
3134 assert!(components_iter.next().is_none());
3135 }
3136 _ => panic!("Expected vector"),
3137 }
3138 }
3139
3140 #[test]
3141 fn cast() {
3142 let mut types = UniqueArena::new();
3143 let mut constants = Arena::new();
3144 let overrides = Arena::new();
3145 let mut global_expressions = Arena::new();
3146
3147 let scalar_ty = types.insert(
3148 Type {
3149 name: None,
3150 inner: TypeInner::Scalar(crate::Scalar::I32),
3151 },
3152 Default::default(),
3153 );
3154
3155 let h = constants.append(
3156 Constant {
3157 name: None,
3158 ty: scalar_ty,
3159 init: global_expressions
3160 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3161 },
3162 Default::default(),
3163 );
3164
3165 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3166
3167 let root = Expression::As {
3168 expr,
3169 kind: ScalarKind::Bool,
3170 convert: Some(crate::BOOL_WIDTH),
3171 };
3172
3173 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3174 let mut solver = ConstantEvaluator {
3175 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3176 types: &mut types,
3177 constants: &constants,
3178 overrides: &overrides,
3179 expressions: &mut global_expressions,
3180 expression_kind_tracker,
3181 layouter: &mut crate::proc::Layouter::default(),
3182 };
3183
3184 let res = solver
3185 .try_eval_and_append(root, Default::default())
3186 .unwrap();
3187
3188 assert_eq!(
3189 global_expressions[res],
3190 Expression::Literal(Literal::Bool(true))
3191 );
3192 }
3193
3194 #[test]
3195 fn access() {
3196 let mut types = UniqueArena::new();
3197 let mut constants = Arena::new();
3198 let overrides = Arena::new();
3199 let mut global_expressions = Arena::new();
3200
3201 let matrix_ty = types.insert(
3202 Type {
3203 name: None,
3204 inner: TypeInner::Matrix {
3205 columns: VectorSize::Bi,
3206 rows: VectorSize::Tri,
3207 scalar: crate::Scalar::F32,
3208 },
3209 },
3210 Default::default(),
3211 );
3212
3213 let vec_ty = types.insert(
3214 Type {
3215 name: None,
3216 inner: TypeInner::Vector {
3217 size: VectorSize::Tri,
3218 scalar: crate::Scalar::F32,
3219 },
3220 },
3221 Default::default(),
3222 );
3223
3224 let mut vec1_components = Vec::with_capacity(3);
3225 let mut vec2_components = Vec::with_capacity(3);
3226
3227 for i in 0..3 {
3228 let h = global_expressions.append(
3229 Expression::Literal(Literal::F32(i as f32)),
3230 Default::default(),
3231 );
3232
3233 vec1_components.push(h)
3234 }
3235
3236 for i in 3..6 {
3237 let h = global_expressions.append(
3238 Expression::Literal(Literal::F32(i as f32)),
3239 Default::default(),
3240 );
3241
3242 vec2_components.push(h)
3243 }
3244
3245 let vec1 = constants.append(
3246 Constant {
3247 name: None,
3248 ty: vec_ty,
3249 init: global_expressions.append(
3250 Expression::Compose {
3251 ty: vec_ty,
3252 components: vec1_components,
3253 },
3254 Default::default(),
3255 ),
3256 },
3257 Default::default(),
3258 );
3259
3260 let vec2 = constants.append(
3261 Constant {
3262 name: None,
3263 ty: vec_ty,
3264 init: global_expressions.append(
3265 Expression::Compose {
3266 ty: vec_ty,
3267 components: vec2_components,
3268 },
3269 Default::default(),
3270 ),
3271 },
3272 Default::default(),
3273 );
3274
3275 let h = constants.append(
3276 Constant {
3277 name: None,
3278 ty: matrix_ty,
3279 init: global_expressions.append(
3280 Expression::Compose {
3281 ty: matrix_ty,
3282 components: vec![constants[vec1].init, constants[vec2].init],
3283 },
3284 Default::default(),
3285 ),
3286 },
3287 Default::default(),
3288 );
3289
3290 let base = global_expressions.append(Expression::Constant(h), Default::default());
3291
3292 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3293 let mut solver = ConstantEvaluator {
3294 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3295 types: &mut types,
3296 constants: &constants,
3297 overrides: &overrides,
3298 expressions: &mut global_expressions,
3299 expression_kind_tracker,
3300 layouter: &mut crate::proc::Layouter::default(),
3301 };
3302
3303 let root1 = Expression::AccessIndex { base, index: 1 };
3304
3305 let res1 = solver
3306 .try_eval_and_append(root1, Default::default())
3307 .unwrap();
3308
3309 let root2 = Expression::AccessIndex {
3310 base: res1,
3311 index: 2,
3312 };
3313
3314 let res2 = solver
3315 .try_eval_and_append(root2, Default::default())
3316 .unwrap();
3317
3318 match global_expressions[res1] {
3319 Expression::Compose {
3320 ref ty,
3321 ref components,
3322 } => {
3323 assert_eq!(*ty, vec_ty);
3324 let mut components_iter = components.iter().copied();
3325 assert_eq!(
3326 global_expressions[components_iter.next().unwrap()],
3327 Expression::Literal(Literal::F32(3.))
3328 );
3329 assert_eq!(
3330 global_expressions[components_iter.next().unwrap()],
3331 Expression::Literal(Literal::F32(4.))
3332 );
3333 assert_eq!(
3334 global_expressions[components_iter.next().unwrap()],
3335 Expression::Literal(Literal::F32(5.))
3336 );
3337 assert!(components_iter.next().is_none());
3338 }
3339 _ => panic!("Expected vector"),
3340 }
3341
3342 assert_eq!(
3343 global_expressions[res2],
3344 Expression::Literal(Literal::F32(5.))
3345 );
3346 }
3347
3348 #[test]
3349 fn compose_of_constants() {
3350 let mut types = UniqueArena::new();
3351 let mut constants = Arena::new();
3352 let overrides = Arena::new();
3353 let mut global_expressions = Arena::new();
3354
3355 let i32_ty = types.insert(
3356 Type {
3357 name: None,
3358 inner: TypeInner::Scalar(crate::Scalar::I32),
3359 },
3360 Default::default(),
3361 );
3362
3363 let vec2_i32_ty = types.insert(
3364 Type {
3365 name: None,
3366 inner: TypeInner::Vector {
3367 size: VectorSize::Bi,
3368 scalar: crate::Scalar::I32,
3369 },
3370 },
3371 Default::default(),
3372 );
3373
3374 let h = constants.append(
3375 Constant {
3376 name: None,
3377 ty: i32_ty,
3378 init: global_expressions
3379 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3380 },
3381 Default::default(),
3382 );
3383
3384 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3385
3386 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3387 let mut solver = ConstantEvaluator {
3388 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3389 types: &mut types,
3390 constants: &constants,
3391 overrides: &overrides,
3392 expressions: &mut global_expressions,
3393 expression_kind_tracker,
3394 layouter: &mut crate::proc::Layouter::default(),
3395 };
3396
3397 let solved_compose = solver
3398 .try_eval_and_append(
3399 Expression::Compose {
3400 ty: vec2_i32_ty,
3401 components: vec![h_expr, h_expr],
3402 },
3403 Default::default(),
3404 )
3405 .unwrap();
3406 let solved_negate = solver
3407 .try_eval_and_append(
3408 Expression::Unary {
3409 op: UnaryOperator::Negate,
3410 expr: solved_compose,
3411 },
3412 Default::default(),
3413 )
3414 .unwrap();
3415
3416 let pass = match global_expressions[solved_negate] {
3417 Expression::Compose { ty, ref components } => {
3418 ty == vec2_i32_ty
3419 && components.iter().all(|&component| {
3420 let component = &global_expressions[component];
3421 matches!(*component, Expression::Literal(Literal::I32(-4)))
3422 })
3423 }
3424 _ => false,
3425 };
3426 if !pass {
3427 panic!("unexpected evaluation result")
3428 }
3429 }
3430
3431 #[test]
3432 fn splat_of_constant() {
3433 let mut types = UniqueArena::new();
3434 let mut constants = Arena::new();
3435 let overrides = Arena::new();
3436 let mut global_expressions = Arena::new();
3437
3438 let i32_ty = types.insert(
3439 Type {
3440 name: None,
3441 inner: TypeInner::Scalar(crate::Scalar::I32),
3442 },
3443 Default::default(),
3444 );
3445
3446 let vec2_i32_ty = types.insert(
3447 Type {
3448 name: None,
3449 inner: TypeInner::Vector {
3450 size: VectorSize::Bi,
3451 scalar: crate::Scalar::I32,
3452 },
3453 },
3454 Default::default(),
3455 );
3456
3457 let h = constants.append(
3458 Constant {
3459 name: None,
3460 ty: i32_ty,
3461 init: global_expressions
3462 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3463 },
3464 Default::default(),
3465 );
3466
3467 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3468
3469 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3470 let mut solver = ConstantEvaluator {
3471 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3472 types: &mut types,
3473 constants: &constants,
3474 overrides: &overrides,
3475 expressions: &mut global_expressions,
3476 expression_kind_tracker,
3477 layouter: &mut crate::proc::Layouter::default(),
3478 };
3479
3480 let solved_compose = solver
3481 .try_eval_and_append(
3482 Expression::Splat {
3483 size: VectorSize::Bi,
3484 value: h_expr,
3485 },
3486 Default::default(),
3487 )
3488 .unwrap();
3489 let solved_negate = solver
3490 .try_eval_and_append(
3491 Expression::Unary {
3492 op: UnaryOperator::Negate,
3493 expr: solved_compose,
3494 },
3495 Default::default(),
3496 )
3497 .unwrap();
3498
3499 let pass = match global_expressions[solved_negate] {
3500 Expression::Compose { ty, ref components } => {
3501 ty == vec2_i32_ty
3502 && components.iter().all(|&component| {
3503 let component = &global_expressions[component];
3504 matches!(*component, Expression::Literal(Literal::I32(-4)))
3505 })
3506 }
3507 _ => false,
3508 };
3509 if !pass {
3510 panic!("unexpected evaluation result")
3511 }
3512 }
3513
3514 #[test]
3515 fn splat_of_zero_value() {
3516 let mut types = UniqueArena::new();
3517 let constants = Arena::new();
3518 let overrides = Arena::new();
3519 let mut global_expressions = Arena::new();
3520
3521 let f32_ty = types.insert(
3522 Type {
3523 name: None,
3524 inner: TypeInner::Scalar(crate::Scalar::F32),
3525 },
3526 Default::default(),
3527 );
3528
3529 let vec2_f32_ty = types.insert(
3530 Type {
3531 name: None,
3532 inner: TypeInner::Vector {
3533 size: VectorSize::Bi,
3534 scalar: crate::Scalar::F32,
3535 },
3536 },
3537 Default::default(),
3538 );
3539
3540 let five =
3541 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3542 let five_splat = global_expressions.append(
3543 Expression::Splat {
3544 size: VectorSize::Bi,
3545 value: five,
3546 },
3547 Default::default(),
3548 );
3549 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3550 let zero_splat = global_expressions.append(
3551 Expression::Splat {
3552 size: VectorSize::Bi,
3553 value: zero,
3554 },
3555 Default::default(),
3556 );
3557
3558 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3559 let mut solver = ConstantEvaluator {
3560 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3561 types: &mut types,
3562 constants: &constants,
3563 overrides: &overrides,
3564 expressions: &mut global_expressions,
3565 expression_kind_tracker,
3566 layouter: &mut crate::proc::Layouter::default(),
3567 };
3568
3569 let solved_add = solver
3570 .try_eval_and_append(
3571 Expression::Binary {
3572 op: crate::BinaryOperator::Add,
3573 left: zero_splat,
3574 right: five_splat,
3575 },
3576 Default::default(),
3577 )
3578 .unwrap();
3579
3580 let pass = match global_expressions[solved_add] {
3581 Expression::Compose { ty, ref components } => {
3582 ty == vec2_f32_ty
3583 && components.iter().all(|&component| {
3584 let component = &global_expressions[component];
3585 matches!(*component, Expression::Literal(Literal::F32(5.0)))
3586 })
3587 }
3588 _ => false,
3589 };
3590 if !pass {
3591 panic!("unexpected evaluation result")
3592 }
3593 }
3594}