1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
273enum LiteralVector {
274 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288 fn len(&self) -> usize {
289 match *self {
290 LiteralVector::F64(ref v) => v.len(),
291 LiteralVector::F32(ref v) => v.len(),
292 LiteralVector::F16(ref v) => v.len(),
293 LiteralVector::U32(ref v) => v.len(),
294 LiteralVector::I32(ref v) => v.len(),
295 LiteralVector::U64(ref v) => v.len(),
296 LiteralVector::I64(ref v) => v.len(),
297 LiteralVector::Bool(ref v) => v.len(),
298 LiteralVector::AbstractInt(ref v) => v.len(),
299 LiteralVector::AbstractFloat(ref v) => v.len(),
300 }
301 }
302
303 fn from_literal(literal: Literal) -> Self {
305 match literal {
306 Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307 Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308 Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309 Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310 Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311 Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312 Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313 Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314 Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315 Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316 }
317 }
318
319 fn from_literal_vec(
324 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325 ) -> Result<Self, ConstantEvaluatorError> {
326 assert!(!components.is_empty());
327 Ok(match components[0] {
328 Literal::I32(_) => Self::I32(
329 components
330 .iter()
331 .map(|l| match l {
332 &Literal::I32(v) => Ok(v),
333 _ => Err(ConstantEvaluatorError::InvalidMathArg),
335 })
336 .collect::<Result<_, _>>()?,
337 ),
338 Literal::U32(_) => Self::U32(
339 components
340 .iter()
341 .map(|l| match l {
342 &Literal::U32(v) => Ok(v),
343 _ => Err(ConstantEvaluatorError::InvalidMathArg),
344 })
345 .collect::<Result<_, _>>()?,
346 ),
347 Literal::I64(_) => Self::I64(
348 components
349 .iter()
350 .map(|l| match l {
351 &Literal::I64(v) => Ok(v),
352 _ => Err(ConstantEvaluatorError::InvalidMathArg),
353 })
354 .collect::<Result<_, _>>()?,
355 ),
356 Literal::U64(_) => Self::U64(
357 components
358 .iter()
359 .map(|l| match l {
360 &Literal::U64(v) => Ok(v),
361 _ => Err(ConstantEvaluatorError::InvalidMathArg),
362 })
363 .collect::<Result<_, _>>()?,
364 ),
365 Literal::F32(_) => Self::F32(
366 components
367 .iter()
368 .map(|l| match l {
369 &Literal::F32(v) => Ok(v),
370 _ => Err(ConstantEvaluatorError::InvalidMathArg),
371 })
372 .collect::<Result<_, _>>()?,
373 ),
374 Literal::F64(_) => Self::F64(
375 components
376 .iter()
377 .map(|l| match l {
378 &Literal::F64(v) => Ok(v),
379 _ => Err(ConstantEvaluatorError::InvalidMathArg),
380 })
381 .collect::<Result<_, _>>()?,
382 ),
383 Literal::Bool(_) => Self::Bool(
384 components
385 .iter()
386 .map(|l| match l {
387 &Literal::Bool(v) => Ok(v),
388 _ => Err(ConstantEvaluatorError::InvalidMathArg),
389 })
390 .collect::<Result<_, _>>()?,
391 ),
392 Literal::AbstractInt(_) => Self::AbstractInt(
393 components
394 .iter()
395 .map(|l| match l {
396 &Literal::AbstractInt(v) => Ok(v),
397 _ => Err(ConstantEvaluatorError::InvalidMathArg),
398 })
399 .collect::<Result<_, _>>()?,
400 ),
401 Literal::AbstractFloat(_) => Self::AbstractFloat(
402 components
403 .iter()
404 .map(|l| match l {
405 &Literal::AbstractFloat(v) => Ok(v),
406 _ => Err(ConstantEvaluatorError::InvalidMathArg),
407 })
408 .collect::<Result<_, _>>()?,
409 ),
410 Literal::F16(_) => Self::F16(
411 components
412 .iter()
413 .map(|l| match l {
414 &Literal::F16(v) => Ok(v),
415 _ => Err(ConstantEvaluatorError::InvalidMathArg),
416 })
417 .collect::<Result<_, _>>()?,
418 ),
419 })
420 }
421
422 #[allow(dead_code)]
423 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425 match *self {
426 LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427 LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428 LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429 LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430 LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431 LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432 LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433 LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434 LiteralVector::AbstractInt(ref v) => {
435 v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436 }
437 LiteralVector::AbstractFloat(ref v) => {
438 v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439 }
440 }
441 }
442
443 #[allow(dead_code)]
444 fn register_as_evaluated_expr(
446 &self,
447 eval: &mut ConstantEvaluator<'_>,
448 span: Span,
449 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450 let lit_vec = self.to_literal_vec();
451 assert!(!lit_vec.is_empty());
452 let expr = if lit_vec.len() == 1 {
453 Expression::Literal(lit_vec[0])
454 } else {
455 Expression::Compose {
456 ty: eval.types.insert(
457 Type {
458 name: None,
459 inner: TypeInner::Vector {
460 size: match lit_vec.len() {
461 2 => crate::VectorSize::Bi,
462 3 => crate::VectorSize::Tri,
463 4 => crate::VectorSize::Quad,
464 _ => unreachable!(),
465 },
466 scalar: lit_vec[0].scalar(),
467 },
468 },
469 Span::UNDEFINED,
470 ),
471 components: lit_vec
472 .iter()
473 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474 .collect::<Result<_, _>>()?,
475 }
476 };
477 eval.register_evaluated_expr(expr, span)
478 }
479}
480
481macro_rules! match_literal_vector {
506 (match $lit_vec:expr => $out:ident {
507 $(
508 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509 ),+
510 $(,)?
511 }) => {
512 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513 };
514
515 (@inner_start
516 $lit_vec:expr;
517 $out:ident;
518 [$($ty:ident),+];
519 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520 ) => {
521 match_literal_vector!(@inner
522 $lit_vec;
523 $out;
524 [$($ty),+];
525 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526 )
527 };
528
529 (@inner
530 $lit_vec:expr;
531 $out:ident;
532 [$ty:ident $(, $ty1:ident)*];
533 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535 ) => {
536 match_literal_vector!(@inner
537 $ty;
538 $lit_vec;
539 $out;
540 [$($ty1),*];
541 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542 [$({ $($var),+ ; $($ret)? ; $body }),+]
543 )
544 };
545 (@inner
546 Integer;
547 $lit_vec:expr;
548 $out:ident;
549 [$($ty:ident),*];
550 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551 [
552 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554 ]
555 ) => {
556 match_literal_vector!(@inner
557 $lit_vec;
558 $out;
559 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561 [
562 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568 ]
569 )
570 };
571 (@inner
572 Float;
573 $lit_vec:expr;
574 $out:ident;
575 [$($ty:ident),*];
576 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577 [
578 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580 ]
581 ) => {
582 match_literal_vector!(@inner
583 $lit_vec;
584 $out;
585 [F16, F32, F64, AbstractFloat $(, $ty)*];
586 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587 [
588 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593 ]
594 )
595 };
596 (@inner
597 $ty:ident;
598 $lit_vec:expr;
599 $out:ident;
600 [$ty1:ident $(,$ty2:ident)*];
601 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604 ]
605 ) => {
606 match_literal_vector!(@inner
607 $ty1;
608 $lit_vec;
609 $out;
610 [$($ty2),*];
611 [
612 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613 { $ty; $($var),+ ; $($ret)? ; $body }
614 ] <>
615 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617 )
618 };
619 (@inner
620 $ty:ident;
621 $lit_vec:expr;
622 $out:ident;
623 [];
624 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626 ) => {
627 match_literal_vector!(@inner_finish
628 $lit_vec;
629 $out;
630 [
631 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632 { $ty; $($var),+ ; $($ret)? ; $body }
633 ]
634 )
635 };
636 (@inner_finish
637 $lit_vec:expr;
638 $out:ident;
639 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640 ) => {
641 match $lit_vec {
642 $(
643 #[allow(unused_parens)]
644 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645 )+
646 _ => Err(ConstantEvaluatorError::InvalidMathArg),
647 }
648 };
649 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650 $out::$ty($body)
651 };
652 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653 $out::$ret($body)
654 };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659 Wgsl(WgslRestrictions<'a>),
660 Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664 const fn has_runtime_restrictions(&self) -> bool {
666 matches!(
667 self,
668 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670 )
671 }
672}
673
674#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693 behavior: Behavior<'a>,
695
696 types: &'a mut UniqueArena<Type>,
703
704 constants: &'a Arena<Constant>,
706
707 overrides: &'a Arena<Override>,
709
710 expressions: &'a mut Arena<Expression>,
712
713 expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716 layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721 Const(Option<FunctionLocalData<'a>>),
723 Override,
726 Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734 Const,
736 Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744 global_expressions: &'a Arena<Expression>,
746 emitter: &'a mut super::Emitter,
747 block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752 Const,
753 Override,
754 Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759 inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763 pub const fn new() -> Self {
764 Self {
765 inner: HandleVec::new(),
766 }
767 }
768
769 pub fn force_non_const(&mut self, value: Handle<Expression>) {
771 self.inner[value] = ExpressionKind::Runtime;
772 }
773
774 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775 self.inner.insert(value, expr_type);
776 }
777
778 pub fn is_const(&self, h: Handle<Expression>) -> bool {
779 matches!(self.type_of(h), ExpressionKind::Const)
780 }
781
782 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783 matches!(
784 self.type_of(h),
785 ExpressionKind::Const | ExpressionKind::Override
786 )
787 }
788
789 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790 self.inner[value]
791 }
792
793 pub fn from_arena(arena: &Arena<Expression>) -> Self {
794 let mut tracker = Self {
795 inner: HandleVec::with_capacity(arena.len()),
796 };
797 for (handle, expr) in arena.iter() {
798 tracker
799 .inner
800 .insert(handle, tracker.type_of_with_expr(expr));
801 }
802 tracker
803 }
804
805 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806 match *expr {
807 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808 ExpressionKind::Const
809 }
810 Expression::Override(_) => ExpressionKind::Override,
811 Expression::Compose { ref components, .. } => {
812 let mut expr_type = ExpressionKind::Const;
813 for component in components {
814 expr_type = expr_type.max(self.type_of(*component))
815 }
816 expr_type
817 }
818 Expression::Splat { value, .. } => self.type_of(value),
819 Expression::AccessIndex { base, .. } => self.type_of(base),
820 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821 Expression::Swizzle { vector, .. } => self.type_of(vector),
822 Expression::Unary { expr, .. } => self.type_of(expr),
823 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824 Expression::Math {
825 arg,
826 arg1,
827 arg2,
828 arg3,
829 ..
830 } => self
831 .type_of(arg)
832 .max(
833 arg1.map(|arg| self.type_of(arg))
834 .unwrap_or(ExpressionKind::Const),
835 )
836 .max(
837 arg2.map(|arg| self.type_of(arg))
838 .unwrap_or(ExpressionKind::Const),
839 )
840 .max(
841 arg3.map(|arg| self.type_of(arg))
842 .unwrap_or(ExpressionKind::Const),
843 ),
844 Expression::As { expr, .. } => self.type_of(expr),
845 Expression::Select {
846 condition,
847 accept,
848 reject,
849 } => self
850 .type_of(condition)
851 .max(self.type_of(accept))
852 .max(self.type_of(reject)),
853 Expression::Relational { argument, .. } => self.type_of(argument),
854 Expression::ArrayLength(expr) => self.type_of(expr),
855 _ => ExpressionKind::Runtime,
856 }
857 }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863 #[error("Constants cannot access function arguments")]
864 FunctionArg,
865 #[error("Constants cannot access global variables")]
866 GlobalVariable,
867 #[error("Constants cannot access local variables")]
868 LocalVariable,
869 #[error("Cannot get the array length of a non array type")]
870 InvalidArrayLengthArg,
871 #[error("Constants cannot get the array length of a dynamically sized array")]
872 ArrayLengthDynamic,
873 #[error("Cannot call arrayLength on array sized by override-expression")]
874 ArrayLengthOverridden,
875 #[error("Constants cannot call functions")]
876 Call,
877 #[error("Constants don't support workGroupUniformLoad")]
878 WorkGroupUniformLoadResult,
879 #[error("Constants don't support atomic functions")]
880 Atomic,
881 #[error("Constants don't support derivative functions")]
882 Derivative,
883 #[error("Constants don't support load expressions")]
884 Load,
885 #[error("Constants don't support image expressions")]
886 ImageExpression,
887 #[error("Constants don't support ray query expressions")]
888 RayQueryExpression,
889 #[error("Constants don't support subgroup expressions")]
890 SubgroupExpression,
891 #[error("Cannot access the type")]
892 InvalidAccessBase,
893 #[error("Cannot access at the index")]
894 InvalidAccessIndex,
895 #[error("Cannot access with index of type")]
896 InvalidAccessIndexTy,
897 #[error("Constants don't support array length expressions")]
898 ArrayLength,
899 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900 InvalidCastArg { from: String, to: String },
901 #[error("Cannot apply the unary op to the argument")]
902 InvalidUnaryOpArg,
903 #[error("Cannot apply the binary op to the arguments")]
904 InvalidBinaryOpArgs,
905 #[error("Cannot apply math function to type")]
906 InvalidMathArg,
907 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908 InvalidMathArgCount(crate::MathFunction, usize, usize),
909 #[error("Cannot apply relational function to type")]
910 InvalidRelationalArg(RelationalFunction),
911 #[error("value of `low` is greater than `high` for clamp built-in function")]
912 InvalidClamp,
913 #[error("Constructor expects {expected} components, found {actual}")]
914 InvalidVectorComposeLength { expected: usize, actual: usize },
915 #[error("Constructor must only contain vector or scalar arguments")]
916 InvalidVectorComposeComponent,
917 #[error("Splat is defined only on scalar values")]
918 SplatScalarOnly,
919 #[error("Can only swizzle vector constants")]
920 SwizzleVectorOnly,
921 #[error("swizzle component not present in source expression")]
922 SwizzleOutOfBounds,
923 #[error("Type is not constructible")]
924 TypeNotConstructible,
925 #[error("Subexpression(s) are not constant")]
926 SubexpressionsAreNotConstant,
927 #[error("Not implemented as constant expression: {0}")]
928 NotImplemented(String),
929 #[error("{0} operation overflowed")]
930 Overflow(String),
931 #[error(
932 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933 )]
934 AutomaticConversionLossy {
935 value: String,
936 to_type: &'static str,
937 },
938 #[error("Division by zero")]
939 DivisionByZero,
940 #[error("Remainder by zero")]
941 RemainderByZero,
942 #[error("RHS of shift operation is greater than or equal to 32")]
943 ShiftedMoreThan32Bits,
944 #[error(transparent)]
945 Literal(#[from] crate::valid::LiteralError),
946 #[error("Can't use pipeline-overridable constants in const-expressions")]
947 Override,
948 #[error("Unexpected runtime-expression")]
949 RuntimeExpr,
950 #[error("Unexpected override-expression")]
951 OverrideExpr,
952 #[error("Expected boolean expression for condition argument of `select`, got something else")]
953 SelectScalarConditionNotABool,
954 #[error(
955 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956 reject,
957 accept
958 )]
959 SelectVecRejectAcceptSizeMismatch {
960 reject: crate::VectorSize,
961 accept: crate::VectorSize,
962 },
963 #[error("Expected boolean vector for condition arg., got something else")]
964 SelectConditionNotAVecBool,
965 #[error(
966 "Expected same number of vector components between condition, accept, and reject args., got something else",
967 )]
968 SelectConditionVecSizeMismatch,
969 #[error(
970 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971 )]
972 SelectAcceptRejectTypeMismatch,
973 #[error("Cooperative operations can't be constant")]
974 CooperativeOperation,
975}
976
977impl<'a> ConstantEvaluator<'a> {
978 pub fn for_wgsl_module(
983 module: &'a mut crate::Module,
984 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
985 layouter: &'a mut crate::proc::Layouter,
986 in_override_ctx: bool,
987 ) -> Self {
988 Self::for_module(
989 Behavior::Wgsl(if in_override_ctx {
990 WgslRestrictions::Override
991 } else {
992 WgslRestrictions::Const(None)
993 }),
994 module,
995 global_expression_kind_tracker,
996 layouter,
997 )
998 }
999
1000 pub fn for_glsl_module(
1005 module: &'a mut crate::Module,
1006 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1007 layouter: &'a mut crate::proc::Layouter,
1008 ) -> Self {
1009 Self::for_module(
1010 Behavior::Glsl(GlslRestrictions::Const),
1011 module,
1012 global_expression_kind_tracker,
1013 layouter,
1014 )
1015 }
1016
1017 fn for_module(
1018 behavior: Behavior<'a>,
1019 module: &'a mut crate::Module,
1020 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1021 layouter: &'a mut crate::proc::Layouter,
1022 ) -> Self {
1023 Self {
1024 behavior,
1025 types: &mut module.types,
1026 constants: &module.constants,
1027 overrides: &module.overrides,
1028 expressions: &mut module.global_expressions,
1029 expression_kind_tracker: global_expression_kind_tracker,
1030 layouter,
1031 }
1032 }
1033
1034 pub fn for_wgsl_function(
1039 module: &'a mut crate::Module,
1040 expressions: &'a mut Arena<Expression>,
1041 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1042 layouter: &'a mut crate::proc::Layouter,
1043 emitter: &'a mut super::Emitter,
1044 block: &'a mut crate::Block,
1045 is_const: bool,
1046 ) -> Self {
1047 let local_data = FunctionLocalData {
1048 global_expressions: &module.global_expressions,
1049 emitter,
1050 block,
1051 };
1052 Self {
1053 behavior: Behavior::Wgsl(if is_const {
1054 WgslRestrictions::Const(Some(local_data))
1055 } else {
1056 WgslRestrictions::Runtime(local_data)
1057 }),
1058 types: &mut module.types,
1059 constants: &module.constants,
1060 overrides: &module.overrides,
1061 expressions,
1062 expression_kind_tracker: local_expression_kind_tracker,
1063 layouter,
1064 }
1065 }
1066
1067 pub fn for_glsl_function(
1072 module: &'a mut crate::Module,
1073 expressions: &'a mut Arena<Expression>,
1074 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1075 layouter: &'a mut crate::proc::Layouter,
1076 emitter: &'a mut super::Emitter,
1077 block: &'a mut crate::Block,
1078 ) -> Self {
1079 Self {
1080 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1081 global_expressions: &module.global_expressions,
1082 emitter,
1083 block,
1084 })),
1085 types: &mut module.types,
1086 constants: &module.constants,
1087 overrides: &module.overrides,
1088 expressions,
1089 expression_kind_tracker: local_expression_kind_tracker,
1090 layouter,
1091 }
1092 }
1093
1094 pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1095 crate::proc::GlobalCtx {
1096 types: self.types,
1097 constants: self.constants,
1098 overrides: self.overrides,
1099 global_expressions: match self.function_local_data() {
1100 Some(data) => data.global_expressions,
1101 None => self.expressions,
1102 },
1103 }
1104 }
1105
1106 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1107 if !self.expression_kind_tracker.is_const(expr) {
1108 log::debug!("check: SubexpressionsAreNotConstant");
1109 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1110 }
1111 Ok(())
1112 }
1113
1114 fn check_and_get(
1115 &mut self,
1116 expr: Handle<Expression>,
1117 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1118 match self.expressions[expr] {
1119 Expression::Constant(c) => {
1120 if let Some(function_local_data) = self.function_local_data() {
1123 self.copy_from(
1125 self.constants[c].init,
1126 function_local_data.global_expressions,
1127 )
1128 } else {
1129 Ok(self.constants[c].init)
1131 }
1132 }
1133 _ => {
1134 self.check(expr)?;
1135 Ok(expr)
1136 }
1137 }
1138 }
1139
1140 pub fn try_eval_and_append(
1164 &mut self,
1165 expr: Expression,
1166 span: Span,
1167 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1168 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1169 ExpressionKind::Const => {
1170 let eval_result = self.try_eval_and_append_impl(&expr, span);
1171 if self.behavior.has_runtime_restrictions()
1176 && matches!(
1177 eval_result,
1178 Err(ConstantEvaluatorError::NotImplemented(_)
1179 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1180 )
1181 {
1182 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1183 } else {
1184 eval_result
1185 }
1186 }
1187 ExpressionKind::Override => match self.behavior {
1188 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1189 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1190 }
1191 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1192 Err(ConstantEvaluatorError::OverrideExpr)
1193 }
1194
1195 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1197 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1198 }
1199 Behavior::Glsl(GlslRestrictions::Const) => {
1200 Err(ConstantEvaluatorError::OverrideExpr)
1201 }
1202 },
1203 ExpressionKind::Runtime => {
1204 if self.behavior.has_runtime_restrictions() {
1205 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1206 } else {
1207 Err(ConstantEvaluatorError::RuntimeExpr)
1208 }
1209 }
1210 }
1211 }
1212
1213 const fn is_global_arena(&self) -> bool {
1215 matches!(
1216 self.behavior,
1217 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1218 | Behavior::Glsl(GlslRestrictions::Const)
1219 )
1220 }
1221
1222 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1223 match self.behavior {
1224 Behavior::Wgsl(
1225 WgslRestrictions::Runtime(ref function_local_data)
1226 | WgslRestrictions::Const(Some(ref function_local_data)),
1227 )
1228 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1229 Some(function_local_data)
1230 }
1231 _ => None,
1232 }
1233 }
1234
1235 fn try_eval_and_append_impl(
1236 &mut self,
1237 expr: &Expression,
1238 span: Span,
1239 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1240 log::trace!("try_eval_and_append: {expr:?}");
1241 match *expr {
1242 Expression::Constant(c) if self.is_global_arena() => {
1243 Ok(self.constants[c].init)
1246 }
1247 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1248 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1249 self.register_evaluated_expr(expr.clone(), span)
1250 }
1251 Expression::Compose { ty, ref components } => {
1252 let components = components
1253 .iter()
1254 .map(|component| self.check_and_get(*component))
1255 .collect::<Result<Vec<_>, _>>()?;
1256 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1257 }
1258 Expression::Splat { size, value } => {
1259 let value = self.check_and_get(value)?;
1260 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1261 }
1262 Expression::AccessIndex { base, index } => {
1263 let base = self.check_and_get(base)?;
1264
1265 self.access(base, index as usize, span)
1266 }
1267 Expression::Access { base, index } => {
1268 let base = self.check_and_get(base)?;
1269 let index = self.check_and_get(index)?;
1270
1271 let index_val: u32 = self
1272 .to_ctx()
1273 .get_const_val_from(index, self.expressions)
1274 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1275 self.access(base, index_val as usize, span)
1276 }
1277 Expression::Swizzle {
1278 size,
1279 vector,
1280 pattern,
1281 } => {
1282 let vector = self.check_and_get(vector)?;
1283
1284 self.swizzle(size, span, vector, pattern)
1285 }
1286 Expression::Unary { expr, op } => {
1287 let expr = self.check_and_get(expr)?;
1288
1289 self.unary_op(op, expr, span)
1290 }
1291 Expression::Binary { left, right, op } => {
1292 let left = self.check_and_get(left)?;
1293 let right = self.check_and_get(right)?;
1294
1295 self.binary_op(op, left, right, span)
1296 }
1297 Expression::Math {
1298 fun,
1299 arg,
1300 arg1,
1301 arg2,
1302 arg3,
1303 } => {
1304 let arg = self.check_and_get(arg)?;
1305 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1306 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1307 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1308
1309 self.math(arg, arg1, arg2, arg3, fun, span)
1310 }
1311 Expression::As {
1312 convert,
1313 expr,
1314 kind,
1315 } => {
1316 let expr = self.check_and_get(expr)?;
1317
1318 match convert {
1319 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1320 None => Err(ConstantEvaluatorError::NotImplemented(
1321 "bitcast built-in function".into(),
1322 )),
1323 }
1324 }
1325 Expression::Select {
1326 reject,
1327 accept,
1328 condition,
1329 } => {
1330 let mut arg = |expr| self.check_and_get(expr);
1331
1332 let reject = arg(reject)?;
1333 let accept = arg(accept)?;
1334 let condition = arg(condition)?;
1335
1336 self.select(reject, accept, condition, span)
1337 }
1338 Expression::Relational { fun, argument } => {
1339 let argument = self.check_and_get(argument)?;
1340 self.relational(fun, argument, span)
1341 }
1342 Expression::ArrayLength(expr) => match self.behavior {
1343 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1344 Behavior::Glsl(_) => {
1345 let expr = self.check_and_get(expr)?;
1346 self.array_length(expr, span)
1347 }
1348 },
1349 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1350 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1351 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1352 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1353 Expression::WorkGroupUniformLoadResult { .. } => {
1354 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1355 }
1356 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1357 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1358 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1359 Expression::ImageSample { .. }
1360 | Expression::ImageLoad { .. }
1361 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1362 Expression::RayQueryProceedResult
1363 | Expression::RayQueryGetIntersection { .. }
1364 | Expression::RayQueryVertexPositions { .. } => {
1365 Err(ConstantEvaluatorError::RayQueryExpression)
1366 }
1367 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1368 Expression::SubgroupOperationResult { .. } => {
1369 Err(ConstantEvaluatorError::SubgroupExpression)
1370 }
1371 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1372 Err(ConstantEvaluatorError::CooperativeOperation)
1373 }
1374 }
1375 }
1376
1377 fn splat(
1390 &mut self,
1391 value: Handle<Expression>,
1392 size: crate::VectorSize,
1393 span: Span,
1394 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1395 match self.expressions[value] {
1396 Expression::Literal(literal) => {
1397 let scalar = literal.scalar();
1398 let ty = self.types.insert(
1399 Type {
1400 name: None,
1401 inner: TypeInner::Vector { size, scalar },
1402 },
1403 span,
1404 );
1405 let expr = Expression::Compose {
1406 ty,
1407 components: vec![value; size as usize],
1408 };
1409 self.register_evaluated_expr(expr, span)
1410 }
1411 Expression::ZeroValue(ty) => {
1412 let inner = match self.types[ty].inner {
1413 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1414 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1415 };
1416 let res_ty = self.types.insert(Type { name: None, inner }, span);
1417 let expr = Expression::ZeroValue(res_ty);
1418 self.register_evaluated_expr(expr, span)
1419 }
1420 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1421 }
1422 }
1423
1424 fn swizzle(
1425 &mut self,
1426 size: crate::VectorSize,
1427 span: Span,
1428 src_constant: Handle<Expression>,
1429 pattern: [crate::SwizzleComponent; 4],
1430 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1431 let mut get_dst_ty = |ty| match self.types[ty].inner {
1432 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1433 Type {
1434 name: None,
1435 inner: TypeInner::Vector { size, scalar },
1436 },
1437 span,
1438 )),
1439 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1440 };
1441
1442 match self.expressions[src_constant] {
1443 Expression::ZeroValue(ty) => {
1444 let dst_ty = get_dst_ty(ty)?;
1445 let expr = Expression::ZeroValue(dst_ty);
1446 self.register_evaluated_expr(expr, span)
1447 }
1448 Expression::Splat { value, .. } => {
1449 let expr = Expression::Splat { size, value };
1450 self.register_evaluated_expr(expr, span)
1451 }
1452 Expression::Compose { ty, ref components } => {
1453 let dst_ty = get_dst_ty(ty)?;
1454
1455 let mut flattened = [src_constant; 4]; let len =
1457 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1458 .zip(flattened.iter_mut())
1459 .map(|(component, elt)| *elt = component)
1460 .count();
1461 let flattened = &flattened[..len];
1462
1463 let swizzled_components = pattern[..size as usize]
1464 .iter()
1465 .map(|&sc| {
1466 let sc = sc as usize;
1467 if let Some(elt) = flattened.get(sc) {
1468 Ok(*elt)
1469 } else {
1470 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1471 }
1472 })
1473 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1474 let expr = Expression::Compose {
1475 ty: dst_ty,
1476 components: swizzled_components,
1477 };
1478 self.register_evaluated_expr(expr, span)
1479 }
1480 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1481 }
1482 }
1483
1484 fn math(
1485 &mut self,
1486 arg: Handle<Expression>,
1487 arg1: Option<Handle<Expression>>,
1488 arg2: Option<Handle<Expression>>,
1489 arg3: Option<Handle<Expression>>,
1490 fun: crate::MathFunction,
1491 span: Span,
1492 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1493 let expected = fun.argument_count();
1494 let given = Some(arg)
1495 .into_iter()
1496 .chain(arg1)
1497 .chain(arg2)
1498 .chain(arg3)
1499 .count();
1500 if expected != given {
1501 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1502 fun, expected, given,
1503 ));
1504 }
1505
1506 match fun {
1508 crate::MathFunction::Abs => {
1510 component_wise_scalar(self, span, [arg], |args| match args {
1511 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1512 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1513 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1514 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1515 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1516 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1518 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1519 })
1520 }
1521 crate::MathFunction::Min => {
1522 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1523 Ok([e1.min(e2)])
1524 })
1525 }
1526 crate::MathFunction::Max => {
1527 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1528 Ok([e1.max(e2)])
1529 })
1530 }
1531 crate::MathFunction::Clamp => {
1532 component_wise_scalar!(
1533 self,
1534 span,
1535 [arg, arg1.unwrap(), arg2.unwrap()],
1536 |e, low, high| {
1537 if low > high {
1538 Err(ConstantEvaluatorError::InvalidClamp)
1539 } else {
1540 Ok([e.clamp(low, high)])
1541 }
1542 }
1543 )
1544 }
1545 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1546 Float::F16([e]) => Ok(Float::F16(
1547 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1548 )),
1549 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1550 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1551 }),
1552
1553 crate::MathFunction::Cos => {
1555 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1556 }
1557 crate::MathFunction::Cosh => {
1558 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1559 }
1560 crate::MathFunction::Sin => {
1561 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1562 }
1563 crate::MathFunction::Sinh => {
1564 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1565 }
1566 crate::MathFunction::Tan => {
1567 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1568 }
1569 crate::MathFunction::Tanh => {
1570 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1571 }
1572 crate::MathFunction::Acos => {
1573 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1574 }
1575 crate::MathFunction::Asin => {
1576 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1577 }
1578 crate::MathFunction::Atan => {
1579 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1580 }
1581 crate::MathFunction::Atan2 => {
1582 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1583 Ok([y.atan2(x)])
1584 })
1585 }
1586 crate::MathFunction::Asinh => {
1587 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1588 }
1589 crate::MathFunction::Acosh => {
1590 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1591 }
1592 crate::MathFunction::Atanh => {
1593 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1594 }
1595 crate::MathFunction::Radians => {
1596 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1597 }
1598 crate::MathFunction::Degrees => {
1599 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1600 }
1601
1602 crate::MathFunction::Ceil => {
1604 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1605 }
1606 crate::MathFunction::Floor => {
1607 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1608 }
1609 crate::MathFunction::Round => {
1610 component_wise_float(self, span, [arg], |e| match e {
1611 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1612 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1613 Float::F16([e]) => {
1614 fn round_ties_even(x: f64) -> f64 {
1622 let i = x as i64;
1623 let f = (x - i as f64).abs();
1624 if f == 0.5 {
1625 if i & 1 == 1 {
1626 (x.abs() + 0.5).copysign(x)
1628 } else {
1629 (x.abs() - 0.5).copysign(x)
1630 }
1631 } else {
1632 x.round()
1633 }
1634 }
1635
1636 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1637 }
1638 })
1639 }
1640 crate::MathFunction::Fract => {
1641 component_wise_float!(self, span, [arg], |e| {
1642 Ok([e - e.floor()])
1645 })
1646 }
1647 crate::MathFunction::Trunc => {
1648 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1649 }
1650
1651 crate::MathFunction::Exp => {
1653 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1654 }
1655 crate::MathFunction::Exp2 => {
1656 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1657 }
1658 crate::MathFunction::Log => {
1659 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1660 }
1661 crate::MathFunction::Log2 => {
1662 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1663 }
1664 crate::MathFunction::Pow => {
1665 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1666 Ok([e1.powf(e2)])
1667 })
1668 }
1669
1670 crate::MathFunction::Sign => {
1672 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1673 }
1674 crate::MathFunction::Fma => {
1675 component_wise_float!(
1676 self,
1677 span,
1678 [arg, arg1.unwrap(), arg2.unwrap()],
1679 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1680 )
1681 }
1682 crate::MathFunction::Step => {
1683 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1684 Float::Abstract([edge, x]) => {
1685 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1686 }
1687 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1688 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1689 f16::one()
1690 } else {
1691 f16::zero()
1692 }])),
1693 })
1694 }
1695 crate::MathFunction::Sqrt => {
1696 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1697 }
1698 crate::MathFunction::InverseSqrt => {
1699 component_wise_float(self, span, [arg], |e| match e {
1700 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1701 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1702 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1703 })
1704 }
1705
1706 crate::MathFunction::CountTrailingZeros => {
1708 component_wise_concrete_int!(self, span, [arg], |e| {
1709 #[allow(clippy::useless_conversion)]
1710 Ok([e
1711 .trailing_zeros()
1712 .try_into()
1713 .expect("bit count overflowed 32 bits, somehow!?")])
1714 })
1715 }
1716 crate::MathFunction::CountLeadingZeros => {
1717 component_wise_concrete_int!(self, span, [arg], |e| {
1718 #[allow(clippy::useless_conversion)]
1719 Ok([e
1720 .leading_zeros()
1721 .try_into()
1722 .expect("bit count overflowed 32 bits, somehow!?")])
1723 })
1724 }
1725 crate::MathFunction::CountOneBits => {
1726 component_wise_concrete_int!(self, span, [arg], |e| {
1727 #[allow(clippy::useless_conversion)]
1728 Ok([e
1729 .count_ones()
1730 .try_into()
1731 .expect("bit count overflowed 32 bits, somehow!?")])
1732 })
1733 }
1734 crate::MathFunction::ReverseBits => {
1735 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1736 }
1737 crate::MathFunction::FirstTrailingBit => {
1738 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1739 }
1740 crate::MathFunction::FirstLeadingBit => {
1741 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1742 }
1743
1744 crate::MathFunction::Dot4I8Packed => {
1746 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1747 }
1748 crate::MathFunction::Dot4U8Packed => {
1749 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1750 }
1751 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1752 crate::MathFunction::Dot => {
1753 let e1 = self.extract_vec(arg, false)?;
1755 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1756 if e1.len() != e2.len() {
1757 return Err(ConstantEvaluatorError::InvalidMathArg);
1758 }
1759
1760 fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1761 where
1762 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1763 {
1764 a.iter()
1765 .zip(b.iter())
1766 .map(|(&aa, bb)| aa.checked_mul(bb))
1767 .try_fold(P::zero(), |acc, x| {
1768 if let Some(x) = x {
1769 acc.checked_add(&x)
1770 } else {
1771 None
1772 }
1773 })
1774 .ok_or(ConstantEvaluatorError::Overflow(
1775 "in dot built-in".to_string(),
1776 ))
1777 }
1778
1779 let result = match_literal_vector!(match (e1, e2) => Literal {
1780 Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1781 Integer => |e1, e2| { int_dot(e1, e2)? },
1782 })?;
1783 self.register_evaluated_expr(Expression::Literal(result), span)
1784 }
1785 crate::MathFunction::Length => {
1786 let e1 = self.extract_vec(arg, true)?;
1788
1789 fn float_length<F>(e: &[F]) -> F
1790 where
1791 F: core::ops::Mul<F>,
1792 F: num_traits::Float + iter::Sum,
1793 {
1794 e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1795 }
1796
1797 let result = match_literal_vector!(match e1 => Literal {
1798 Float => |e1| { float_length(e1) },
1799 })?;
1800 self.register_evaluated_expr(Expression::Literal(result), span)
1801 }
1802 crate::MathFunction::Distance => {
1803 let e1 = self.extract_vec(arg, true)?;
1805 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1806 if e1.len() != e2.len() {
1807 return Err(ConstantEvaluatorError::InvalidMathArg);
1808 }
1809
1810 fn float_distance<F>(a: &[F], b: &[F]) -> F
1811 where
1812 F: core::ops::Mul<F>,
1813 F: num_traits::Float + iter::Sum + core::ops::Sub,
1814 {
1815 a.iter()
1816 .zip(b.iter())
1817 .map(|(&aa, &bb)| aa - bb)
1818 .map(|ei| ei * ei)
1819 .sum::<F>()
1820 .sqrt()
1821 }
1822 let result = match_literal_vector!(match (e1, e2) => Literal {
1823 Float => |e1, e2| { float_distance(e1, e2) },
1824 })?;
1825 self.register_evaluated_expr(Expression::Literal(result), span)
1826 }
1827 crate::MathFunction::Normalize => {
1828 let e1 = self.extract_vec(arg, true)?;
1830
1831 fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1832 where
1833 F: core::ops::Mul<F>,
1834 F: num_traits::Float + iter::Sum,
1835 {
1836 let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1837 e.iter().map(|&ei| ei / len).collect()
1838 }
1839
1840 let result = match_literal_vector!(match e1 => LiteralVector {
1841 Float => |e1| { float_normalize(e1) },
1842 })?;
1843 result.register_as_evaluated_expr(self, span)
1844 }
1845
1846 crate::MathFunction::Modf
1848 | crate::MathFunction::Frexp
1849 | crate::MathFunction::Ldexp
1850 | crate::MathFunction::Outer
1851 | crate::MathFunction::FaceForward
1852 | crate::MathFunction::Reflect
1853 | crate::MathFunction::Refract
1854 | crate::MathFunction::Mix
1855 | crate::MathFunction::SmoothStep
1856 | crate::MathFunction::Inverse
1857 | crate::MathFunction::Transpose
1858 | crate::MathFunction::Determinant
1859 | crate::MathFunction::QuantizeToF16
1860 | crate::MathFunction::ExtractBits
1861 | crate::MathFunction::InsertBits
1862 | crate::MathFunction::Pack4x8snorm
1863 | crate::MathFunction::Pack4x8unorm
1864 | crate::MathFunction::Pack2x16snorm
1865 | crate::MathFunction::Pack2x16unorm
1866 | crate::MathFunction::Pack2x16float
1867 | crate::MathFunction::Pack4xI8
1868 | crate::MathFunction::Pack4xU8
1869 | crate::MathFunction::Pack4xI8Clamp
1870 | crate::MathFunction::Pack4xU8Clamp
1871 | crate::MathFunction::Unpack4x8snorm
1872 | crate::MathFunction::Unpack4x8unorm
1873 | crate::MathFunction::Unpack2x16snorm
1874 | crate::MathFunction::Unpack2x16unorm
1875 | crate::MathFunction::Unpack2x16float
1876 | crate::MathFunction::Unpack4xI8
1877 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1878 format!("{fun:?} built-in function"),
1879 )),
1880 }
1881 }
1882
1883 fn packed_dot_product(
1885 &mut self,
1886 a: Handle<Expression>,
1887 b: Handle<Expression>,
1888 span: Span,
1889 signed: bool,
1890 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1891 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1892 return Err(ConstantEvaluatorError::InvalidMathArg);
1893 };
1894 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1895 return Err(ConstantEvaluatorError::InvalidMathArg);
1896 };
1897
1898 let result = if signed {
1899 Literal::I32(
1900 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1901 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1902 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1903 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1904 )
1905 } else {
1906 Literal::U32(
1907 (a & 0xFF) * (b & 0xFF)
1908 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1909 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1910 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1911 )
1912 };
1913
1914 self.register_evaluated_expr(Expression::Literal(result), span)
1915 }
1916
1917 fn cross_product(
1919 &mut self,
1920 a: Handle<Expression>,
1921 b: Handle<Expression>,
1922 span: Span,
1923 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1924 use Literal as Li;
1925
1926 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1927 let (b, _) = self.extract_vec_with_size::<3>(b)?;
1928
1929 let product = match (a, b) {
1930 (
1931 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1932 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1933 ) => {
1934 let p = cross_product(
1939 [a0 as f64, a1 as f64, a2 as f64],
1940 [b0 as f64, b1 as f64, b2 as f64],
1941 );
1942 [
1943 Li::AbstractFloat(p[0]),
1944 Li::AbstractFloat(p[1]),
1945 Li::AbstractFloat(p[2]),
1946 ]
1947 }
1948 (
1949 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1950 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1951 ) => {
1952 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1953 [
1954 Li::AbstractFloat(p[0]),
1955 Li::AbstractFloat(p[1]),
1956 Li::AbstractFloat(p[2]),
1957 ]
1958 }
1959 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1960 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1961 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1962 }
1963 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1964 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1965 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1966 }
1967 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1968 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1969 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1970 }
1971 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1972 };
1973
1974 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1975 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1976 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1977
1978 self.register_evaluated_expr(
1979 Expression::Compose {
1980 ty,
1981 components: vec![p0, p1, p2],
1982 },
1983 span,
1984 )
1985 }
1986
1987 fn extract_vec_with_size<const N: usize>(
1995 &mut self,
1996 expr: Handle<Expression>,
1997 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1998 let span = self.expressions.get_span(expr);
1999 let expr = self.eval_zero_value_and_splat(expr, span)?;
2000 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2001 return Err(ConstantEvaluatorError::InvalidMathArg);
2002 };
2003
2004 let mut value = [Literal::Bool(false); N];
2005 for (component, elt) in
2006 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2007 .zip(value.iter_mut())
2008 {
2009 let Expression::Literal(literal) = self.expressions[component] else {
2010 return Err(ConstantEvaluatorError::InvalidMathArg);
2011 };
2012 *elt = literal;
2013 }
2014
2015 Ok((value, ty))
2016 }
2017
2018 fn extract_vec(
2026 &mut self,
2027 expr: Handle<Expression>,
2028 allow_single: bool,
2029 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2030 let span = self.expressions.get_span(expr);
2031 let expr = self.eval_zero_value_and_splat(expr, span)?;
2032
2033 match self.expressions[expr] {
2034 Expression::Literal(literal) if allow_single => {
2035 Ok(LiteralVector::from_literal(literal))
2036 }
2037 Expression::Compose { ty, ref components } => {
2038 let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2039 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2040 .map(|expr| match self.expressions[expr] {
2041 Expression::Literal(l) => Ok(l),
2042 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2043 })
2044 .collect::<Result<_, ConstantEvaluatorError>>()?;
2045 LiteralVector::from_literal_vec(components)
2046 }
2047 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2048 }
2049 }
2050
2051 fn array_length(
2052 &mut self,
2053 array: Handle<Expression>,
2054 span: Span,
2055 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2056 match self.expressions[array] {
2057 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2058 match self.types[ty].inner {
2059 TypeInner::Array { size, .. } => match size {
2060 ArraySize::Constant(len) => {
2061 let expr = Expression::Literal(Literal::U32(len.get()));
2062 self.register_evaluated_expr(expr, span)
2063 }
2064 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2065 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2066 },
2067 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2068 }
2069 }
2070 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2071 }
2072 }
2073
2074 fn access(
2075 &mut self,
2076 base: Handle<Expression>,
2077 index: usize,
2078 span: Span,
2079 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2080 match self.expressions[base] {
2081 Expression::ZeroValue(ty) => {
2082 let ty_inner = &self.types[ty].inner;
2083 let components = ty_inner
2084 .components()
2085 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2086
2087 if index >= components as usize {
2088 Err(ConstantEvaluatorError::InvalidAccessBase)
2089 } else {
2090 let ty_res = ty_inner
2091 .component_type(index)
2092 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2093 let ty = match ty_res {
2094 crate::proc::TypeResolution::Handle(ty) => ty,
2095 crate::proc::TypeResolution::Value(inner) => {
2096 self.types.insert(Type { name: None, inner }, span)
2097 }
2098 };
2099 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2100 }
2101 }
2102 Expression::Splat { size, value } => {
2103 if index >= size as usize {
2104 Err(ConstantEvaluatorError::InvalidAccessBase)
2105 } else {
2106 Ok(value)
2107 }
2108 }
2109 Expression::Compose { ty, ref components } => {
2110 let _ = self.types[ty]
2111 .inner
2112 .components()
2113 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2114
2115 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2116 .nth(index)
2117 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2118 }
2119 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2120 }
2121 }
2122
2123 fn eval_zero_value_and_splat(
2130 &mut self,
2131 mut expr: Handle<Expression>,
2132 span: Span,
2133 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2134 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2137 let components = components
2138 .clone()
2139 .iter()
2140 .map(|component| self.eval_zero_value_and_splat(*component, span))
2141 .collect::<Result<_, _>>()?;
2142 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2143 }
2144
2145 if let Expression::Splat { size, value } = self.expressions[expr] {
2149 expr = self.splat(value, size, span)?;
2150 }
2151 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2152 expr = self.eval_zero_value_impl(ty, span)?;
2153 }
2154 Ok(expr)
2155 }
2156
2157 fn eval_zero_value(
2163 &mut self,
2164 expr: Handle<Expression>,
2165 span: Span,
2166 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2167 match self.expressions[expr] {
2168 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2169 _ => Ok(expr),
2170 }
2171 }
2172
2173 fn eval_zero_value_impl(
2179 &mut self,
2180 ty: Handle<Type>,
2181 span: Span,
2182 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2183 match self.types[ty].inner {
2184 TypeInner::Scalar(scalar) => {
2185 let expr = Expression::Literal(
2186 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2187 );
2188 self.register_evaluated_expr(expr, span)
2189 }
2190 TypeInner::Vector { size, scalar } => {
2191 let scalar_ty = self.types.insert(
2192 Type {
2193 name: None,
2194 inner: TypeInner::Scalar(scalar),
2195 },
2196 span,
2197 );
2198 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2199 let expr = Expression::Compose {
2200 ty,
2201 components: vec![el; size as usize],
2202 };
2203 self.register_evaluated_expr(expr, span)
2204 }
2205 TypeInner::Matrix {
2206 columns,
2207 rows,
2208 scalar,
2209 } => {
2210 let vec_ty = self.types.insert(
2211 Type {
2212 name: None,
2213 inner: TypeInner::Vector { size: rows, scalar },
2214 },
2215 span,
2216 );
2217 let el = self.eval_zero_value_impl(vec_ty, span)?;
2218 let expr = Expression::Compose {
2219 ty,
2220 components: vec![el; columns as usize],
2221 };
2222 self.register_evaluated_expr(expr, span)
2223 }
2224 TypeInner::Array {
2225 base,
2226 size: ArraySize::Constant(size),
2227 ..
2228 } => {
2229 let el = self.eval_zero_value_impl(base, span)?;
2230 let expr = Expression::Compose {
2231 ty,
2232 components: vec![el; size.get() as usize],
2233 };
2234 self.register_evaluated_expr(expr, span)
2235 }
2236 TypeInner::Struct { ref members, .. } => {
2237 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2238 let mut components = Vec::with_capacity(members.len());
2239 for ty in types {
2240 components.push(self.eval_zero_value_impl(ty, span)?);
2241 }
2242 let expr = Expression::Compose { ty, components };
2243 self.register_evaluated_expr(expr, span)
2244 }
2245 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2246 }
2247 }
2248
2249 pub fn cast(
2253 &mut self,
2254 expr: Handle<Expression>,
2255 target: crate::Scalar,
2256 span: Span,
2257 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2258 use crate::Scalar as Sc;
2259
2260 let expr = self.eval_zero_value(expr, span)?;
2261
2262 let make_error = || -> Result<_, ConstantEvaluatorError> {
2263 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2264
2265 #[cfg(feature = "wgsl-in")]
2266 let to = target.to_wgsl_for_diagnostics();
2267
2268 #[cfg(not(feature = "wgsl-in"))]
2269 let to = format!("{target:?}");
2270
2271 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2272 };
2273
2274 use crate::proc::type_methods::IntFloatLimits;
2275
2276 let expr = match self.expressions[expr] {
2277 Expression::Literal(literal) => {
2278 let literal = match target {
2279 Sc::I32 => Literal::I32(match literal {
2280 Literal::I32(v) => v,
2281 Literal::U32(v) => v as i32,
2282 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2283 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2285 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2286 return make_error();
2287 }
2288 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2289 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2290 }),
2291 Sc::U32 => Literal::U32(match literal {
2292 Literal::I32(v) => v as u32,
2293 Literal::U32(v) => v,
2294 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2295 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2297 Literal::Bool(v) => v as u32,
2298 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2299 return make_error();
2300 }
2301 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2302 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2303 }),
2304 Sc::I64 => Literal::I64(match literal {
2305 Literal::I32(v) => v as i64,
2306 Literal::U32(v) => v as i64,
2307 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2308 Literal::Bool(v) => v as i64,
2309 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2310 Literal::I64(v) => v,
2311 Literal::U64(v) => v as i64,
2312 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2314 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2315 }),
2316 Sc::U64 => Literal::U64(match literal {
2317 Literal::I32(v) => v as u64,
2318 Literal::U32(v) => v as u64,
2319 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2320 Literal::Bool(v) => v as u64,
2321 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2322 Literal::I64(v) => v as u64,
2323 Literal::U64(v) => v,
2324 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2326 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2327 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2328 }),
2329 Sc::F16 => Literal::F16(match literal {
2330 Literal::F16(v) => v,
2331 Literal::F32(v) => f16::from_f32(v),
2332 Literal::F64(v) => f16::from_f64(v),
2333 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2334 Literal::I64(v) => f16::from_i64(v).unwrap(),
2335 Literal::U64(v) => f16::from_u64(v).unwrap(),
2336 Literal::I32(v) => f16::from_i32(v).unwrap(),
2337 Literal::U32(v) => f16::from_u32(v).unwrap(),
2338 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2339 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2340 }),
2341 Sc::F32 => Literal::F32(match literal {
2342 Literal::I32(v) => v as f32,
2343 Literal::U32(v) => v as f32,
2344 Literal::F32(v) => v,
2345 Literal::Bool(v) => v as u32 as f32,
2346 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2347 return make_error();
2348 }
2349 Literal::F16(v) => f16::to_f32(v),
2350 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2351 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2352 }),
2353 Sc::F64 => Literal::F64(match literal {
2354 Literal::I32(v) => v as f64,
2355 Literal::U32(v) => v as f64,
2356 Literal::F16(v) => f16::to_f64(v),
2357 Literal::F32(v) => v as f64,
2358 Literal::F64(v) => v,
2359 Literal::Bool(v) => v as u32 as f64,
2360 Literal::I64(_) | Literal::U64(_) => return make_error(),
2361 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2362 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2363 }),
2364 Sc::BOOL => Literal::Bool(match literal {
2365 Literal::I32(v) => v != 0,
2366 Literal::U32(v) => v != 0,
2367 Literal::F32(v) => v != 0.0,
2368 Literal::F16(v) => v != f16::zero(),
2369 Literal::Bool(v) => v,
2370 Literal::AbstractInt(v) => v != 0,
2371 Literal::AbstractFloat(v) => v != 0.0,
2372 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2373 return make_error();
2374 }
2375 }),
2376 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2377 Literal::AbstractInt(v) => {
2378 v as f64
2383 }
2384 Literal::AbstractFloat(v) => v,
2385 _ => return make_error(),
2386 }),
2387 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2388 Literal::AbstractInt(v) => v,
2389 _ => return make_error(),
2390 }),
2391 _ => {
2392 log::debug!("Constant evaluator refused to convert value to {target:?}");
2393 return make_error();
2394 }
2395 };
2396 Expression::Literal(literal)
2397 }
2398 Expression::Compose {
2399 ty,
2400 components: ref src_components,
2401 } => {
2402 let ty_inner = match self.types[ty].inner {
2403 TypeInner::Vector { size, .. } => TypeInner::Vector {
2404 size,
2405 scalar: target,
2406 },
2407 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2408 columns,
2409 rows,
2410 scalar: target,
2411 },
2412 _ => return make_error(),
2413 };
2414
2415 let mut components = src_components.clone();
2416 for component in &mut components {
2417 *component = self.cast(*component, target, span)?;
2418 }
2419
2420 let ty = self.types.insert(
2421 Type {
2422 name: None,
2423 inner: ty_inner,
2424 },
2425 span,
2426 );
2427
2428 Expression::Compose { ty, components }
2429 }
2430 Expression::Splat { size, value } => {
2431 let value_span = self.expressions.get_span(value);
2432 let cast_value = self.cast(value, target, value_span)?;
2433 Expression::Splat {
2434 size,
2435 value: cast_value,
2436 }
2437 }
2438 _ => return make_error(),
2439 };
2440
2441 self.register_evaluated_expr(expr, span)
2442 }
2443
2444 pub fn cast_array(
2457 &mut self,
2458 expr: Handle<Expression>,
2459 target: crate::Scalar,
2460 span: Span,
2461 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2462 let expr = self.check_and_get(expr)?;
2463
2464 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2465 return self.cast(expr, target, span);
2466 };
2467
2468 let TypeInner::Array {
2469 base: _,
2470 size,
2471 stride: _,
2472 } = self.types[ty].inner
2473 else {
2474 return self.cast(expr, target, span);
2475 };
2476
2477 let mut components = components.clone();
2478 for component in &mut components {
2479 *component = self.cast_array(*component, target, span)?;
2480 }
2481
2482 let first = components.first().unwrap();
2483 let new_base = match self.resolve_type(*first)? {
2484 crate::proc::TypeResolution::Handle(ty) => ty,
2485 crate::proc::TypeResolution::Value(inner) => {
2486 self.types.insert(Type { name: None, inner }, span)
2487 }
2488 };
2489 let mut layouter = core::mem::take(self.layouter);
2490 layouter.update(self.to_ctx()).unwrap();
2491 *self.layouter = layouter;
2492
2493 let new_base_stride = self.layouter[new_base].to_stride();
2494 let new_array_ty = self.types.insert(
2495 Type {
2496 name: None,
2497 inner: TypeInner::Array {
2498 base: new_base,
2499 size,
2500 stride: new_base_stride,
2501 },
2502 },
2503 span,
2504 );
2505
2506 let compose = Expression::Compose {
2507 ty: new_array_ty,
2508 components,
2509 };
2510 self.register_evaluated_expr(compose, span)
2511 }
2512
2513 fn unary_op(
2514 &mut self,
2515 op: UnaryOperator,
2516 expr: Handle<Expression>,
2517 span: Span,
2518 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2519 let expr = self.eval_zero_value_and_splat(expr, span)?;
2520
2521 let expr = match self.expressions[expr] {
2522 Expression::Literal(value) => Expression::Literal(match op {
2523 UnaryOperator::Negate => match value {
2524 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2525 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2526 Literal::F32(v) => Literal::F32(-v),
2527 Literal::F16(v) => Literal::F16(-v),
2528 Literal::F64(v) => Literal::F64(-v),
2529 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2530 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2531 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2532 },
2533 UnaryOperator::LogicalNot => match value {
2534 Literal::Bool(v) => Literal::Bool(!v),
2535 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2536 },
2537 UnaryOperator::BitwiseNot => match value {
2538 Literal::I32(v) => Literal::I32(!v),
2539 Literal::I64(v) => Literal::I64(!v),
2540 Literal::U32(v) => Literal::U32(!v),
2541 Literal::U64(v) => Literal::U64(!v),
2542 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2543 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2544 },
2545 }),
2546 Expression::Compose {
2547 ty,
2548 components: ref src_components,
2549 } => {
2550 match self.types[ty].inner {
2551 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2552 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2553 }
2554
2555 let mut components = src_components.clone();
2556 for component in &mut components {
2557 *component = self.unary_op(op, *component, span)?;
2558 }
2559
2560 Expression::Compose { ty, components }
2561 }
2562 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2563 };
2564
2565 self.register_evaluated_expr(expr, span)
2566 }
2567
2568 fn binary_op(
2569 &mut self,
2570 op: BinaryOperator,
2571 left: Handle<Expression>,
2572 right: Handle<Expression>,
2573 span: Span,
2574 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2575 let left = self.eval_zero_value_and_splat(left, span)?;
2576 let right = self.eval_zero_value_and_splat(right, span)?;
2577
2578 let expr = match (&self.expressions[left], &self.expressions[right]) {
2583 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2584 let literal = match op {
2585 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2586 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2587 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2588 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2589 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2590 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2591
2592 _ => match (left_value, right_value) {
2593 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2594 BinaryOperator::Add => a.wrapping_add(b),
2595 BinaryOperator::Subtract => a.wrapping_sub(b),
2596 BinaryOperator::Multiply => a.wrapping_mul(b),
2597 BinaryOperator::Divide => {
2598 if b == 0 {
2599 return Err(ConstantEvaluatorError::DivisionByZero);
2600 } else {
2601 a.wrapping_div(b)
2602 }
2603 }
2604 BinaryOperator::Modulo => {
2605 if b == 0 {
2606 return Err(ConstantEvaluatorError::RemainderByZero);
2607 } else {
2608 a.wrapping_rem(b)
2609 }
2610 }
2611 BinaryOperator::And => a & b,
2612 BinaryOperator::ExclusiveOr => a ^ b,
2613 BinaryOperator::InclusiveOr => a | b,
2614 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2615 }),
2616 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2617 BinaryOperator::ShiftLeft => {
2618 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2619 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2620 }
2621 a.checked_shl(b)
2622 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2623 }
2624 BinaryOperator::ShiftRight => a
2625 .checked_shr(b)
2626 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2627 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2628 }),
2629 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2630 BinaryOperator::Add => a.wrapping_add(b),
2631 BinaryOperator::Subtract => a.wrapping_sub(b),
2632 BinaryOperator::Multiply => a.wrapping_mul(b),
2633 BinaryOperator::Divide => a
2634 .checked_div(b)
2635 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2636 BinaryOperator::Modulo => a
2637 .checked_rem(b)
2638 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2639 BinaryOperator::And => a & b,
2640 BinaryOperator::ExclusiveOr => a ^ b,
2641 BinaryOperator::InclusiveOr => a | b,
2642 BinaryOperator::ShiftLeft => a
2643 .checked_mul(
2644 1u32.checked_shl(b)
2645 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2646 )
2647 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2648 BinaryOperator::ShiftRight => a
2649 .checked_shr(b)
2650 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2651 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2652 }),
2653 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2654 BinaryOperator::Add => a + b,
2655 BinaryOperator::Subtract => a - b,
2656 BinaryOperator::Multiply => a * b,
2657 BinaryOperator::Divide => a / b,
2658 BinaryOperator::Modulo => a % b,
2659 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2660 }),
2661 (Literal::AbstractInt(a), Literal::U32(b)) => {
2662 Literal::AbstractInt(match op {
2663 BinaryOperator::ShiftLeft => {
2664 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2665 return Err(ConstantEvaluatorError::Overflow(
2666 "<<".to_string(),
2667 ));
2668 }
2669 a.checked_shl(b).unwrap_or(0)
2670 }
2671 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2672 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2673 })
2674 }
2675 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2676 BinaryOperator::Add => a + b,
2677 BinaryOperator::Subtract => a - b,
2678 BinaryOperator::Multiply => a * b,
2679 BinaryOperator::Divide => a / b,
2680 BinaryOperator::Modulo => a % b,
2681 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2682 }),
2683 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2684 Literal::AbstractInt(match op {
2685 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2686 ConstantEvaluatorError::Overflow("addition".into())
2687 })?,
2688 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2689 ConstantEvaluatorError::Overflow("subtraction".into())
2690 })?,
2691 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2692 ConstantEvaluatorError::Overflow("multiplication".into())
2693 })?,
2694 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2695 if b == 0 {
2696 ConstantEvaluatorError::DivisionByZero
2697 } else {
2698 ConstantEvaluatorError::Overflow("division".into())
2699 }
2700 })?,
2701 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2702 if b == 0 {
2703 ConstantEvaluatorError::RemainderByZero
2704 } else {
2705 ConstantEvaluatorError::Overflow("remainder".into())
2706 }
2707 })?,
2708 BinaryOperator::And => a & b,
2709 BinaryOperator::ExclusiveOr => a ^ b,
2710 BinaryOperator::InclusiveOr => a | b,
2711 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2712 })
2713 }
2714 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2715 Literal::AbstractFloat(match op {
2716 BinaryOperator::Add => a + b,
2717 BinaryOperator::Subtract => a - b,
2718 BinaryOperator::Multiply => a * b,
2719 BinaryOperator::Divide => a / b,
2720 BinaryOperator::Modulo => a % b,
2721 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2722 })
2723 }
2724 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2725 BinaryOperator::LogicalAnd => a && b,
2726 BinaryOperator::LogicalOr => a || b,
2727 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728 }),
2729 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2730 },
2731 };
2732 Expression::Literal(literal)
2733 }
2734 (
2735 &Expression::Compose {
2736 components: ref src_components,
2737 ty,
2738 },
2739 &Expression::Literal(_),
2740 ) => {
2741 let mut components = src_components.clone();
2742 for component in &mut components {
2743 *component = self.binary_op(op, *component, right, span)?;
2744 }
2745 Expression::Compose { ty, components }
2746 }
2747 (
2748 &Expression::Literal(_),
2749 &Expression::Compose {
2750 components: ref src_components,
2751 ty,
2752 },
2753 ) => {
2754 let mut components = src_components.clone();
2755 for component in &mut components {
2756 *component = self.binary_op(op, left, *component, span)?;
2757 }
2758 Expression::Compose { ty, components }
2759 }
2760 (
2761 &Expression::Compose {
2762 components: ref left_components,
2763 ty: left_ty,
2764 },
2765 &Expression::Compose {
2766 components: ref right_components,
2767 ty: right_ty,
2768 },
2769 ) => {
2770 let left_flattened = crate::proc::flatten_compose(
2774 left_ty,
2775 left_components,
2776 self.expressions,
2777 self.types,
2778 );
2779 let right_flattened = crate::proc::flatten_compose(
2780 right_ty,
2781 right_components,
2782 self.expressions,
2783 self.types,
2784 );
2785
2786 let mut flattened = Vec::with_capacity(left_components.len());
2789 flattened.extend(left_flattened.zip(right_flattened));
2790
2791 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2792 (
2793 &TypeInner::Vector {
2794 size: left_size, ..
2795 },
2796 &TypeInner::Vector {
2797 size: right_size, ..
2798 },
2799 ) if left_size == right_size => {
2800 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2801 }
2802 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2803 }
2804 }
2805 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2806 };
2807
2808 self.register_evaluated_expr(expr, span)
2809 }
2810
2811 fn binary_op_vector(
2812 &mut self,
2813 op: BinaryOperator,
2814 size: crate::VectorSize,
2815 components: &[(Handle<Expression>, Handle<Expression>)],
2816 left_ty: Handle<Type>,
2817 span: Span,
2818 ) -> Result<Expression, ConstantEvaluatorError> {
2819 let ty = match op {
2820 BinaryOperator::Equal
2822 | BinaryOperator::NotEqual
2823 | BinaryOperator::Less
2824 | BinaryOperator::LessEqual
2825 | BinaryOperator::Greater
2826 | BinaryOperator::GreaterEqual => self.types.insert(
2827 Type {
2828 name: None,
2829 inner: TypeInner::Vector {
2830 size,
2831 scalar: crate::Scalar::BOOL,
2832 },
2833 },
2834 span,
2835 ),
2836
2837 BinaryOperator::Add
2840 | BinaryOperator::Subtract
2841 | BinaryOperator::Multiply
2842 | BinaryOperator::Divide
2843 | BinaryOperator::Modulo
2844 | BinaryOperator::And
2845 | BinaryOperator::ExclusiveOr
2846 | BinaryOperator::InclusiveOr
2847 | BinaryOperator::ShiftLeft
2848 | BinaryOperator::ShiftRight => left_ty,
2849
2850 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2851 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2853 }
2854 };
2855
2856 let components = components
2857 .iter()
2858 .map(|&(left, right)| self.binary_op(op, left, right, span))
2859 .collect::<Result<Vec<_>, _>>()?;
2860
2861 Ok(Expression::Compose { ty, components })
2862 }
2863
2864 fn relational(
2865 &mut self,
2866 fun: RelationalFunction,
2867 arg: Handle<Expression>,
2868 span: Span,
2869 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2870 let arg = self.eval_zero_value_and_splat(arg, span)?;
2871 match fun {
2872 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2873 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2874 Expression::Compose { ty, ref components }
2875 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2876 {
2877 let components =
2878 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2879 .map(|component| match self.expressions[component] {
2880 Expression::Literal(Literal::Bool(val)) => Ok(val),
2881 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2882 })
2883 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2884 let result = match fun {
2885 RelationalFunction::All => components.iter().all(|c| *c),
2886 RelationalFunction::Any => components.iter().any(|c| *c),
2887 _ => unreachable!(),
2888 };
2889 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2890 }
2891 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2892 },
2893 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2894 "{fun:?} built-in function"
2895 ))),
2896 }
2897 }
2898
2899 fn copy_from(
2907 &mut self,
2908 expr: Handle<Expression>,
2909 expressions: &Arena<Expression>,
2910 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2911 let span = expressions.get_span(expr);
2912 match expressions[expr] {
2913 ref expr @ (Expression::Literal(_)
2914 | Expression::Constant(_)
2915 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2916 Expression::Compose { ty, ref components } => {
2917 let mut components = components.clone();
2918 for component in &mut components {
2919 *component = self.copy_from(*component, expressions)?;
2920 }
2921 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2922 }
2923 Expression::Splat { size, value } => {
2924 let value = self.copy_from(value, expressions)?;
2925 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2926 }
2927 _ => {
2928 log::debug!("copy_from: SubexpressionsAreNotConstant");
2929 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2930 }
2931 }
2932 }
2933
2934 fn vector_compose_flattened_size(
2936 &self,
2937 components: &[Handle<Expression>],
2938 ) -> Result<usize, ConstantEvaluatorError> {
2939 components
2940 .iter()
2941 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2942 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2943 TypeInner::Scalar(_) => 1,
2944 TypeInner::Vector { size, .. } => size as usize,
2948 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2949 };
2950 Ok(acc + size)
2951 })
2952 }
2953
2954 fn register_evaluated_expr(
2955 &mut self,
2956 expr: Expression,
2957 span: Span,
2958 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2959 if let Expression::Literal(literal) = expr {
2964 crate::valid::check_literal_value(literal)?;
2965 }
2966
2967 if let Expression::Compose { ty, ref components } = expr {
2971 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2972 let expected = size as usize;
2973 let actual = self.vector_compose_flattened_size(components)?;
2974 if expected != actual {
2975 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2976 expected,
2977 actual,
2978 });
2979 }
2980 }
2981 }
2982
2983 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2984 }
2985
2986 fn append_expr(
2987 &mut self,
2988 expr: Expression,
2989 span: Span,
2990 expr_type: ExpressionKind,
2991 ) -> Handle<Expression> {
2992 let h = match self.behavior {
2993 Behavior::Wgsl(
2994 WgslRestrictions::Runtime(ref mut function_local_data)
2995 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2996 )
2997 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2998 let is_running = function_local_data.emitter.is_running();
2999 let needs_pre_emit = expr.needs_pre_emit();
3000 if is_running && needs_pre_emit {
3001 function_local_data
3002 .block
3003 .extend(function_local_data.emitter.finish(self.expressions));
3004 let h = self.expressions.append(expr, span);
3005 function_local_data.emitter.start(self.expressions);
3006 h
3007 } else {
3008 self.expressions.append(expr, span)
3009 }
3010 }
3011 _ => self.expressions.append(expr, span),
3012 };
3013 self.expression_kind_tracker.insert(h, expr_type);
3014 h
3015 }
3016
3017 fn resolve_type(
3018 &self,
3019 expr: Handle<Expression>,
3020 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3021 use crate::proc::TypeResolution as Tr;
3022 use crate::Expression as Ex;
3023 let resolution = match self.expressions[expr] {
3024 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3025 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3026 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3027 Ex::Splat { size, value } => {
3028 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3029 return Err(ConstantEvaluatorError::SplatScalarOnly);
3030 };
3031 Tr::Value(TypeInner::Vector { scalar, size })
3032 }
3033 _ => {
3034 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3035 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3036 }
3037 };
3038
3039 Ok(resolution)
3040 }
3041
3042 fn select(
3043 &mut self,
3044 reject: Handle<Expression>,
3045 accept: Handle<Expression>,
3046 condition: Handle<Expression>,
3047 span: Span,
3048 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3049 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3050
3051 let reject = arg(reject)?;
3052 let accept = arg(accept)?;
3053 let condition = arg(condition)?;
3054
3055 let select_single_component =
3056 |this: &mut Self, reject_scalar, reject, accept, condition| {
3057 let accept = this.cast(accept, reject_scalar, span)?;
3058 if condition {
3059 Ok(accept)
3060 } else {
3061 Ok(reject)
3062 }
3063 };
3064
3065 match (&self.expressions[reject], &self.expressions[accept]) {
3066 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3067 let reject_scalar = reject_lit.scalar();
3068 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3069 else {
3070 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3071 };
3072 select_single_component(self, reject_scalar, reject, accept, condition)
3073 }
3074 (
3075 &Expression::Compose {
3076 ty: reject_ty,
3077 components: ref reject_components,
3078 },
3079 &Expression::Compose {
3080 ty: accept_ty,
3081 components: ref accept_components,
3082 },
3083 ) => {
3084 let ty_deets = |ty| {
3085 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3086 (size.unwrap(), scalar)
3087 };
3088
3089 let expected_vec_size = {
3090 let [(reject_vec_size, _), (accept_vec_size, _)] =
3091 [reject_ty, accept_ty].map(ty_deets);
3092
3093 if reject_vec_size != accept_vec_size {
3094 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3095 reject: reject_vec_size,
3096 accept: accept_vec_size,
3097 });
3098 }
3099 reject_vec_size
3100 };
3101
3102 let condition_components = match self.expressions[condition] {
3103 Expression::Literal(Literal::Bool(condition)) => {
3104 vec![condition; (expected_vec_size as u8).into()]
3105 }
3106 Expression::Compose {
3107 ty: condition_ty,
3108 components: ref condition_components,
3109 } => {
3110 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3111 if condition_scalar.kind != ScalarKind::Bool {
3112 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3113 }
3114 if condition_vec_size != expected_vec_size {
3115 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3116 }
3117 condition_components
3118 .iter()
3119 .copied()
3120 .map(|component| match &self.expressions[component] {
3121 &Expression::Literal(Literal::Bool(condition)) => condition,
3122 _ => unreachable!(),
3123 })
3124 .collect()
3125 }
3126
3127 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3128 };
3129
3130 let evaluated = Expression::Compose {
3131 ty: reject_ty,
3132 components: reject_components
3133 .clone()
3134 .into_iter()
3135 .zip(accept_components.clone().into_iter())
3136 .zip(condition_components.into_iter())
3137 .map(|((reject, accept), condition)| {
3138 let reject_scalar = match &self.expressions[reject] {
3139 &Expression::Literal(lit) => lit.scalar(),
3140 _ => unreachable!(),
3141 };
3142 select_single_component(self, reject_scalar, reject, accept, condition)
3143 })
3144 .collect::<Result<_, _>>()?,
3145 };
3146 self.register_evaluated_expr(evaluated, span)
3147 }
3148 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3149 }
3150 }
3151}
3152
3153fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3154 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3158 match e {
3159 idx @ 0..=31 => idx,
3160 32 => u32::MAX,
3161 _ => unreachable!(),
3162 }
3163 };
3164 match concrete_int {
3165 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3166 ConcreteInt::I32([e]) => {
3167 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3168 }
3169 }
3170}
3171
3172#[test]
3173fn first_trailing_bit_smoke() {
3174 assert_eq!(
3175 first_trailing_bit(ConcreteInt::I32([0])),
3176 ConcreteInt::I32([-1])
3177 );
3178 assert_eq!(
3179 first_trailing_bit(ConcreteInt::I32([1])),
3180 ConcreteInt::I32([0])
3181 );
3182 assert_eq!(
3183 first_trailing_bit(ConcreteInt::I32([2])),
3184 ConcreteInt::I32([1])
3185 );
3186 assert_eq!(
3187 first_trailing_bit(ConcreteInt::I32([-1])),
3188 ConcreteInt::I32([0]),
3189 );
3190 assert_eq!(
3191 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3192 ConcreteInt::I32([31]),
3193 );
3194 assert_eq!(
3195 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3196 ConcreteInt::I32([0]),
3197 );
3198 for idx in 0..32 {
3199 assert_eq!(
3200 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3201 ConcreteInt::I32([idx])
3202 )
3203 }
3204
3205 assert_eq!(
3206 first_trailing_bit(ConcreteInt::U32([0])),
3207 ConcreteInt::U32([u32::MAX])
3208 );
3209 assert_eq!(
3210 first_trailing_bit(ConcreteInt::U32([1])),
3211 ConcreteInt::U32([0])
3212 );
3213 assert_eq!(
3214 first_trailing_bit(ConcreteInt::U32([2])),
3215 ConcreteInt::U32([1])
3216 );
3217 assert_eq!(
3218 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3219 ConcreteInt::U32([31]),
3220 );
3221 assert_eq!(
3222 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3223 ConcreteInt::U32([0]),
3224 );
3225 for idx in 0..32 {
3226 assert_eq!(
3227 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3228 ConcreteInt::U32([idx])
3229 )
3230 }
3231}
3232
3233fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3234 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3238 match e {
3239 idx @ 0..=31 => 31 - idx,
3240 32 => u32::MAX,
3241 _ => unreachable!(),
3242 }
3243 };
3244 match concrete_int {
3245 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3246 let rtl_bit_index = if e.is_negative() {
3247 e.leading_ones()
3248 } else {
3249 e.leading_zeros()
3250 };
3251 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3252 }]),
3253 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3254 }
3255}
3256
3257#[test]
3258fn first_leading_bit_smoke() {
3259 assert_eq!(
3260 first_leading_bit(ConcreteInt::I32([-1])),
3261 ConcreteInt::I32([-1])
3262 );
3263 assert_eq!(
3264 first_leading_bit(ConcreteInt::I32([0])),
3265 ConcreteInt::I32([-1])
3266 );
3267 assert_eq!(
3268 first_leading_bit(ConcreteInt::I32([1])),
3269 ConcreteInt::I32([0])
3270 );
3271 assert_eq!(
3272 first_leading_bit(ConcreteInt::I32([-2])),
3273 ConcreteInt::I32([0])
3274 );
3275 assert_eq!(
3276 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3277 ConcreteInt::I32([12])
3278 );
3279 assert_eq!(
3280 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3281 ConcreteInt::I32([30])
3282 );
3283 assert_eq!(
3284 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3285 ConcreteInt::I32([30])
3286 );
3287 for idx in 0..(32 - 1) {
3289 assert_eq!(
3290 first_leading_bit(ConcreteInt::I32([1 << idx])),
3291 ConcreteInt::I32([idx])
3292 );
3293 }
3294 for idx in 1..(32 - 1) {
3295 assert_eq!(
3296 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3297 ConcreteInt::I32([idx - 1])
3298 );
3299 }
3300
3301 assert_eq!(
3302 first_leading_bit(ConcreteInt::U32([0])),
3303 ConcreteInt::U32([u32::MAX])
3304 );
3305 assert_eq!(
3306 first_leading_bit(ConcreteInt::U32([1])),
3307 ConcreteInt::U32([0])
3308 );
3309 assert_eq!(
3310 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3311 ConcreteInt::U32([31])
3312 );
3313 for idx in 0..32 {
3314 assert_eq!(
3315 first_leading_bit(ConcreteInt::U32([1 << idx])),
3316 ConcreteInt::U32([idx])
3317 )
3318 }
3319}
3320
3321trait TryFromAbstract<T>: Sized {
3323 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3345}
3346
3347impl TryFromAbstract<i64> for i32 {
3348 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3349 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3350 value: format!("{value:?}"),
3351 to_type: "i32",
3352 })
3353 }
3354}
3355
3356impl TryFromAbstract<i64> for u32 {
3357 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3358 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3359 value: format!("{value:?}"),
3360 to_type: "u32",
3361 })
3362 }
3363}
3364
3365impl TryFromAbstract<i64> for u64 {
3366 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3367 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3368 value: format!("{value:?}"),
3369 to_type: "u64",
3370 })
3371 }
3372}
3373
3374impl TryFromAbstract<i64> for i64 {
3375 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3376 Ok(value)
3377 }
3378}
3379
3380impl TryFromAbstract<i64> for f32 {
3381 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3382 let f = value as f32;
3383 Ok(f)
3387 }
3388}
3389
3390impl TryFromAbstract<f64> for f32 {
3391 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3392 let f = value as f32;
3393 if f.is_infinite() {
3394 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3395 value: format!("{value:?}"),
3396 to_type: "f32",
3397 });
3398 }
3399 Ok(f)
3400 }
3401}
3402
3403impl TryFromAbstract<i64> for f64 {
3404 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3405 let f = value as f64;
3406 Ok(f)
3410 }
3411}
3412
3413impl TryFromAbstract<f64> for f64 {
3414 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3415 Ok(value)
3416 }
3417}
3418
3419impl TryFromAbstract<f64> for i32 {
3420 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3421 Ok(value as i32)
3434 }
3435}
3436
3437impl TryFromAbstract<f64> for u32 {
3438 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3439 Ok(value as u32)
3442 }
3443}
3444
3445impl TryFromAbstract<f64> for i64 {
3446 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3447 use crate::proc::type_methods::IntFloatLimits;
3450 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3451 }
3452}
3453
3454impl TryFromAbstract<f64> for u64 {
3455 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3456 use crate::proc::type_methods::IntFloatLimits;
3459 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3460 }
3461}
3462
3463impl TryFromAbstract<f64> for f16 {
3464 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3465 let f = f16::from_f64(value);
3466 if f.is_infinite() {
3467 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3468 value: format!("{value:?}"),
3469 to_type: "f16",
3470 });
3471 }
3472 Ok(f)
3473 }
3474}
3475
3476impl TryFromAbstract<i64> for f16 {
3477 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3478 let f = f16::from_i64(value);
3479 if f.is_none() {
3480 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3481 value: format!("{value:?}"),
3482 to_type: "f16",
3483 });
3484 }
3485 Ok(f.unwrap())
3486 }
3487}
3488
3489fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3490where
3491 T: Copy,
3492 T: core::ops::Mul<T, Output = T>,
3493 T: core::ops::Sub<T, Output = T>,
3494{
3495 [
3496 a[1] * b[2] - a[2] * b[1],
3497 a[2] * b[0] - a[0] * b[2],
3498 a[0] * b[1] - a[1] * b[0],
3499 ]
3500}
3501
3502#[cfg(test)]
3503mod tests {
3504 use alloc::{vec, vec::Vec};
3505
3506 use crate::{
3507 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3508 UniqueArena, VectorSize,
3509 };
3510
3511 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3512
3513 #[test]
3514 fn unary_op() {
3515 let mut types = UniqueArena::new();
3516 let mut constants = Arena::new();
3517 let overrides = Arena::new();
3518 let mut global_expressions = Arena::new();
3519
3520 let scalar_ty = types.insert(
3521 Type {
3522 name: None,
3523 inner: TypeInner::Scalar(crate::Scalar::I32),
3524 },
3525 Default::default(),
3526 );
3527
3528 let vec_ty = types.insert(
3529 Type {
3530 name: None,
3531 inner: TypeInner::Vector {
3532 size: VectorSize::Bi,
3533 scalar: crate::Scalar::I32,
3534 },
3535 },
3536 Default::default(),
3537 );
3538
3539 let h = constants.append(
3540 Constant {
3541 name: None,
3542 ty: scalar_ty,
3543 init: global_expressions
3544 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3545 },
3546 Default::default(),
3547 );
3548
3549 let h1 = constants.append(
3550 Constant {
3551 name: None,
3552 ty: scalar_ty,
3553 init: global_expressions
3554 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3555 },
3556 Default::default(),
3557 );
3558
3559 let vec_h = constants.append(
3560 Constant {
3561 name: None,
3562 ty: vec_ty,
3563 init: global_expressions.append(
3564 Expression::Compose {
3565 ty: vec_ty,
3566 components: vec![constants[h].init, constants[h1].init],
3567 },
3568 Default::default(),
3569 ),
3570 },
3571 Default::default(),
3572 );
3573
3574 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3575 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3576
3577 let expr2 = Expression::Unary {
3578 op: UnaryOperator::Negate,
3579 expr,
3580 };
3581
3582 let expr3 = Expression::Unary {
3583 op: UnaryOperator::BitwiseNot,
3584 expr,
3585 };
3586
3587 let expr4 = Expression::Unary {
3588 op: UnaryOperator::BitwiseNot,
3589 expr: expr1,
3590 };
3591
3592 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3593 let mut solver = ConstantEvaluator {
3594 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3595 types: &mut types,
3596 constants: &constants,
3597 overrides: &overrides,
3598 expressions: &mut global_expressions,
3599 expression_kind_tracker,
3600 layouter: &mut crate::proc::Layouter::default(),
3601 };
3602
3603 let res1 = solver
3604 .try_eval_and_append(expr2, Default::default())
3605 .unwrap();
3606 let res2 = solver
3607 .try_eval_and_append(expr3, Default::default())
3608 .unwrap();
3609 let res3 = solver
3610 .try_eval_and_append(expr4, Default::default())
3611 .unwrap();
3612
3613 assert_eq!(
3614 global_expressions[res1],
3615 Expression::Literal(Literal::I32(-4))
3616 );
3617
3618 assert_eq!(
3619 global_expressions[res2],
3620 Expression::Literal(Literal::I32(!4))
3621 );
3622
3623 let res3_inner = &global_expressions[res3];
3624
3625 match *res3_inner {
3626 Expression::Compose {
3627 ref ty,
3628 ref components,
3629 } => {
3630 assert_eq!(*ty, vec_ty);
3631 let mut components_iter = components.iter().copied();
3632 assert_eq!(
3633 global_expressions[components_iter.next().unwrap()],
3634 Expression::Literal(Literal::I32(!4))
3635 );
3636 assert_eq!(
3637 global_expressions[components_iter.next().unwrap()],
3638 Expression::Literal(Literal::I32(!8))
3639 );
3640 assert!(components_iter.next().is_none());
3641 }
3642 _ => panic!("Expected vector"),
3643 }
3644 }
3645
3646 #[test]
3647 fn cast() {
3648 let mut types = UniqueArena::new();
3649 let mut constants = Arena::new();
3650 let overrides = Arena::new();
3651 let mut global_expressions = Arena::new();
3652
3653 let scalar_ty = types.insert(
3654 Type {
3655 name: None,
3656 inner: TypeInner::Scalar(crate::Scalar::I32),
3657 },
3658 Default::default(),
3659 );
3660
3661 let h = constants.append(
3662 Constant {
3663 name: None,
3664 ty: scalar_ty,
3665 init: global_expressions
3666 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3667 },
3668 Default::default(),
3669 );
3670
3671 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3672
3673 let root = Expression::As {
3674 expr,
3675 kind: ScalarKind::Bool,
3676 convert: Some(crate::BOOL_WIDTH),
3677 };
3678
3679 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3680 let mut solver = ConstantEvaluator {
3681 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3682 types: &mut types,
3683 constants: &constants,
3684 overrides: &overrides,
3685 expressions: &mut global_expressions,
3686 expression_kind_tracker,
3687 layouter: &mut crate::proc::Layouter::default(),
3688 };
3689
3690 let res = solver
3691 .try_eval_and_append(root, Default::default())
3692 .unwrap();
3693
3694 assert_eq!(
3695 global_expressions[res],
3696 Expression::Literal(Literal::Bool(true))
3697 );
3698 }
3699
3700 #[test]
3701 fn access() {
3702 let mut types = UniqueArena::new();
3703 let mut constants = Arena::new();
3704 let overrides = Arena::new();
3705 let mut global_expressions = Arena::new();
3706
3707 let matrix_ty = types.insert(
3708 Type {
3709 name: None,
3710 inner: TypeInner::Matrix {
3711 columns: VectorSize::Bi,
3712 rows: VectorSize::Tri,
3713 scalar: crate::Scalar::F32,
3714 },
3715 },
3716 Default::default(),
3717 );
3718
3719 let vec_ty = types.insert(
3720 Type {
3721 name: None,
3722 inner: TypeInner::Vector {
3723 size: VectorSize::Tri,
3724 scalar: crate::Scalar::F32,
3725 },
3726 },
3727 Default::default(),
3728 );
3729
3730 let mut vec1_components = Vec::with_capacity(3);
3731 let mut vec2_components = Vec::with_capacity(3);
3732
3733 for i in 0..3 {
3734 let h = global_expressions.append(
3735 Expression::Literal(Literal::F32(i as f32)),
3736 Default::default(),
3737 );
3738
3739 vec1_components.push(h)
3740 }
3741
3742 for i in 3..6 {
3743 let h = global_expressions.append(
3744 Expression::Literal(Literal::F32(i as f32)),
3745 Default::default(),
3746 );
3747
3748 vec2_components.push(h)
3749 }
3750
3751 let vec1 = constants.append(
3752 Constant {
3753 name: None,
3754 ty: vec_ty,
3755 init: global_expressions.append(
3756 Expression::Compose {
3757 ty: vec_ty,
3758 components: vec1_components,
3759 },
3760 Default::default(),
3761 ),
3762 },
3763 Default::default(),
3764 );
3765
3766 let vec2 = constants.append(
3767 Constant {
3768 name: None,
3769 ty: vec_ty,
3770 init: global_expressions.append(
3771 Expression::Compose {
3772 ty: vec_ty,
3773 components: vec2_components,
3774 },
3775 Default::default(),
3776 ),
3777 },
3778 Default::default(),
3779 );
3780
3781 let h = constants.append(
3782 Constant {
3783 name: None,
3784 ty: matrix_ty,
3785 init: global_expressions.append(
3786 Expression::Compose {
3787 ty: matrix_ty,
3788 components: vec![constants[vec1].init, constants[vec2].init],
3789 },
3790 Default::default(),
3791 ),
3792 },
3793 Default::default(),
3794 );
3795
3796 let base = global_expressions.append(Expression::Constant(h), Default::default());
3797
3798 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3799 let mut solver = ConstantEvaluator {
3800 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3801 types: &mut types,
3802 constants: &constants,
3803 overrides: &overrides,
3804 expressions: &mut global_expressions,
3805 expression_kind_tracker,
3806 layouter: &mut crate::proc::Layouter::default(),
3807 };
3808
3809 let root1 = Expression::AccessIndex { base, index: 1 };
3810
3811 let res1 = solver
3812 .try_eval_and_append(root1, Default::default())
3813 .unwrap();
3814
3815 let root2 = Expression::AccessIndex {
3816 base: res1,
3817 index: 2,
3818 };
3819
3820 let res2 = solver
3821 .try_eval_and_append(root2, Default::default())
3822 .unwrap();
3823
3824 match global_expressions[res1] {
3825 Expression::Compose {
3826 ref ty,
3827 ref components,
3828 } => {
3829 assert_eq!(*ty, vec_ty);
3830 let mut components_iter = components.iter().copied();
3831 assert_eq!(
3832 global_expressions[components_iter.next().unwrap()],
3833 Expression::Literal(Literal::F32(3.))
3834 );
3835 assert_eq!(
3836 global_expressions[components_iter.next().unwrap()],
3837 Expression::Literal(Literal::F32(4.))
3838 );
3839 assert_eq!(
3840 global_expressions[components_iter.next().unwrap()],
3841 Expression::Literal(Literal::F32(5.))
3842 );
3843 assert!(components_iter.next().is_none());
3844 }
3845 _ => panic!("Expected vector"),
3846 }
3847
3848 assert_eq!(
3849 global_expressions[res2],
3850 Expression::Literal(Literal::F32(5.))
3851 );
3852 }
3853
3854 #[test]
3855 fn compose_of_constants() {
3856 let mut types = UniqueArena::new();
3857 let mut constants = Arena::new();
3858 let overrides = Arena::new();
3859 let mut global_expressions = Arena::new();
3860
3861 let i32_ty = types.insert(
3862 Type {
3863 name: None,
3864 inner: TypeInner::Scalar(crate::Scalar::I32),
3865 },
3866 Default::default(),
3867 );
3868
3869 let vec2_i32_ty = types.insert(
3870 Type {
3871 name: None,
3872 inner: TypeInner::Vector {
3873 size: VectorSize::Bi,
3874 scalar: crate::Scalar::I32,
3875 },
3876 },
3877 Default::default(),
3878 );
3879
3880 let h = constants.append(
3881 Constant {
3882 name: None,
3883 ty: i32_ty,
3884 init: global_expressions
3885 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3886 },
3887 Default::default(),
3888 );
3889
3890 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3891
3892 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3893 let mut solver = ConstantEvaluator {
3894 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3895 types: &mut types,
3896 constants: &constants,
3897 overrides: &overrides,
3898 expressions: &mut global_expressions,
3899 expression_kind_tracker,
3900 layouter: &mut crate::proc::Layouter::default(),
3901 };
3902
3903 let solved_compose = solver
3904 .try_eval_and_append(
3905 Expression::Compose {
3906 ty: vec2_i32_ty,
3907 components: vec![h_expr, h_expr],
3908 },
3909 Default::default(),
3910 )
3911 .unwrap();
3912 let solved_negate = solver
3913 .try_eval_and_append(
3914 Expression::Unary {
3915 op: UnaryOperator::Negate,
3916 expr: solved_compose,
3917 },
3918 Default::default(),
3919 )
3920 .unwrap();
3921
3922 let pass = match global_expressions[solved_negate] {
3923 Expression::Compose { ty, ref components } => {
3924 ty == vec2_i32_ty
3925 && components.iter().all(|&component| {
3926 let component = &global_expressions[component];
3927 matches!(*component, Expression::Literal(Literal::I32(-4)))
3928 })
3929 }
3930 _ => false,
3931 };
3932 if !pass {
3933 panic!("unexpected evaluation result")
3934 }
3935 }
3936
3937 #[test]
3938 fn splat_of_constant() {
3939 let mut types = UniqueArena::new();
3940 let mut constants = Arena::new();
3941 let overrides = Arena::new();
3942 let mut global_expressions = Arena::new();
3943
3944 let i32_ty = types.insert(
3945 Type {
3946 name: None,
3947 inner: TypeInner::Scalar(crate::Scalar::I32),
3948 },
3949 Default::default(),
3950 );
3951
3952 let vec2_i32_ty = types.insert(
3953 Type {
3954 name: None,
3955 inner: TypeInner::Vector {
3956 size: VectorSize::Bi,
3957 scalar: crate::Scalar::I32,
3958 },
3959 },
3960 Default::default(),
3961 );
3962
3963 let h = constants.append(
3964 Constant {
3965 name: None,
3966 ty: i32_ty,
3967 init: global_expressions
3968 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3969 },
3970 Default::default(),
3971 );
3972
3973 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3974
3975 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3976 let mut solver = ConstantEvaluator {
3977 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3978 types: &mut types,
3979 constants: &constants,
3980 overrides: &overrides,
3981 expressions: &mut global_expressions,
3982 expression_kind_tracker,
3983 layouter: &mut crate::proc::Layouter::default(),
3984 };
3985
3986 let solved_compose = solver
3987 .try_eval_and_append(
3988 Expression::Splat {
3989 size: VectorSize::Bi,
3990 value: h_expr,
3991 },
3992 Default::default(),
3993 )
3994 .unwrap();
3995 let solved_negate = solver
3996 .try_eval_and_append(
3997 Expression::Unary {
3998 op: UnaryOperator::Negate,
3999 expr: solved_compose,
4000 },
4001 Default::default(),
4002 )
4003 .unwrap();
4004
4005 let pass = match global_expressions[solved_negate] {
4006 Expression::Compose { ty, ref components } => {
4007 ty == vec2_i32_ty
4008 && components.iter().all(|&component| {
4009 let component = &global_expressions[component];
4010 matches!(*component, Expression::Literal(Literal::I32(-4)))
4011 })
4012 }
4013 _ => false,
4014 };
4015 if !pass {
4016 panic!("unexpected evaluation result")
4017 }
4018 }
4019
4020 #[test]
4021 fn splat_of_zero_value() {
4022 let mut types = UniqueArena::new();
4023 let constants = Arena::new();
4024 let overrides = Arena::new();
4025 let mut global_expressions = Arena::new();
4026
4027 let f32_ty = types.insert(
4028 Type {
4029 name: None,
4030 inner: TypeInner::Scalar(crate::Scalar::F32),
4031 },
4032 Default::default(),
4033 );
4034
4035 let vec2_f32_ty = types.insert(
4036 Type {
4037 name: None,
4038 inner: TypeInner::Vector {
4039 size: VectorSize::Bi,
4040 scalar: crate::Scalar::F32,
4041 },
4042 },
4043 Default::default(),
4044 );
4045
4046 let five =
4047 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4048 let five_splat = global_expressions.append(
4049 Expression::Splat {
4050 size: VectorSize::Bi,
4051 value: five,
4052 },
4053 Default::default(),
4054 );
4055 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4056 let zero_splat = global_expressions.append(
4057 Expression::Splat {
4058 size: VectorSize::Bi,
4059 value: zero,
4060 },
4061 Default::default(),
4062 );
4063
4064 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4065 let mut solver = ConstantEvaluator {
4066 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4067 types: &mut types,
4068 constants: &constants,
4069 overrides: &overrides,
4070 expressions: &mut global_expressions,
4071 expression_kind_tracker,
4072 layouter: &mut crate::proc::Layouter::default(),
4073 };
4074
4075 let solved_add = solver
4076 .try_eval_and_append(
4077 Expression::Binary {
4078 op: crate::BinaryOperator::Add,
4079 left: zero_splat,
4080 right: five_splat,
4081 },
4082 Default::default(),
4083 )
4084 .unwrap();
4085
4086 let pass = match global_expressions[solved_add] {
4087 Expression::Compose { ty, ref components } => {
4088 ty == vec2_f32_ty
4089 && components.iter().all(|&component| {
4090 let component = &global_expressions[component];
4091 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4092 })
4093 }
4094 _ => false,
4095 };
4096 if !pass {
4097 panic!("unexpected evaluation result")
4098 }
4099 }
4100}