1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
273enum LiteralVector {
274 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288 fn len(&self) -> usize {
289 match *self {
290 LiteralVector::F64(ref v) => v.len(),
291 LiteralVector::F32(ref v) => v.len(),
292 LiteralVector::F16(ref v) => v.len(),
293 LiteralVector::U32(ref v) => v.len(),
294 LiteralVector::I32(ref v) => v.len(),
295 LiteralVector::U64(ref v) => v.len(),
296 LiteralVector::I64(ref v) => v.len(),
297 LiteralVector::Bool(ref v) => v.len(),
298 LiteralVector::AbstractInt(ref v) => v.len(),
299 LiteralVector::AbstractFloat(ref v) => v.len(),
300 }
301 }
302
303 fn from_literal(literal: Literal) -> Self {
305 match literal {
306 Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307 Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308 Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309 Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310 Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311 Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312 Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313 Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314 Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315 Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316 }
317 }
318
319 fn from_literal_vec(
324 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325 ) -> Result<Self, ConstantEvaluatorError> {
326 assert!(!components.is_empty());
327 Ok(match components[0] {
328 Literal::I32(_) => Self::I32(
329 components
330 .iter()
331 .map(|l| match l {
332 &Literal::I32(v) => Ok(v),
333 _ => Err(ConstantEvaluatorError::InvalidMathArg),
335 })
336 .collect::<Result<_, _>>()?,
337 ),
338 Literal::U32(_) => Self::U32(
339 components
340 .iter()
341 .map(|l| match l {
342 &Literal::U32(v) => Ok(v),
343 _ => Err(ConstantEvaluatorError::InvalidMathArg),
344 })
345 .collect::<Result<_, _>>()?,
346 ),
347 Literal::I64(_) => Self::I64(
348 components
349 .iter()
350 .map(|l| match l {
351 &Literal::I64(v) => Ok(v),
352 _ => Err(ConstantEvaluatorError::InvalidMathArg),
353 })
354 .collect::<Result<_, _>>()?,
355 ),
356 Literal::U64(_) => Self::U64(
357 components
358 .iter()
359 .map(|l| match l {
360 &Literal::U64(v) => Ok(v),
361 _ => Err(ConstantEvaluatorError::InvalidMathArg),
362 })
363 .collect::<Result<_, _>>()?,
364 ),
365 Literal::F32(_) => Self::F32(
366 components
367 .iter()
368 .map(|l| match l {
369 &Literal::F32(v) => Ok(v),
370 _ => Err(ConstantEvaluatorError::InvalidMathArg),
371 })
372 .collect::<Result<_, _>>()?,
373 ),
374 Literal::F64(_) => Self::F64(
375 components
376 .iter()
377 .map(|l| match l {
378 &Literal::F64(v) => Ok(v),
379 _ => Err(ConstantEvaluatorError::InvalidMathArg),
380 })
381 .collect::<Result<_, _>>()?,
382 ),
383 Literal::Bool(_) => Self::Bool(
384 components
385 .iter()
386 .map(|l| match l {
387 &Literal::Bool(v) => Ok(v),
388 _ => Err(ConstantEvaluatorError::InvalidMathArg),
389 })
390 .collect::<Result<_, _>>()?,
391 ),
392 Literal::AbstractInt(_) => Self::AbstractInt(
393 components
394 .iter()
395 .map(|l| match l {
396 &Literal::AbstractInt(v) => Ok(v),
397 _ => Err(ConstantEvaluatorError::InvalidMathArg),
398 })
399 .collect::<Result<_, _>>()?,
400 ),
401 Literal::AbstractFloat(_) => Self::AbstractFloat(
402 components
403 .iter()
404 .map(|l| match l {
405 &Literal::AbstractFloat(v) => Ok(v),
406 _ => Err(ConstantEvaluatorError::InvalidMathArg),
407 })
408 .collect::<Result<_, _>>()?,
409 ),
410 Literal::F16(_) => Self::F16(
411 components
412 .iter()
413 .map(|l| match l {
414 &Literal::F16(v) => Ok(v),
415 _ => Err(ConstantEvaluatorError::InvalidMathArg),
416 })
417 .collect::<Result<_, _>>()?,
418 ),
419 })
420 }
421
422 #[allow(dead_code)]
423 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425 match *self {
426 LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427 LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428 LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429 LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430 LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431 LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432 LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433 LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434 LiteralVector::AbstractInt(ref v) => {
435 v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436 }
437 LiteralVector::AbstractFloat(ref v) => {
438 v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439 }
440 }
441 }
442
443 #[allow(dead_code)]
444 fn register_as_evaluated_expr(
446 &self,
447 eval: &mut ConstantEvaluator<'_>,
448 span: Span,
449 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450 let lit_vec = self.to_literal_vec();
451 assert!(!lit_vec.is_empty());
452 let expr = if lit_vec.len() == 1 {
453 Expression::Literal(lit_vec[0])
454 } else {
455 Expression::Compose {
456 ty: eval.types.insert(
457 Type {
458 name: None,
459 inner: TypeInner::Vector {
460 size: match lit_vec.len() {
461 2 => crate::VectorSize::Bi,
462 3 => crate::VectorSize::Tri,
463 4 => crate::VectorSize::Quad,
464 _ => unreachable!(),
465 },
466 scalar: lit_vec[0].scalar(),
467 },
468 },
469 Span::UNDEFINED,
470 ),
471 components: lit_vec
472 .iter()
473 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474 .collect::<Result<_, _>>()?,
475 }
476 };
477 eval.register_evaluated_expr(expr, span)
478 }
479}
480
481macro_rules! match_literal_vector {
506 (match $lit_vec:expr => $out:ident {
507 $(
508 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509 ),+
510 $(,)?
511 }) => {
512 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513 };
514
515 (@inner_start
516 $lit_vec:expr;
517 $out:ident;
518 [$($ty:ident),+];
519 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520 ) => {
521 match_literal_vector!(@inner
522 $lit_vec;
523 $out;
524 [$($ty),+];
525 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526 )
527 };
528
529 (@inner
530 $lit_vec:expr;
531 $out:ident;
532 [$ty:ident $(, $ty1:ident)*];
533 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535 ) => {
536 match_literal_vector!(@inner
537 $ty;
538 $lit_vec;
539 $out;
540 [$($ty1),*];
541 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542 [$({ $($var),+ ; $($ret)? ; $body }),+]
543 )
544 };
545 (@inner
546 Integer;
547 $lit_vec:expr;
548 $out:ident;
549 [$($ty:ident),*];
550 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551 [
552 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554 ]
555 ) => {
556 match_literal_vector!(@inner
557 $lit_vec;
558 $out;
559 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561 [
562 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568 ]
569 )
570 };
571 (@inner
572 Float;
573 $lit_vec:expr;
574 $out:ident;
575 [$($ty:ident),*];
576 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577 [
578 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580 ]
581 ) => {
582 match_literal_vector!(@inner
583 $lit_vec;
584 $out;
585 [F16, F32, F64, AbstractFloat $(, $ty)*];
586 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587 [
588 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593 ]
594 )
595 };
596 (@inner
597 $ty:ident;
598 $lit_vec:expr;
599 $out:ident;
600 [$ty1:ident $(,$ty2:ident)*];
601 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604 ]
605 ) => {
606 match_literal_vector!(@inner
607 $ty1;
608 $lit_vec;
609 $out;
610 [$($ty2),*];
611 [
612 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613 { $ty; $($var),+ ; $($ret)? ; $body }
614 ] <>
615 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617 )
618 };
619 (@inner
620 $ty:ident;
621 $lit_vec:expr;
622 $out:ident;
623 [];
624 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626 ) => {
627 match_literal_vector!(@inner_finish
628 $lit_vec;
629 $out;
630 [
631 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632 { $ty; $($var),+ ; $($ret)? ; $body }
633 ]
634 )
635 };
636 (@inner_finish
637 $lit_vec:expr;
638 $out:ident;
639 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640 ) => {
641 match $lit_vec {
642 $(
643 #[allow(unused_parens)]
644 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645 )+
646 _ => Err(ConstantEvaluatorError::InvalidMathArg),
647 }
648 };
649 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650 $out::$ty($body)
651 };
652 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653 $out::$ret($body)
654 };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659 Wgsl(WgslRestrictions<'a>),
660 Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664 const fn has_runtime_restrictions(&self) -> bool {
666 matches!(
667 self,
668 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670 )
671 }
672}
673
674#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693 behavior: Behavior<'a>,
695
696 types: &'a mut UniqueArena<Type>,
703
704 constants: &'a Arena<Constant>,
706
707 overrides: &'a Arena<Override>,
709
710 expressions: &'a mut Arena<Expression>,
712
713 expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716 layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721 Const(Option<FunctionLocalData<'a>>),
723 Override,
726 Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734 Const,
736 Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744 global_expressions: &'a Arena<Expression>,
746 emitter: &'a mut super::Emitter,
747 block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752 Const,
753 Override,
754 Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759 inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763 pub const fn new() -> Self {
764 Self {
765 inner: HandleVec::new(),
766 }
767 }
768
769 pub fn force_non_const(&mut self, value: Handle<Expression>) {
771 self.inner[value] = ExpressionKind::Runtime;
772 }
773
774 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775 self.inner.insert(value, expr_type);
776 }
777
778 pub fn is_const(&self, h: Handle<Expression>) -> bool {
779 matches!(self.type_of(h), ExpressionKind::Const)
780 }
781
782 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783 matches!(
784 self.type_of(h),
785 ExpressionKind::Const | ExpressionKind::Override
786 )
787 }
788
789 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790 self.inner[value]
791 }
792
793 pub fn from_arena(arena: &Arena<Expression>) -> Self {
794 let mut tracker = Self {
795 inner: HandleVec::with_capacity(arena.len()),
796 };
797 for (handle, expr) in arena.iter() {
798 tracker
799 .inner
800 .insert(handle, tracker.type_of_with_expr(expr));
801 }
802 tracker
803 }
804
805 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806 match *expr {
807 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808 ExpressionKind::Const
809 }
810 Expression::Override(_) => ExpressionKind::Override,
811 Expression::Compose { ref components, .. } => {
812 let mut expr_type = ExpressionKind::Const;
813 for component in components {
814 expr_type = expr_type.max(self.type_of(*component))
815 }
816 expr_type
817 }
818 Expression::Splat { value, .. } => self.type_of(value),
819 Expression::AccessIndex { base, .. } => self.type_of(base),
820 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821 Expression::Swizzle { vector, .. } => self.type_of(vector),
822 Expression::Unary { expr, .. } => self.type_of(expr),
823 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824 Expression::Math {
825 arg,
826 arg1,
827 arg2,
828 arg3,
829 ..
830 } => self
831 .type_of(arg)
832 .max(
833 arg1.map(|arg| self.type_of(arg))
834 .unwrap_or(ExpressionKind::Const),
835 )
836 .max(
837 arg2.map(|arg| self.type_of(arg))
838 .unwrap_or(ExpressionKind::Const),
839 )
840 .max(
841 arg3.map(|arg| self.type_of(arg))
842 .unwrap_or(ExpressionKind::Const),
843 ),
844 Expression::As { expr, .. } => self.type_of(expr),
845 Expression::Select {
846 condition,
847 accept,
848 reject,
849 } => self
850 .type_of(condition)
851 .max(self.type_of(accept))
852 .max(self.type_of(reject)),
853 Expression::Relational { argument, .. } => self.type_of(argument),
854 Expression::ArrayLength(expr) => self.type_of(expr),
855 _ => ExpressionKind::Runtime,
856 }
857 }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863 #[error("Constants cannot access function arguments")]
864 FunctionArg,
865 #[error("Constants cannot access global variables")]
866 GlobalVariable,
867 #[error("Constants cannot access local variables")]
868 LocalVariable,
869 #[error("Cannot get the array length of a non array type")]
870 InvalidArrayLengthArg,
871 #[error("Constants cannot get the array length of a dynamically sized array")]
872 ArrayLengthDynamic,
873 #[error("Cannot call arrayLength on array sized by override-expression")]
874 ArrayLengthOverridden,
875 #[error("Constants cannot call functions")]
876 Call,
877 #[error("Constants don't support workGroupUniformLoad")]
878 WorkGroupUniformLoadResult,
879 #[error("Constants don't support atomic functions")]
880 Atomic,
881 #[error("Constants don't support derivative functions")]
882 Derivative,
883 #[error("Constants don't support load expressions")]
884 Load,
885 #[error("Constants don't support image expressions")]
886 ImageExpression,
887 #[error("Constants don't support ray query expressions")]
888 RayQueryExpression,
889 #[error("Constants don't support subgroup expressions")]
890 SubgroupExpression,
891 #[error("Cannot access the type")]
892 InvalidAccessBase,
893 #[error("Cannot access at the index")]
894 InvalidAccessIndex,
895 #[error("Cannot access with index of type")]
896 InvalidAccessIndexTy,
897 #[error("Constants don't support array length expressions")]
898 ArrayLength,
899 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900 InvalidCastArg { from: String, to: String },
901 #[error("Cannot apply the unary op to the argument")]
902 InvalidUnaryOpArg,
903 #[error("Cannot apply the binary op to the arguments")]
904 InvalidBinaryOpArgs,
905 #[error("Cannot apply math function to type")]
906 InvalidMathArg,
907 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908 InvalidMathArgCount(crate::MathFunction, usize, usize),
909 #[error("Cannot apply relational function to type")]
910 InvalidRelationalArg(RelationalFunction),
911 #[error("value of `low` is greater than `high` for clamp built-in function")]
912 InvalidClamp,
913 #[error("Constructor expects {expected} components, found {actual}")]
914 InvalidVectorComposeLength { expected: usize, actual: usize },
915 #[error("Constructor must only contain vector or scalar arguments")]
916 InvalidVectorComposeComponent,
917 #[error("Splat is defined only on scalar values")]
918 SplatScalarOnly,
919 #[error("Can only swizzle vector constants")]
920 SwizzleVectorOnly,
921 #[error("swizzle component not present in source expression")]
922 SwizzleOutOfBounds,
923 #[error("Type is not constructible")]
924 TypeNotConstructible,
925 #[error("Subexpression(s) are not constant")]
926 SubexpressionsAreNotConstant,
927 #[error("Not implemented as constant expression: {0}")]
928 NotImplemented(String),
929 #[error("{0} operation overflowed")]
930 Overflow(String),
931 #[error(
932 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933 )]
934 AutomaticConversionLossy {
935 value: String,
936 to_type: &'static str,
937 },
938 #[error("Division by zero")]
939 DivisionByZero,
940 #[error("Remainder by zero")]
941 RemainderByZero,
942 #[error("RHS of shift operation is greater than or equal to 32")]
943 ShiftedMoreThan32Bits,
944 #[error(transparent)]
945 Literal(#[from] crate::valid::LiteralError),
946 #[error("Can't use pipeline-overridable constants in const-expressions")]
947 Override,
948 #[error("Unexpected runtime-expression")]
949 RuntimeExpr,
950 #[error("Unexpected override-expression")]
951 OverrideExpr,
952 #[error("Expected boolean expression for condition argument of `select`, got something else")]
953 SelectScalarConditionNotABool,
954 #[error(
955 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956 reject,
957 accept
958 )]
959 SelectVecRejectAcceptSizeMismatch {
960 reject: crate::VectorSize,
961 accept: crate::VectorSize,
962 },
963 #[error("Expected boolean vector for condition arg., got something else")]
964 SelectConditionNotAVecBool,
965 #[error(
966 "Expected same number of vector components between condition, accept, and reject args., got something else",
967 )]
968 SelectConditionVecSizeMismatch,
969 #[error(
970 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971 )]
972 SelectAcceptRejectTypeMismatch,
973}
974
975impl<'a> ConstantEvaluator<'a> {
976 pub fn for_wgsl_module(
981 module: &'a mut crate::Module,
982 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
983 layouter: &'a mut crate::proc::Layouter,
984 in_override_ctx: bool,
985 ) -> Self {
986 Self::for_module(
987 Behavior::Wgsl(if in_override_ctx {
988 WgslRestrictions::Override
989 } else {
990 WgslRestrictions::Const(None)
991 }),
992 module,
993 global_expression_kind_tracker,
994 layouter,
995 )
996 }
997
998 pub fn for_glsl_module(
1003 module: &'a mut crate::Module,
1004 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1005 layouter: &'a mut crate::proc::Layouter,
1006 ) -> Self {
1007 Self::for_module(
1008 Behavior::Glsl(GlslRestrictions::Const),
1009 module,
1010 global_expression_kind_tracker,
1011 layouter,
1012 )
1013 }
1014
1015 fn for_module(
1016 behavior: Behavior<'a>,
1017 module: &'a mut crate::Module,
1018 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1019 layouter: &'a mut crate::proc::Layouter,
1020 ) -> Self {
1021 Self {
1022 behavior,
1023 types: &mut module.types,
1024 constants: &module.constants,
1025 overrides: &module.overrides,
1026 expressions: &mut module.global_expressions,
1027 expression_kind_tracker: global_expression_kind_tracker,
1028 layouter,
1029 }
1030 }
1031
1032 pub fn for_wgsl_function(
1037 module: &'a mut crate::Module,
1038 expressions: &'a mut Arena<Expression>,
1039 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1040 layouter: &'a mut crate::proc::Layouter,
1041 emitter: &'a mut super::Emitter,
1042 block: &'a mut crate::Block,
1043 is_const: bool,
1044 ) -> Self {
1045 let local_data = FunctionLocalData {
1046 global_expressions: &module.global_expressions,
1047 emitter,
1048 block,
1049 };
1050 Self {
1051 behavior: Behavior::Wgsl(if is_const {
1052 WgslRestrictions::Const(Some(local_data))
1053 } else {
1054 WgslRestrictions::Runtime(local_data)
1055 }),
1056 types: &mut module.types,
1057 constants: &module.constants,
1058 overrides: &module.overrides,
1059 expressions,
1060 expression_kind_tracker: local_expression_kind_tracker,
1061 layouter,
1062 }
1063 }
1064
1065 pub fn for_glsl_function(
1070 module: &'a mut crate::Module,
1071 expressions: &'a mut Arena<Expression>,
1072 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1073 layouter: &'a mut crate::proc::Layouter,
1074 emitter: &'a mut super::Emitter,
1075 block: &'a mut crate::Block,
1076 ) -> Self {
1077 Self {
1078 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1079 global_expressions: &module.global_expressions,
1080 emitter,
1081 block,
1082 })),
1083 types: &mut module.types,
1084 constants: &module.constants,
1085 overrides: &module.overrides,
1086 expressions,
1087 expression_kind_tracker: local_expression_kind_tracker,
1088 layouter,
1089 }
1090 }
1091
1092 pub fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1093 crate::proc::GlobalCtx {
1094 types: self.types,
1095 constants: self.constants,
1096 overrides: self.overrides,
1097 global_expressions: match self.function_local_data() {
1098 Some(data) => data.global_expressions,
1099 None => self.expressions,
1100 },
1101 }
1102 }
1103
1104 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1105 if !self.expression_kind_tracker.is_const(expr) {
1106 log::debug!("check: SubexpressionsAreNotConstant");
1107 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1108 }
1109 Ok(())
1110 }
1111
1112 fn check_and_get(
1113 &mut self,
1114 expr: Handle<Expression>,
1115 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1116 match self.expressions[expr] {
1117 Expression::Constant(c) => {
1118 if let Some(function_local_data) = self.function_local_data() {
1121 self.copy_from(
1123 self.constants[c].init,
1124 function_local_data.global_expressions,
1125 )
1126 } else {
1127 Ok(self.constants[c].init)
1129 }
1130 }
1131 _ => {
1132 self.check(expr)?;
1133 Ok(expr)
1134 }
1135 }
1136 }
1137
1138 pub fn try_eval_and_append(
1162 &mut self,
1163 expr: Expression,
1164 span: Span,
1165 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1166 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1167 ExpressionKind::Const => {
1168 let eval_result = self.try_eval_and_append_impl(&expr, span);
1169 if self.behavior.has_runtime_restrictions()
1174 && matches!(
1175 eval_result,
1176 Err(ConstantEvaluatorError::NotImplemented(_)
1177 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1178 )
1179 {
1180 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1181 } else {
1182 eval_result
1183 }
1184 }
1185 ExpressionKind::Override => match self.behavior {
1186 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1187 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1188 }
1189 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1190 Err(ConstantEvaluatorError::OverrideExpr)
1191 }
1192 Behavior::Glsl(_) => {
1193 unreachable!()
1194 }
1195 },
1196 ExpressionKind::Runtime => {
1197 if self.behavior.has_runtime_restrictions() {
1198 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1199 } else {
1200 Err(ConstantEvaluatorError::RuntimeExpr)
1201 }
1202 }
1203 }
1204 }
1205
1206 const fn is_global_arena(&self) -> bool {
1208 matches!(
1209 self.behavior,
1210 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1211 | Behavior::Glsl(GlslRestrictions::Const)
1212 )
1213 }
1214
1215 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1216 match self.behavior {
1217 Behavior::Wgsl(
1218 WgslRestrictions::Runtime(ref function_local_data)
1219 | WgslRestrictions::Const(Some(ref function_local_data)),
1220 )
1221 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1222 Some(function_local_data)
1223 }
1224 _ => None,
1225 }
1226 }
1227
1228 fn try_eval_and_append_impl(
1229 &mut self,
1230 expr: &Expression,
1231 span: Span,
1232 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1233 log::trace!("try_eval_and_append: {expr:?}");
1234 match *expr {
1235 Expression::Constant(c) if self.is_global_arena() => {
1236 Ok(self.constants[c].init)
1239 }
1240 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1241 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1242 self.register_evaluated_expr(expr.clone(), span)
1243 }
1244 Expression::Compose { ty, ref components } => {
1245 let components = components
1246 .iter()
1247 .map(|component| self.check_and_get(*component))
1248 .collect::<Result<Vec<_>, _>>()?;
1249 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1250 }
1251 Expression::Splat { size, value } => {
1252 let value = self.check_and_get(value)?;
1253 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1254 }
1255 Expression::AccessIndex { base, index } => {
1256 let base = self.check_and_get(base)?;
1257
1258 self.access(base, index as usize, span)
1259 }
1260 Expression::Access { base, index } => {
1261 let base = self.check_and_get(base)?;
1262 let index = self.check_and_get(index)?;
1263
1264 self.access(base, self.constant_index(index)?, span)
1265 }
1266 Expression::Swizzle {
1267 size,
1268 vector,
1269 pattern,
1270 } => {
1271 let vector = self.check_and_get(vector)?;
1272
1273 self.swizzle(size, span, vector, pattern)
1274 }
1275 Expression::Unary { expr, op } => {
1276 let expr = self.check_and_get(expr)?;
1277
1278 self.unary_op(op, expr, span)
1279 }
1280 Expression::Binary { left, right, op } => {
1281 let left = self.check_and_get(left)?;
1282 let right = self.check_and_get(right)?;
1283
1284 self.binary_op(op, left, right, span)
1285 }
1286 Expression::Math {
1287 fun,
1288 arg,
1289 arg1,
1290 arg2,
1291 arg3,
1292 } => {
1293 let arg = self.check_and_get(arg)?;
1294 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1295 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1296 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1297
1298 self.math(arg, arg1, arg2, arg3, fun, span)
1299 }
1300 Expression::As {
1301 convert,
1302 expr,
1303 kind,
1304 } => {
1305 let expr = self.check_and_get(expr)?;
1306
1307 match convert {
1308 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1309 None => Err(ConstantEvaluatorError::NotImplemented(
1310 "bitcast built-in function".into(),
1311 )),
1312 }
1313 }
1314 Expression::Select {
1315 reject,
1316 accept,
1317 condition,
1318 } => {
1319 let mut arg = |expr| self.check_and_get(expr);
1320
1321 let reject = arg(reject)?;
1322 let accept = arg(accept)?;
1323 let condition = arg(condition)?;
1324
1325 self.select(reject, accept, condition, span)
1326 }
1327 Expression::Relational { fun, argument } => {
1328 let argument = self.check_and_get(argument)?;
1329 self.relational(fun, argument, span)
1330 }
1331 Expression::ArrayLength(expr) => match self.behavior {
1332 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1333 Behavior::Glsl(_) => {
1334 let expr = self.check_and_get(expr)?;
1335 self.array_length(expr, span)
1336 }
1337 },
1338 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1339 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1340 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1341 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1342 Expression::WorkGroupUniformLoadResult { .. } => {
1343 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1344 }
1345 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1346 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1347 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1348 Expression::ImageSample { .. }
1349 | Expression::ImageLoad { .. }
1350 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1351 Expression::RayQueryProceedResult
1352 | Expression::RayQueryGetIntersection { .. }
1353 | Expression::RayQueryVertexPositions { .. } => {
1354 Err(ConstantEvaluatorError::RayQueryExpression)
1355 }
1356 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1357 Expression::SubgroupOperationResult { .. } => {
1358 Err(ConstantEvaluatorError::SubgroupExpression)
1359 }
1360 }
1361 }
1362
1363 fn splat(
1376 &mut self,
1377 value: Handle<Expression>,
1378 size: crate::VectorSize,
1379 span: Span,
1380 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1381 match self.expressions[value] {
1382 Expression::Literal(literal) => {
1383 let scalar = literal.scalar();
1384 let ty = self.types.insert(
1385 Type {
1386 name: None,
1387 inner: TypeInner::Vector { size, scalar },
1388 },
1389 span,
1390 );
1391 let expr = Expression::Compose {
1392 ty,
1393 components: vec![value; size as usize],
1394 };
1395 self.register_evaluated_expr(expr, span)
1396 }
1397 Expression::ZeroValue(ty) => {
1398 let inner = match self.types[ty].inner {
1399 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1400 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1401 };
1402 let res_ty = self.types.insert(Type { name: None, inner }, span);
1403 let expr = Expression::ZeroValue(res_ty);
1404 self.register_evaluated_expr(expr, span)
1405 }
1406 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1407 }
1408 }
1409
1410 fn swizzle(
1411 &mut self,
1412 size: crate::VectorSize,
1413 span: Span,
1414 src_constant: Handle<Expression>,
1415 pattern: [crate::SwizzleComponent; 4],
1416 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1417 let mut get_dst_ty = |ty| match self.types[ty].inner {
1418 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1419 Type {
1420 name: None,
1421 inner: TypeInner::Vector { size, scalar },
1422 },
1423 span,
1424 )),
1425 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1426 };
1427
1428 match self.expressions[src_constant] {
1429 Expression::ZeroValue(ty) => {
1430 let dst_ty = get_dst_ty(ty)?;
1431 let expr = Expression::ZeroValue(dst_ty);
1432 self.register_evaluated_expr(expr, span)
1433 }
1434 Expression::Splat { value, .. } => {
1435 let expr = Expression::Splat { size, value };
1436 self.register_evaluated_expr(expr, span)
1437 }
1438 Expression::Compose { ty, ref components } => {
1439 let dst_ty = get_dst_ty(ty)?;
1440
1441 let mut flattened = [src_constant; 4]; let len =
1443 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1444 .zip(flattened.iter_mut())
1445 .map(|(component, elt)| *elt = component)
1446 .count();
1447 let flattened = &flattened[..len];
1448
1449 let swizzled_components = pattern[..size as usize]
1450 .iter()
1451 .map(|&sc| {
1452 let sc = sc as usize;
1453 if let Some(elt) = flattened.get(sc) {
1454 Ok(*elt)
1455 } else {
1456 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1457 }
1458 })
1459 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1460 let expr = Expression::Compose {
1461 ty: dst_ty,
1462 components: swizzled_components,
1463 };
1464 self.register_evaluated_expr(expr, span)
1465 }
1466 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1467 }
1468 }
1469
1470 fn math(
1471 &mut self,
1472 arg: Handle<Expression>,
1473 arg1: Option<Handle<Expression>>,
1474 arg2: Option<Handle<Expression>>,
1475 arg3: Option<Handle<Expression>>,
1476 fun: crate::MathFunction,
1477 span: Span,
1478 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1479 let expected = fun.argument_count();
1480 let given = Some(arg)
1481 .into_iter()
1482 .chain(arg1)
1483 .chain(arg2)
1484 .chain(arg3)
1485 .count();
1486 if expected != given {
1487 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1488 fun, expected, given,
1489 ));
1490 }
1491
1492 match fun {
1494 crate::MathFunction::Abs => {
1496 component_wise_scalar(self, span, [arg], |args| match args {
1497 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1498 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1499 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1500 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1501 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1502 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1504 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1505 })
1506 }
1507 crate::MathFunction::Min => {
1508 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1509 Ok([e1.min(e2)])
1510 })
1511 }
1512 crate::MathFunction::Max => {
1513 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1514 Ok([e1.max(e2)])
1515 })
1516 }
1517 crate::MathFunction::Clamp => {
1518 component_wise_scalar!(
1519 self,
1520 span,
1521 [arg, arg1.unwrap(), arg2.unwrap()],
1522 |e, low, high| {
1523 if low > high {
1524 Err(ConstantEvaluatorError::InvalidClamp)
1525 } else {
1526 Ok([e.clamp(low, high)])
1527 }
1528 }
1529 )
1530 }
1531 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1532 Float::F16([e]) => Ok(Float::F16(
1533 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1534 )),
1535 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1536 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1537 }),
1538
1539 crate::MathFunction::Cos => {
1541 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1542 }
1543 crate::MathFunction::Cosh => {
1544 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1545 }
1546 crate::MathFunction::Sin => {
1547 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1548 }
1549 crate::MathFunction::Sinh => {
1550 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1551 }
1552 crate::MathFunction::Tan => {
1553 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1554 }
1555 crate::MathFunction::Tanh => {
1556 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1557 }
1558 crate::MathFunction::Acos => {
1559 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1560 }
1561 crate::MathFunction::Asin => {
1562 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1563 }
1564 crate::MathFunction::Atan => {
1565 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1566 }
1567 crate::MathFunction::Atan2 => {
1568 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1569 Ok([y.atan2(x)])
1570 })
1571 }
1572 crate::MathFunction::Asinh => {
1573 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1574 }
1575 crate::MathFunction::Acosh => {
1576 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1577 }
1578 crate::MathFunction::Atanh => {
1579 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1580 }
1581 crate::MathFunction::Radians => {
1582 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1583 }
1584 crate::MathFunction::Degrees => {
1585 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1586 }
1587
1588 crate::MathFunction::Ceil => {
1590 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1591 }
1592 crate::MathFunction::Floor => {
1593 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1594 }
1595 crate::MathFunction::Round => {
1596 component_wise_float(self, span, [arg], |e| match e {
1597 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1598 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1599 Float::F16([e]) => {
1600 fn round_ties_even(x: f64) -> f64 {
1608 let i = x as i64;
1609 let f = (x - i as f64).abs();
1610 if f == 0.5 {
1611 if i & 1 == 1 {
1612 (x.abs() + 0.5).copysign(x)
1614 } else {
1615 (x.abs() - 0.5).copysign(x)
1616 }
1617 } else {
1618 x.round()
1619 }
1620 }
1621
1622 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1623 }
1624 })
1625 }
1626 crate::MathFunction::Fract => {
1627 component_wise_float!(self, span, [arg], |e| {
1628 Ok([e - e.floor()])
1631 })
1632 }
1633 crate::MathFunction::Trunc => {
1634 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1635 }
1636
1637 crate::MathFunction::Exp => {
1639 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1640 }
1641 crate::MathFunction::Exp2 => {
1642 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1643 }
1644 crate::MathFunction::Log => {
1645 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1646 }
1647 crate::MathFunction::Log2 => {
1648 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1649 }
1650 crate::MathFunction::Pow => {
1651 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1652 Ok([e1.powf(e2)])
1653 })
1654 }
1655
1656 crate::MathFunction::Sign => {
1658 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1659 }
1660 crate::MathFunction::Fma => {
1661 component_wise_float!(
1662 self,
1663 span,
1664 [arg, arg1.unwrap(), arg2.unwrap()],
1665 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1666 )
1667 }
1668 crate::MathFunction::Step => {
1669 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1670 Float::Abstract([edge, x]) => {
1671 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1672 }
1673 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1674 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1675 f16::one()
1676 } else {
1677 f16::zero()
1678 }])),
1679 })
1680 }
1681 crate::MathFunction::Sqrt => {
1682 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1683 }
1684 crate::MathFunction::InverseSqrt => {
1685 component_wise_float(self, span, [arg], |e| match e {
1686 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1687 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1688 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1689 })
1690 }
1691
1692 crate::MathFunction::CountTrailingZeros => {
1694 component_wise_concrete_int!(self, span, [arg], |e| {
1695 #[allow(clippy::useless_conversion)]
1696 Ok([e
1697 .trailing_zeros()
1698 .try_into()
1699 .expect("bit count overflowed 32 bits, somehow!?")])
1700 })
1701 }
1702 crate::MathFunction::CountLeadingZeros => {
1703 component_wise_concrete_int!(self, span, [arg], |e| {
1704 #[allow(clippy::useless_conversion)]
1705 Ok([e
1706 .leading_zeros()
1707 .try_into()
1708 .expect("bit count overflowed 32 bits, somehow!?")])
1709 })
1710 }
1711 crate::MathFunction::CountOneBits => {
1712 component_wise_concrete_int!(self, span, [arg], |e| {
1713 #[allow(clippy::useless_conversion)]
1714 Ok([e
1715 .count_ones()
1716 .try_into()
1717 .expect("bit count overflowed 32 bits, somehow!?")])
1718 })
1719 }
1720 crate::MathFunction::ReverseBits => {
1721 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1722 }
1723 crate::MathFunction::FirstTrailingBit => {
1724 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1725 }
1726 crate::MathFunction::FirstLeadingBit => {
1727 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1728 }
1729
1730 crate::MathFunction::Dot4I8Packed => {
1732 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1733 }
1734 crate::MathFunction::Dot4U8Packed => {
1735 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1736 }
1737 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1738 crate::MathFunction::Dot => {
1739 let e1 = self.extract_vec(arg, false)?;
1741 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1742 if e1.len() != e2.len() {
1743 return Err(ConstantEvaluatorError::InvalidMathArg);
1744 }
1745
1746 fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1747 where
1748 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1749 {
1750 a.iter()
1751 .zip(b.iter())
1752 .map(|(&aa, bb)| aa.checked_mul(bb))
1753 .try_fold(P::zero(), |acc, x| {
1754 if let Some(x) = x {
1755 acc.checked_add(&x)
1756 } else {
1757 None
1758 }
1759 })
1760 .ok_or(ConstantEvaluatorError::Overflow(
1761 "in dot built-in".to_string(),
1762 ))
1763 }
1764
1765 let result = match_literal_vector!(match (e1, e2) => Literal {
1766 Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1767 Integer => |e1, e2| { int_dot(e1, e2)? },
1768 })?;
1769 self.register_evaluated_expr(Expression::Literal(result), span)
1770 }
1771 crate::MathFunction::Length => {
1772 let e1 = self.extract_vec(arg, true)?;
1774
1775 fn float_length<F>(e: &[F]) -> F
1776 where
1777 F: core::ops::Mul<F>,
1778 F: num_traits::Float + iter::Sum,
1779 {
1780 e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1781 }
1782
1783 let result = match_literal_vector!(match e1 => Literal {
1784 Float => |e1| { float_length(e1) },
1785 })?;
1786 self.register_evaluated_expr(Expression::Literal(result), span)
1787 }
1788 crate::MathFunction::Distance => {
1789 let e1 = self.extract_vec(arg, true)?;
1791 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1792 if e1.len() != e2.len() {
1793 return Err(ConstantEvaluatorError::InvalidMathArg);
1794 }
1795
1796 fn float_distance<F>(a: &[F], b: &[F]) -> F
1797 where
1798 F: core::ops::Mul<F>,
1799 F: num_traits::Float + iter::Sum + core::ops::Sub,
1800 {
1801 a.iter()
1802 .zip(b.iter())
1803 .map(|(&aa, &bb)| aa - bb)
1804 .map(|ei| ei * ei)
1805 .sum::<F>()
1806 .sqrt()
1807 }
1808 let result = match_literal_vector!(match (e1, e2) => Literal {
1809 Float => |e1, e2| { float_distance(e1, e2) },
1810 })?;
1811 self.register_evaluated_expr(Expression::Literal(result), span)
1812 }
1813 crate::MathFunction::Normalize => {
1814 let e1 = self.extract_vec(arg, true)?;
1816
1817 fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1818 where
1819 F: core::ops::Mul<F>,
1820 F: num_traits::Float + iter::Sum,
1821 {
1822 let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1823 e.iter().map(|&ei| ei / len).collect()
1824 }
1825
1826 let result = match_literal_vector!(match e1 => LiteralVector {
1827 Float => |e1| { float_normalize(e1) },
1828 })?;
1829 result.register_as_evaluated_expr(self, span)
1830 }
1831
1832 crate::MathFunction::Modf
1834 | crate::MathFunction::Frexp
1835 | crate::MathFunction::Ldexp
1836 | crate::MathFunction::Outer
1837 | crate::MathFunction::FaceForward
1838 | crate::MathFunction::Reflect
1839 | crate::MathFunction::Refract
1840 | crate::MathFunction::Mix
1841 | crate::MathFunction::SmoothStep
1842 | crate::MathFunction::Inverse
1843 | crate::MathFunction::Transpose
1844 | crate::MathFunction::Determinant
1845 | crate::MathFunction::QuantizeToF16
1846 | crate::MathFunction::ExtractBits
1847 | crate::MathFunction::InsertBits
1848 | crate::MathFunction::Pack4x8snorm
1849 | crate::MathFunction::Pack4x8unorm
1850 | crate::MathFunction::Pack2x16snorm
1851 | crate::MathFunction::Pack2x16unorm
1852 | crate::MathFunction::Pack2x16float
1853 | crate::MathFunction::Pack4xI8
1854 | crate::MathFunction::Pack4xU8
1855 | crate::MathFunction::Pack4xI8Clamp
1856 | crate::MathFunction::Pack4xU8Clamp
1857 | crate::MathFunction::Unpack4x8snorm
1858 | crate::MathFunction::Unpack4x8unorm
1859 | crate::MathFunction::Unpack2x16snorm
1860 | crate::MathFunction::Unpack2x16unorm
1861 | crate::MathFunction::Unpack2x16float
1862 | crate::MathFunction::Unpack4xI8
1863 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1864 format!("{fun:?} built-in function"),
1865 )),
1866 }
1867 }
1868
1869 fn packed_dot_product(
1871 &mut self,
1872 a: Handle<Expression>,
1873 b: Handle<Expression>,
1874 span: Span,
1875 signed: bool,
1876 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1877 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1878 return Err(ConstantEvaluatorError::InvalidMathArg);
1879 };
1880 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1881 return Err(ConstantEvaluatorError::InvalidMathArg);
1882 };
1883
1884 let result = if signed {
1885 Literal::I32(
1886 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1887 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1888 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1889 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1890 )
1891 } else {
1892 Literal::U32(
1893 (a & 0xFF) * (b & 0xFF)
1894 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1895 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1896 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1897 )
1898 };
1899
1900 self.register_evaluated_expr(Expression::Literal(result), span)
1901 }
1902
1903 fn cross_product(
1905 &mut self,
1906 a: Handle<Expression>,
1907 b: Handle<Expression>,
1908 span: Span,
1909 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1910 use Literal as Li;
1911
1912 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1913 let (b, _) = self.extract_vec_with_size::<3>(b)?;
1914
1915 let product = match (a, b) {
1916 (
1917 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1918 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1919 ) => {
1920 let p = cross_product(
1925 [a0 as f64, a1 as f64, a2 as f64],
1926 [b0 as f64, b1 as f64, b2 as f64],
1927 );
1928 [
1929 Li::AbstractFloat(p[0]),
1930 Li::AbstractFloat(p[1]),
1931 Li::AbstractFloat(p[2]),
1932 ]
1933 }
1934 (
1935 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1936 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1937 ) => {
1938 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1939 [
1940 Li::AbstractFloat(p[0]),
1941 Li::AbstractFloat(p[1]),
1942 Li::AbstractFloat(p[2]),
1943 ]
1944 }
1945 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1946 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1947 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1948 }
1949 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1950 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1951 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1952 }
1953 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1954 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1955 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1956 }
1957 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1958 };
1959
1960 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1961 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1962 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1963
1964 self.register_evaluated_expr(
1965 Expression::Compose {
1966 ty,
1967 components: vec![p0, p1, p2],
1968 },
1969 span,
1970 )
1971 }
1972
1973 fn extract_vec_with_size<const N: usize>(
1981 &mut self,
1982 expr: Handle<Expression>,
1983 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1984 let span = self.expressions.get_span(expr);
1985 let expr = self.eval_zero_value_and_splat(expr, span)?;
1986 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1987 return Err(ConstantEvaluatorError::InvalidMathArg);
1988 };
1989
1990 let mut value = [Literal::Bool(false); N];
1991 for (component, elt) in
1992 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1993 .zip(value.iter_mut())
1994 {
1995 let Expression::Literal(literal) = self.expressions[component] else {
1996 return Err(ConstantEvaluatorError::InvalidMathArg);
1997 };
1998 *elt = literal;
1999 }
2000
2001 Ok((value, ty))
2002 }
2003
2004 fn extract_vec(
2012 &mut self,
2013 expr: Handle<Expression>,
2014 allow_single: bool,
2015 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2016 let span = self.expressions.get_span(expr);
2017 let expr = self.eval_zero_value_and_splat(expr, span)?;
2018
2019 match self.expressions[expr] {
2020 Expression::Literal(literal) if allow_single => {
2021 Ok(LiteralVector::from_literal(literal))
2022 }
2023 Expression::Compose { ty, ref components } => {
2024 let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2025 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2026 .map(|expr| match self.expressions[expr] {
2027 Expression::Literal(l) => Ok(l),
2028 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2029 })
2030 .collect::<Result<_, ConstantEvaluatorError>>()?;
2031 LiteralVector::from_literal_vec(components)
2032 }
2033 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2034 }
2035 }
2036
2037 fn array_length(
2038 &mut self,
2039 array: Handle<Expression>,
2040 span: Span,
2041 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2042 match self.expressions[array] {
2043 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2044 match self.types[ty].inner {
2045 TypeInner::Array { size, .. } => match size {
2046 ArraySize::Constant(len) => {
2047 let expr = Expression::Literal(Literal::U32(len.get()));
2048 self.register_evaluated_expr(expr, span)
2049 }
2050 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2051 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2052 },
2053 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2054 }
2055 }
2056 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2057 }
2058 }
2059
2060 fn access(
2061 &mut self,
2062 base: Handle<Expression>,
2063 index: usize,
2064 span: Span,
2065 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2066 match self.expressions[base] {
2067 Expression::ZeroValue(ty) => {
2068 let ty_inner = &self.types[ty].inner;
2069 let components = ty_inner
2070 .components()
2071 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2072
2073 if index >= components as usize {
2074 Err(ConstantEvaluatorError::InvalidAccessBase)
2075 } else {
2076 let ty_res = ty_inner
2077 .component_type(index)
2078 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2079 let ty = match ty_res {
2080 crate::proc::TypeResolution::Handle(ty) => ty,
2081 crate::proc::TypeResolution::Value(inner) => {
2082 self.types.insert(Type { name: None, inner }, span)
2083 }
2084 };
2085 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2086 }
2087 }
2088 Expression::Splat { size, value } => {
2089 if index >= size as usize {
2090 Err(ConstantEvaluatorError::InvalidAccessBase)
2091 } else {
2092 Ok(value)
2093 }
2094 }
2095 Expression::Compose { ty, ref components } => {
2096 let _ = self.types[ty]
2097 .inner
2098 .components()
2099 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2100
2101 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2102 .nth(index)
2103 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2104 }
2105 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2106 }
2107 }
2108
2109 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
2110 match self.expressions[expr] {
2111 Expression::ZeroValue(ty)
2112 if matches!(
2113 self.types[ty].inner,
2114 TypeInner::Scalar(crate::Scalar {
2115 kind: ScalarKind::Uint,
2116 ..
2117 })
2118 ) =>
2119 {
2120 Ok(0)
2121 }
2122 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
2123 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
2124 }
2125 }
2126
2127 fn eval_zero_value_and_splat(
2134 &mut self,
2135 mut expr: Handle<Expression>,
2136 span: Span,
2137 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2138 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2141 let components = components
2142 .clone()
2143 .iter()
2144 .map(|component| self.eval_zero_value_and_splat(*component, span))
2145 .collect::<Result<_, _>>()?;
2146 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2147 }
2148
2149 if let Expression::Splat { size, value } = self.expressions[expr] {
2153 expr = self.splat(value, size, span)?;
2154 }
2155 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2156 expr = self.eval_zero_value_impl(ty, span)?;
2157 }
2158 Ok(expr)
2159 }
2160
2161 fn eval_zero_value(
2167 &mut self,
2168 expr: Handle<Expression>,
2169 span: Span,
2170 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2171 match self.expressions[expr] {
2172 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2173 _ => Ok(expr),
2174 }
2175 }
2176
2177 fn eval_zero_value_impl(
2183 &mut self,
2184 ty: Handle<Type>,
2185 span: Span,
2186 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2187 match self.types[ty].inner {
2188 TypeInner::Scalar(scalar) => {
2189 let expr = Expression::Literal(
2190 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2191 );
2192 self.register_evaluated_expr(expr, span)
2193 }
2194 TypeInner::Vector { size, scalar } => {
2195 let scalar_ty = self.types.insert(
2196 Type {
2197 name: None,
2198 inner: TypeInner::Scalar(scalar),
2199 },
2200 span,
2201 );
2202 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2203 let expr = Expression::Compose {
2204 ty,
2205 components: vec![el; size as usize],
2206 };
2207 self.register_evaluated_expr(expr, span)
2208 }
2209 TypeInner::Matrix {
2210 columns,
2211 rows,
2212 scalar,
2213 } => {
2214 let vec_ty = self.types.insert(
2215 Type {
2216 name: None,
2217 inner: TypeInner::Vector { size: rows, scalar },
2218 },
2219 span,
2220 );
2221 let el = self.eval_zero_value_impl(vec_ty, span)?;
2222 let expr = Expression::Compose {
2223 ty,
2224 components: vec![el; columns as usize],
2225 };
2226 self.register_evaluated_expr(expr, span)
2227 }
2228 TypeInner::Array {
2229 base,
2230 size: ArraySize::Constant(size),
2231 ..
2232 } => {
2233 let el = self.eval_zero_value_impl(base, span)?;
2234 let expr = Expression::Compose {
2235 ty,
2236 components: vec![el; size.get() as usize],
2237 };
2238 self.register_evaluated_expr(expr, span)
2239 }
2240 TypeInner::Struct { ref members, .. } => {
2241 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2242 let mut components = Vec::with_capacity(members.len());
2243 for ty in types {
2244 components.push(self.eval_zero_value_impl(ty, span)?);
2245 }
2246 let expr = Expression::Compose { ty, components };
2247 self.register_evaluated_expr(expr, span)
2248 }
2249 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2250 }
2251 }
2252
2253 pub fn cast(
2257 &mut self,
2258 expr: Handle<Expression>,
2259 target: crate::Scalar,
2260 span: Span,
2261 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2262 use crate::Scalar as Sc;
2263
2264 let expr = self.eval_zero_value(expr, span)?;
2265
2266 let make_error = || -> Result<_, ConstantEvaluatorError> {
2267 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2268
2269 #[cfg(feature = "wgsl-in")]
2270 let to = target.to_wgsl_for_diagnostics();
2271
2272 #[cfg(not(feature = "wgsl-in"))]
2273 let to = format!("{target:?}");
2274
2275 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2276 };
2277
2278 use crate::proc::type_methods::IntFloatLimits;
2279
2280 let expr = match self.expressions[expr] {
2281 Expression::Literal(literal) => {
2282 let literal = match target {
2283 Sc::I32 => Literal::I32(match literal {
2284 Literal::I32(v) => v,
2285 Literal::U32(v) => v as i32,
2286 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2287 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2289 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2290 return make_error();
2291 }
2292 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2293 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2294 }),
2295 Sc::U32 => Literal::U32(match literal {
2296 Literal::I32(v) => v as u32,
2297 Literal::U32(v) => v,
2298 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2299 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2301 Literal::Bool(v) => v as u32,
2302 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2303 return make_error();
2304 }
2305 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2306 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2307 }),
2308 Sc::I64 => Literal::I64(match literal {
2309 Literal::I32(v) => v as i64,
2310 Literal::U32(v) => v as i64,
2311 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2312 Literal::Bool(v) => v as i64,
2313 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2314 Literal::I64(v) => v,
2315 Literal::U64(v) => v as i64,
2316 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2318 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2319 }),
2320 Sc::U64 => Literal::U64(match literal {
2321 Literal::I32(v) => v as u64,
2322 Literal::U32(v) => v as u64,
2323 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2324 Literal::Bool(v) => v as u64,
2325 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2326 Literal::I64(v) => v as u64,
2327 Literal::U64(v) => v,
2328 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2330 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2331 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2332 }),
2333 Sc::F16 => Literal::F16(match literal {
2334 Literal::F16(v) => v,
2335 Literal::F32(v) => f16::from_f32(v),
2336 Literal::F64(v) => f16::from_f64(v),
2337 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2338 Literal::I64(v) => f16::from_i64(v).unwrap(),
2339 Literal::U64(v) => f16::from_u64(v).unwrap(),
2340 Literal::I32(v) => f16::from_i32(v).unwrap(),
2341 Literal::U32(v) => f16::from_u32(v).unwrap(),
2342 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2343 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2344 }),
2345 Sc::F32 => Literal::F32(match literal {
2346 Literal::I32(v) => v as f32,
2347 Literal::U32(v) => v as f32,
2348 Literal::F32(v) => v,
2349 Literal::Bool(v) => v as u32 as f32,
2350 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2351 return make_error();
2352 }
2353 Literal::F16(v) => f16::to_f32(v),
2354 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2355 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2356 }),
2357 Sc::F64 => Literal::F64(match literal {
2358 Literal::I32(v) => v as f64,
2359 Literal::U32(v) => v as f64,
2360 Literal::F16(v) => f16::to_f64(v),
2361 Literal::F32(v) => v as f64,
2362 Literal::F64(v) => v,
2363 Literal::Bool(v) => v as u32 as f64,
2364 Literal::I64(_) | Literal::U64(_) => return make_error(),
2365 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2366 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2367 }),
2368 Sc::BOOL => Literal::Bool(match literal {
2369 Literal::I32(v) => v != 0,
2370 Literal::U32(v) => v != 0,
2371 Literal::F32(v) => v != 0.0,
2372 Literal::F16(v) => v != f16::zero(),
2373 Literal::Bool(v) => v,
2374 Literal::AbstractInt(v) => v != 0,
2375 Literal::AbstractFloat(v) => v != 0.0,
2376 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2377 return make_error();
2378 }
2379 }),
2380 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2381 Literal::AbstractInt(v) => {
2382 v as f64
2387 }
2388 Literal::AbstractFloat(v) => v,
2389 _ => return make_error(),
2390 }),
2391 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2392 Literal::AbstractInt(v) => v,
2393 _ => return make_error(),
2394 }),
2395 _ => {
2396 log::debug!("Constant evaluator refused to convert value to {target:?}");
2397 return make_error();
2398 }
2399 };
2400 Expression::Literal(literal)
2401 }
2402 Expression::Compose {
2403 ty,
2404 components: ref src_components,
2405 } => {
2406 let ty_inner = match self.types[ty].inner {
2407 TypeInner::Vector { size, .. } => TypeInner::Vector {
2408 size,
2409 scalar: target,
2410 },
2411 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2412 columns,
2413 rows,
2414 scalar: target,
2415 },
2416 _ => return make_error(),
2417 };
2418
2419 let mut components = src_components.clone();
2420 for component in &mut components {
2421 *component = self.cast(*component, target, span)?;
2422 }
2423
2424 let ty = self.types.insert(
2425 Type {
2426 name: None,
2427 inner: ty_inner,
2428 },
2429 span,
2430 );
2431
2432 Expression::Compose { ty, components }
2433 }
2434 Expression::Splat { size, value } => {
2435 let value_span = self.expressions.get_span(value);
2436 let cast_value = self.cast(value, target, value_span)?;
2437 Expression::Splat {
2438 size,
2439 value: cast_value,
2440 }
2441 }
2442 _ => return make_error(),
2443 };
2444
2445 self.register_evaluated_expr(expr, span)
2446 }
2447
2448 pub fn cast_array(
2461 &mut self,
2462 expr: Handle<Expression>,
2463 target: crate::Scalar,
2464 span: Span,
2465 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2466 let expr = self.check_and_get(expr)?;
2467
2468 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2469 return self.cast(expr, target, span);
2470 };
2471
2472 let TypeInner::Array {
2473 base: _,
2474 size,
2475 stride: _,
2476 } = self.types[ty].inner
2477 else {
2478 return self.cast(expr, target, span);
2479 };
2480
2481 let mut components = components.clone();
2482 for component in &mut components {
2483 *component = self.cast_array(*component, target, span)?;
2484 }
2485
2486 let first = components.first().unwrap();
2487 let new_base = match self.resolve_type(*first)? {
2488 crate::proc::TypeResolution::Handle(ty) => ty,
2489 crate::proc::TypeResolution::Value(inner) => {
2490 self.types.insert(Type { name: None, inner }, span)
2491 }
2492 };
2493 let mut layouter = core::mem::take(self.layouter);
2494 layouter.update(self.to_ctx()).unwrap();
2495 *self.layouter = layouter;
2496
2497 let new_base_stride = self.layouter[new_base].to_stride();
2498 let new_array_ty = self.types.insert(
2499 Type {
2500 name: None,
2501 inner: TypeInner::Array {
2502 base: new_base,
2503 size,
2504 stride: new_base_stride,
2505 },
2506 },
2507 span,
2508 );
2509
2510 let compose = Expression::Compose {
2511 ty: new_array_ty,
2512 components,
2513 };
2514 self.register_evaluated_expr(compose, span)
2515 }
2516
2517 fn unary_op(
2518 &mut self,
2519 op: UnaryOperator,
2520 expr: Handle<Expression>,
2521 span: Span,
2522 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2523 let expr = self.eval_zero_value_and_splat(expr, span)?;
2524
2525 let expr = match self.expressions[expr] {
2526 Expression::Literal(value) => Expression::Literal(match op {
2527 UnaryOperator::Negate => match value {
2528 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2529 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2530 Literal::F32(v) => Literal::F32(-v),
2531 Literal::F16(v) => Literal::F16(-v),
2532 Literal::F64(v) => Literal::F64(-v),
2533 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2534 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2535 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2536 },
2537 UnaryOperator::LogicalNot => match value {
2538 Literal::Bool(v) => Literal::Bool(!v),
2539 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2540 },
2541 UnaryOperator::BitwiseNot => match value {
2542 Literal::I32(v) => Literal::I32(!v),
2543 Literal::I64(v) => Literal::I64(!v),
2544 Literal::U32(v) => Literal::U32(!v),
2545 Literal::U64(v) => Literal::U64(!v),
2546 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2547 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2548 },
2549 }),
2550 Expression::Compose {
2551 ty,
2552 components: ref src_components,
2553 } => {
2554 match self.types[ty].inner {
2555 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2556 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2557 }
2558
2559 let mut components = src_components.clone();
2560 for component in &mut components {
2561 *component = self.unary_op(op, *component, span)?;
2562 }
2563
2564 Expression::Compose { ty, components }
2565 }
2566 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2567 };
2568
2569 self.register_evaluated_expr(expr, span)
2570 }
2571
2572 fn binary_op(
2573 &mut self,
2574 op: BinaryOperator,
2575 left: Handle<Expression>,
2576 right: Handle<Expression>,
2577 span: Span,
2578 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2579 let left = self.eval_zero_value_and_splat(left, span)?;
2580 let right = self.eval_zero_value_and_splat(right, span)?;
2581
2582 let expr = match (&self.expressions[left], &self.expressions[right]) {
2583 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2584 let literal = match op {
2585 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2586 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2587 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2588 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2589 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2590 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2591
2592 _ => match (left_value, right_value) {
2593 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2594 BinaryOperator::Add => a.wrapping_add(b),
2595 BinaryOperator::Subtract => a.wrapping_sub(b),
2596 BinaryOperator::Multiply => a.wrapping_mul(b),
2597 BinaryOperator::Divide => {
2598 if b == 0 {
2599 return Err(ConstantEvaluatorError::DivisionByZero);
2600 } else {
2601 a.wrapping_div(b)
2602 }
2603 }
2604 BinaryOperator::Modulo => {
2605 if b == 0 {
2606 return Err(ConstantEvaluatorError::RemainderByZero);
2607 } else {
2608 a.wrapping_rem(b)
2609 }
2610 }
2611 BinaryOperator::And => a & b,
2612 BinaryOperator::ExclusiveOr => a ^ b,
2613 BinaryOperator::InclusiveOr => a | b,
2614 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2615 }),
2616 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2617 BinaryOperator::ShiftLeft => {
2618 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2619 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2620 }
2621 a.checked_shl(b)
2622 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2623 }
2624 BinaryOperator::ShiftRight => a
2625 .checked_shr(b)
2626 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2627 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2628 }),
2629 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2630 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2631 ConstantEvaluatorError::Overflow("addition".into())
2632 })?,
2633 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2634 ConstantEvaluatorError::Overflow("subtraction".into())
2635 })?,
2636 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2637 ConstantEvaluatorError::Overflow("multiplication".into())
2638 })?,
2639 BinaryOperator::Divide => a
2640 .checked_div(b)
2641 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2642 BinaryOperator::Modulo => a
2643 .checked_rem(b)
2644 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2645 BinaryOperator::And => a & b,
2646 BinaryOperator::ExclusiveOr => a ^ b,
2647 BinaryOperator::InclusiveOr => a | b,
2648 BinaryOperator::ShiftLeft => a
2649 .checked_mul(
2650 1u32.checked_shl(b)
2651 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2652 )
2653 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2654 BinaryOperator::ShiftRight => a
2655 .checked_shr(b)
2656 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2657 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2658 }),
2659 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2660 BinaryOperator::Add => a + b,
2661 BinaryOperator::Subtract => a - b,
2662 BinaryOperator::Multiply => a * b,
2663 BinaryOperator::Divide => a / b,
2664 BinaryOperator::Modulo => a % b,
2665 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2666 }),
2667 (Literal::AbstractInt(a), Literal::U32(b)) => {
2668 Literal::AbstractInt(match op {
2669 BinaryOperator::ShiftLeft => {
2670 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2671 return Err(ConstantEvaluatorError::Overflow(
2672 "<<".to_string(),
2673 ));
2674 }
2675 a.checked_shl(b).unwrap_or(0)
2676 }
2677 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2678 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2679 })
2680 }
2681 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2682 BinaryOperator::Add => a + b,
2683 BinaryOperator::Subtract => a - b,
2684 BinaryOperator::Multiply => a * b,
2685 BinaryOperator::Divide => a / b,
2686 BinaryOperator::Modulo => a % b,
2687 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2688 }),
2689 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2690 Literal::AbstractInt(match op {
2691 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2692 ConstantEvaluatorError::Overflow("addition".into())
2693 })?,
2694 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2695 ConstantEvaluatorError::Overflow("subtraction".into())
2696 })?,
2697 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2698 ConstantEvaluatorError::Overflow("multiplication".into())
2699 })?,
2700 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2701 if b == 0 {
2702 ConstantEvaluatorError::DivisionByZero
2703 } else {
2704 ConstantEvaluatorError::Overflow("division".into())
2705 }
2706 })?,
2707 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2708 if b == 0 {
2709 ConstantEvaluatorError::RemainderByZero
2710 } else {
2711 ConstantEvaluatorError::Overflow("remainder".into())
2712 }
2713 })?,
2714 BinaryOperator::And => a & b,
2715 BinaryOperator::ExclusiveOr => a ^ b,
2716 BinaryOperator::InclusiveOr => a | b,
2717 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2718 })
2719 }
2720 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2721 Literal::AbstractFloat(match op {
2722 BinaryOperator::Add => a + b,
2723 BinaryOperator::Subtract => a - b,
2724 BinaryOperator::Multiply => a * b,
2725 BinaryOperator::Divide => a / b,
2726 BinaryOperator::Modulo => a % b,
2727 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728 })
2729 }
2730 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2731 BinaryOperator::LogicalAnd => a && b,
2732 BinaryOperator::LogicalOr => a || b,
2733 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2734 }),
2735 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2736 },
2737 };
2738 Expression::Literal(literal)
2739 }
2740 (
2741 &Expression::Compose {
2742 components: ref src_components,
2743 ty,
2744 },
2745 &Expression::Literal(_),
2746 ) => {
2747 let mut components = src_components.clone();
2748 for component in &mut components {
2749 *component = self.binary_op(op, *component, right, span)?;
2750 }
2751 Expression::Compose { ty, components }
2752 }
2753 (
2754 &Expression::Literal(_),
2755 &Expression::Compose {
2756 components: ref src_components,
2757 ty,
2758 },
2759 ) => {
2760 let mut components = src_components.clone();
2761 for component in &mut components {
2762 *component = self.binary_op(op, left, *component, span)?;
2763 }
2764 Expression::Compose { ty, components }
2765 }
2766 (
2767 &Expression::Compose {
2768 components: ref left_components,
2769 ty: left_ty,
2770 },
2771 &Expression::Compose {
2772 components: ref right_components,
2773 ty: right_ty,
2774 },
2775 ) => {
2776 let left_flattened = crate::proc::flatten_compose(
2780 left_ty,
2781 left_components,
2782 self.expressions,
2783 self.types,
2784 );
2785 let right_flattened = crate::proc::flatten_compose(
2786 right_ty,
2787 right_components,
2788 self.expressions,
2789 self.types,
2790 );
2791
2792 let mut flattened = Vec::with_capacity(left_components.len());
2795 flattened.extend(left_flattened.zip(right_flattened));
2796
2797 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2798 (
2799 &TypeInner::Vector {
2800 size: left_size, ..
2801 },
2802 &TypeInner::Vector {
2803 size: right_size, ..
2804 },
2805 ) if left_size == right_size => {
2806 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2807 }
2808 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2809 }
2810 }
2811 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2812 };
2813
2814 self.register_evaluated_expr(expr, span)
2815 }
2816
2817 fn binary_op_vector(
2818 &mut self,
2819 op: BinaryOperator,
2820 size: crate::VectorSize,
2821 components: &[(Handle<Expression>, Handle<Expression>)],
2822 left_ty: Handle<Type>,
2823 span: Span,
2824 ) -> Result<Expression, ConstantEvaluatorError> {
2825 let ty = match op {
2826 BinaryOperator::Equal
2828 | BinaryOperator::NotEqual
2829 | BinaryOperator::Less
2830 | BinaryOperator::LessEqual
2831 | BinaryOperator::Greater
2832 | BinaryOperator::GreaterEqual => self.types.insert(
2833 Type {
2834 name: None,
2835 inner: TypeInner::Vector {
2836 size,
2837 scalar: crate::Scalar::BOOL,
2838 },
2839 },
2840 span,
2841 ),
2842
2843 BinaryOperator::Add
2846 | BinaryOperator::Subtract
2847 | BinaryOperator::Multiply
2848 | BinaryOperator::Divide
2849 | BinaryOperator::Modulo
2850 | BinaryOperator::And
2851 | BinaryOperator::ExclusiveOr
2852 | BinaryOperator::InclusiveOr
2853 | BinaryOperator::ShiftLeft
2854 | BinaryOperator::ShiftRight => left_ty,
2855
2856 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2857 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2859 }
2860 };
2861
2862 let components = components
2863 .iter()
2864 .map(|&(left, right)| self.binary_op(op, left, right, span))
2865 .collect::<Result<Vec<_>, _>>()?;
2866
2867 Ok(Expression::Compose { ty, components })
2868 }
2869
2870 fn relational(
2871 &mut self,
2872 fun: RelationalFunction,
2873 arg: Handle<Expression>,
2874 span: Span,
2875 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2876 let arg = self.eval_zero_value_and_splat(arg, span)?;
2877 match fun {
2878 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2879 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2880 Expression::Compose { ty, ref components }
2881 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2882 {
2883 let components =
2884 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2885 .map(|component| match self.expressions[component] {
2886 Expression::Literal(Literal::Bool(val)) => Ok(val),
2887 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2888 })
2889 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2890 let result = match fun {
2891 RelationalFunction::All => components.iter().all(|c| *c),
2892 RelationalFunction::Any => components.iter().any(|c| *c),
2893 _ => unreachable!(),
2894 };
2895 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2896 }
2897 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2898 },
2899 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2900 "{fun:?} built-in function"
2901 ))),
2902 }
2903 }
2904
2905 fn copy_from(
2913 &mut self,
2914 expr: Handle<Expression>,
2915 expressions: &Arena<Expression>,
2916 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2917 let span = expressions.get_span(expr);
2918 match expressions[expr] {
2919 ref expr @ (Expression::Literal(_)
2920 | Expression::Constant(_)
2921 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2922 Expression::Compose { ty, ref components } => {
2923 let mut components = components.clone();
2924 for component in &mut components {
2925 *component = self.copy_from(*component, expressions)?;
2926 }
2927 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2928 }
2929 Expression::Splat { size, value } => {
2930 let value = self.copy_from(value, expressions)?;
2931 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2932 }
2933 _ => {
2934 log::debug!("copy_from: SubexpressionsAreNotConstant");
2935 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2936 }
2937 }
2938 }
2939
2940 fn vector_compose_flattened_size(
2942 &self,
2943 components: &[Handle<Expression>],
2944 ) -> Result<usize, ConstantEvaluatorError> {
2945 components
2946 .iter()
2947 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2948 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2949 TypeInner::Scalar(_) => 1,
2950 TypeInner::Vector { size, .. } => size as usize,
2954 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2955 };
2956 Ok(acc + size)
2957 })
2958 }
2959
2960 fn register_evaluated_expr(
2961 &mut self,
2962 expr: Expression,
2963 span: Span,
2964 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2965 if let Expression::Literal(literal) = expr {
2970 crate::valid::check_literal_value(literal)?;
2971 }
2972
2973 if let Expression::Compose { ty, ref components } = expr {
2977 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2978 let expected = size as usize;
2979 let actual = self.vector_compose_flattened_size(components)?;
2980 if expected != actual {
2981 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2982 expected,
2983 actual,
2984 });
2985 }
2986 }
2987 }
2988
2989 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2990 }
2991
2992 fn append_expr(
2993 &mut self,
2994 expr: Expression,
2995 span: Span,
2996 expr_type: ExpressionKind,
2997 ) -> Handle<Expression> {
2998 let h = match self.behavior {
2999 Behavior::Wgsl(
3000 WgslRestrictions::Runtime(ref mut function_local_data)
3001 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3002 )
3003 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3004 let is_running = function_local_data.emitter.is_running();
3005 let needs_pre_emit = expr.needs_pre_emit();
3006 if is_running && needs_pre_emit {
3007 function_local_data
3008 .block
3009 .extend(function_local_data.emitter.finish(self.expressions));
3010 let h = self.expressions.append(expr, span);
3011 function_local_data.emitter.start(self.expressions);
3012 h
3013 } else {
3014 self.expressions.append(expr, span)
3015 }
3016 }
3017 _ => self.expressions.append(expr, span),
3018 };
3019 self.expression_kind_tracker.insert(h, expr_type);
3020 h
3021 }
3022
3023 fn resolve_type(
3024 &self,
3025 expr: Handle<Expression>,
3026 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3027 use crate::proc::TypeResolution as Tr;
3028 use crate::Expression as Ex;
3029 let resolution = match self.expressions[expr] {
3030 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3031 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3032 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3033 Ex::Splat { size, value } => {
3034 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3035 return Err(ConstantEvaluatorError::SplatScalarOnly);
3036 };
3037 Tr::Value(TypeInner::Vector { scalar, size })
3038 }
3039 _ => {
3040 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3041 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3042 }
3043 };
3044
3045 Ok(resolution)
3046 }
3047
3048 fn select(
3049 &mut self,
3050 reject: Handle<Expression>,
3051 accept: Handle<Expression>,
3052 condition: Handle<Expression>,
3053 span: Span,
3054 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3055 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3056
3057 let reject = arg(reject)?;
3058 let accept = arg(accept)?;
3059 let condition = arg(condition)?;
3060
3061 let select_single_component =
3062 |this: &mut Self, reject_scalar, reject, accept, condition| {
3063 let accept = this.cast(accept, reject_scalar, span)?;
3064 if condition {
3065 Ok(accept)
3066 } else {
3067 Ok(reject)
3068 }
3069 };
3070
3071 match (&self.expressions[reject], &self.expressions[accept]) {
3072 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3073 let reject_scalar = reject_lit.scalar();
3074 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3075 else {
3076 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3077 };
3078 select_single_component(self, reject_scalar, reject, accept, condition)
3079 }
3080 (
3081 &Expression::Compose {
3082 ty: reject_ty,
3083 components: ref reject_components,
3084 },
3085 &Expression::Compose {
3086 ty: accept_ty,
3087 components: ref accept_components,
3088 },
3089 ) => {
3090 let ty_deets = |ty| {
3091 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3092 (size.unwrap(), scalar)
3093 };
3094
3095 let expected_vec_size = {
3096 let [(reject_vec_size, _), (accept_vec_size, _)] =
3097 [reject_ty, accept_ty].map(ty_deets);
3098
3099 if reject_vec_size != accept_vec_size {
3100 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3101 reject: reject_vec_size,
3102 accept: accept_vec_size,
3103 });
3104 }
3105 reject_vec_size
3106 };
3107
3108 let condition_components = match self.expressions[condition] {
3109 Expression::Literal(Literal::Bool(condition)) => {
3110 vec![condition; (expected_vec_size as u8).into()]
3111 }
3112 Expression::Compose {
3113 ty: condition_ty,
3114 components: ref condition_components,
3115 } => {
3116 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3117 if condition_scalar.kind != ScalarKind::Bool {
3118 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3119 }
3120 if condition_vec_size != expected_vec_size {
3121 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3122 }
3123 condition_components
3124 .iter()
3125 .copied()
3126 .map(|component| match &self.expressions[component] {
3127 &Expression::Literal(Literal::Bool(condition)) => condition,
3128 _ => unreachable!(),
3129 })
3130 .collect()
3131 }
3132
3133 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3134 };
3135
3136 let evaluated = Expression::Compose {
3137 ty: reject_ty,
3138 components: reject_components
3139 .clone()
3140 .into_iter()
3141 .zip(accept_components.clone().into_iter())
3142 .zip(condition_components.into_iter())
3143 .map(|((reject, accept), condition)| {
3144 let reject_scalar = match &self.expressions[reject] {
3145 &Expression::Literal(lit) => lit.scalar(),
3146 _ => unreachable!(),
3147 };
3148 select_single_component(self, reject_scalar, reject, accept, condition)
3149 })
3150 .collect::<Result<_, _>>()?,
3151 };
3152 self.register_evaluated_expr(evaluated, span)
3153 }
3154 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3155 }
3156 }
3157}
3158
3159fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3160 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3164 match e {
3165 idx @ 0..=31 => idx,
3166 32 => u32::MAX,
3167 _ => unreachable!(),
3168 }
3169 };
3170 match concrete_int {
3171 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3172 ConcreteInt::I32([e]) => {
3173 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3174 }
3175 }
3176}
3177
3178#[test]
3179fn first_trailing_bit_smoke() {
3180 assert_eq!(
3181 first_trailing_bit(ConcreteInt::I32([0])),
3182 ConcreteInt::I32([-1])
3183 );
3184 assert_eq!(
3185 first_trailing_bit(ConcreteInt::I32([1])),
3186 ConcreteInt::I32([0])
3187 );
3188 assert_eq!(
3189 first_trailing_bit(ConcreteInt::I32([2])),
3190 ConcreteInt::I32([1])
3191 );
3192 assert_eq!(
3193 first_trailing_bit(ConcreteInt::I32([-1])),
3194 ConcreteInt::I32([0]),
3195 );
3196 assert_eq!(
3197 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3198 ConcreteInt::I32([31]),
3199 );
3200 assert_eq!(
3201 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3202 ConcreteInt::I32([0]),
3203 );
3204 for idx in 0..32 {
3205 assert_eq!(
3206 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3207 ConcreteInt::I32([idx])
3208 )
3209 }
3210
3211 assert_eq!(
3212 first_trailing_bit(ConcreteInt::U32([0])),
3213 ConcreteInt::U32([u32::MAX])
3214 );
3215 assert_eq!(
3216 first_trailing_bit(ConcreteInt::U32([1])),
3217 ConcreteInt::U32([0])
3218 );
3219 assert_eq!(
3220 first_trailing_bit(ConcreteInt::U32([2])),
3221 ConcreteInt::U32([1])
3222 );
3223 assert_eq!(
3224 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3225 ConcreteInt::U32([31]),
3226 );
3227 assert_eq!(
3228 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3229 ConcreteInt::U32([0]),
3230 );
3231 for idx in 0..32 {
3232 assert_eq!(
3233 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3234 ConcreteInt::U32([idx])
3235 )
3236 }
3237}
3238
3239fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3240 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3244 match e {
3245 idx @ 0..=31 => 31 - idx,
3246 32 => u32::MAX,
3247 _ => unreachable!(),
3248 }
3249 };
3250 match concrete_int {
3251 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3252 let rtl_bit_index = if e.is_negative() {
3253 e.leading_ones()
3254 } else {
3255 e.leading_zeros()
3256 };
3257 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3258 }]),
3259 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3260 }
3261}
3262
3263#[test]
3264fn first_leading_bit_smoke() {
3265 assert_eq!(
3266 first_leading_bit(ConcreteInt::I32([-1])),
3267 ConcreteInt::I32([-1])
3268 );
3269 assert_eq!(
3270 first_leading_bit(ConcreteInt::I32([0])),
3271 ConcreteInt::I32([-1])
3272 );
3273 assert_eq!(
3274 first_leading_bit(ConcreteInt::I32([1])),
3275 ConcreteInt::I32([0])
3276 );
3277 assert_eq!(
3278 first_leading_bit(ConcreteInt::I32([-2])),
3279 ConcreteInt::I32([0])
3280 );
3281 assert_eq!(
3282 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3283 ConcreteInt::I32([12])
3284 );
3285 assert_eq!(
3286 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3287 ConcreteInt::I32([30])
3288 );
3289 assert_eq!(
3290 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3291 ConcreteInt::I32([30])
3292 );
3293 for idx in 0..(32 - 1) {
3295 assert_eq!(
3296 first_leading_bit(ConcreteInt::I32([1 << idx])),
3297 ConcreteInt::I32([idx])
3298 );
3299 }
3300 for idx in 1..(32 - 1) {
3301 assert_eq!(
3302 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3303 ConcreteInt::I32([idx - 1])
3304 );
3305 }
3306
3307 assert_eq!(
3308 first_leading_bit(ConcreteInt::U32([0])),
3309 ConcreteInt::U32([u32::MAX])
3310 );
3311 assert_eq!(
3312 first_leading_bit(ConcreteInt::U32([1])),
3313 ConcreteInt::U32([0])
3314 );
3315 assert_eq!(
3316 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3317 ConcreteInt::U32([31])
3318 );
3319 for idx in 0..32 {
3320 assert_eq!(
3321 first_leading_bit(ConcreteInt::U32([1 << idx])),
3322 ConcreteInt::U32([idx])
3323 )
3324 }
3325}
3326
3327trait TryFromAbstract<T>: Sized {
3329 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3351}
3352
3353impl TryFromAbstract<i64> for i32 {
3354 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3355 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3356 value: format!("{value:?}"),
3357 to_type: "i32",
3358 })
3359 }
3360}
3361
3362impl TryFromAbstract<i64> for u32 {
3363 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3364 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3365 value: format!("{value:?}"),
3366 to_type: "u32",
3367 })
3368 }
3369}
3370
3371impl TryFromAbstract<i64> for u64 {
3372 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3373 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3374 value: format!("{value:?}"),
3375 to_type: "u64",
3376 })
3377 }
3378}
3379
3380impl TryFromAbstract<i64> for i64 {
3381 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3382 Ok(value)
3383 }
3384}
3385
3386impl TryFromAbstract<i64> for f32 {
3387 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3388 let f = value as f32;
3389 Ok(f)
3393 }
3394}
3395
3396impl TryFromAbstract<f64> for f32 {
3397 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3398 let f = value as f32;
3399 if f.is_infinite() {
3400 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3401 value: format!("{value:?}"),
3402 to_type: "f32",
3403 });
3404 }
3405 Ok(f)
3406 }
3407}
3408
3409impl TryFromAbstract<i64> for f64 {
3410 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3411 let f = value as f64;
3412 Ok(f)
3416 }
3417}
3418
3419impl TryFromAbstract<f64> for f64 {
3420 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3421 Ok(value)
3422 }
3423}
3424
3425impl TryFromAbstract<f64> for i32 {
3426 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3427 Ok(value as i32)
3440 }
3441}
3442
3443impl TryFromAbstract<f64> for u32 {
3444 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3445 Ok(value as u32)
3448 }
3449}
3450
3451impl TryFromAbstract<f64> for i64 {
3452 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3453 use crate::proc::type_methods::IntFloatLimits;
3456 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3457 }
3458}
3459
3460impl TryFromAbstract<f64> for u64 {
3461 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3462 use crate::proc::type_methods::IntFloatLimits;
3465 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3466 }
3467}
3468
3469impl TryFromAbstract<f64> for f16 {
3470 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3471 let f = f16::from_f64(value);
3472 if f.is_infinite() {
3473 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3474 value: format!("{value:?}"),
3475 to_type: "f16",
3476 });
3477 }
3478 Ok(f)
3479 }
3480}
3481
3482impl TryFromAbstract<i64> for f16 {
3483 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3484 let f = f16::from_i64(value);
3485 if f.is_none() {
3486 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3487 value: format!("{value:?}"),
3488 to_type: "f16",
3489 });
3490 }
3491 Ok(f.unwrap())
3492 }
3493}
3494
3495fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3496where
3497 T: Copy,
3498 T: core::ops::Mul<T, Output = T>,
3499 T: core::ops::Sub<T, Output = T>,
3500{
3501 [
3502 a[1] * b[2] - a[2] * b[1],
3503 a[2] * b[0] - a[0] * b[2],
3504 a[0] * b[1] - a[1] * b[0],
3505 ]
3506}
3507
3508#[cfg(test)]
3509mod tests {
3510 use alloc::{vec, vec::Vec};
3511
3512 use crate::{
3513 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3514 UniqueArena, VectorSize,
3515 };
3516
3517 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3518
3519 #[test]
3520 fn unary_op() {
3521 let mut types = UniqueArena::new();
3522 let mut constants = Arena::new();
3523 let overrides = Arena::new();
3524 let mut global_expressions = Arena::new();
3525
3526 let scalar_ty = types.insert(
3527 Type {
3528 name: None,
3529 inner: TypeInner::Scalar(crate::Scalar::I32),
3530 },
3531 Default::default(),
3532 );
3533
3534 let vec_ty = types.insert(
3535 Type {
3536 name: None,
3537 inner: TypeInner::Vector {
3538 size: VectorSize::Bi,
3539 scalar: crate::Scalar::I32,
3540 },
3541 },
3542 Default::default(),
3543 );
3544
3545 let h = constants.append(
3546 Constant {
3547 name: None,
3548 ty: scalar_ty,
3549 init: global_expressions
3550 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3551 },
3552 Default::default(),
3553 );
3554
3555 let h1 = constants.append(
3556 Constant {
3557 name: None,
3558 ty: scalar_ty,
3559 init: global_expressions
3560 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3561 },
3562 Default::default(),
3563 );
3564
3565 let vec_h = constants.append(
3566 Constant {
3567 name: None,
3568 ty: vec_ty,
3569 init: global_expressions.append(
3570 Expression::Compose {
3571 ty: vec_ty,
3572 components: vec![constants[h].init, constants[h1].init],
3573 },
3574 Default::default(),
3575 ),
3576 },
3577 Default::default(),
3578 );
3579
3580 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3581 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3582
3583 let expr2 = Expression::Unary {
3584 op: UnaryOperator::Negate,
3585 expr,
3586 };
3587
3588 let expr3 = Expression::Unary {
3589 op: UnaryOperator::BitwiseNot,
3590 expr,
3591 };
3592
3593 let expr4 = Expression::Unary {
3594 op: UnaryOperator::BitwiseNot,
3595 expr: expr1,
3596 };
3597
3598 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3599 let mut solver = ConstantEvaluator {
3600 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3601 types: &mut types,
3602 constants: &constants,
3603 overrides: &overrides,
3604 expressions: &mut global_expressions,
3605 expression_kind_tracker,
3606 layouter: &mut crate::proc::Layouter::default(),
3607 };
3608
3609 let res1 = solver
3610 .try_eval_and_append(expr2, Default::default())
3611 .unwrap();
3612 let res2 = solver
3613 .try_eval_and_append(expr3, Default::default())
3614 .unwrap();
3615 let res3 = solver
3616 .try_eval_and_append(expr4, Default::default())
3617 .unwrap();
3618
3619 assert_eq!(
3620 global_expressions[res1],
3621 Expression::Literal(Literal::I32(-4))
3622 );
3623
3624 assert_eq!(
3625 global_expressions[res2],
3626 Expression::Literal(Literal::I32(!4))
3627 );
3628
3629 let res3_inner = &global_expressions[res3];
3630
3631 match *res3_inner {
3632 Expression::Compose {
3633 ref ty,
3634 ref components,
3635 } => {
3636 assert_eq!(*ty, vec_ty);
3637 let mut components_iter = components.iter().copied();
3638 assert_eq!(
3639 global_expressions[components_iter.next().unwrap()],
3640 Expression::Literal(Literal::I32(!4))
3641 );
3642 assert_eq!(
3643 global_expressions[components_iter.next().unwrap()],
3644 Expression::Literal(Literal::I32(!8))
3645 );
3646 assert!(components_iter.next().is_none());
3647 }
3648 _ => panic!("Expected vector"),
3649 }
3650 }
3651
3652 #[test]
3653 fn cast() {
3654 let mut types = UniqueArena::new();
3655 let mut constants = Arena::new();
3656 let overrides = Arena::new();
3657 let mut global_expressions = Arena::new();
3658
3659 let scalar_ty = types.insert(
3660 Type {
3661 name: None,
3662 inner: TypeInner::Scalar(crate::Scalar::I32),
3663 },
3664 Default::default(),
3665 );
3666
3667 let h = constants.append(
3668 Constant {
3669 name: None,
3670 ty: scalar_ty,
3671 init: global_expressions
3672 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3673 },
3674 Default::default(),
3675 );
3676
3677 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3678
3679 let root = Expression::As {
3680 expr,
3681 kind: ScalarKind::Bool,
3682 convert: Some(crate::BOOL_WIDTH),
3683 };
3684
3685 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3686 let mut solver = ConstantEvaluator {
3687 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3688 types: &mut types,
3689 constants: &constants,
3690 overrides: &overrides,
3691 expressions: &mut global_expressions,
3692 expression_kind_tracker,
3693 layouter: &mut crate::proc::Layouter::default(),
3694 };
3695
3696 let res = solver
3697 .try_eval_and_append(root, Default::default())
3698 .unwrap();
3699
3700 assert_eq!(
3701 global_expressions[res],
3702 Expression::Literal(Literal::Bool(true))
3703 );
3704 }
3705
3706 #[test]
3707 fn access() {
3708 let mut types = UniqueArena::new();
3709 let mut constants = Arena::new();
3710 let overrides = Arena::new();
3711 let mut global_expressions = Arena::new();
3712
3713 let matrix_ty = types.insert(
3714 Type {
3715 name: None,
3716 inner: TypeInner::Matrix {
3717 columns: VectorSize::Bi,
3718 rows: VectorSize::Tri,
3719 scalar: crate::Scalar::F32,
3720 },
3721 },
3722 Default::default(),
3723 );
3724
3725 let vec_ty = types.insert(
3726 Type {
3727 name: None,
3728 inner: TypeInner::Vector {
3729 size: VectorSize::Tri,
3730 scalar: crate::Scalar::F32,
3731 },
3732 },
3733 Default::default(),
3734 );
3735
3736 let mut vec1_components = Vec::with_capacity(3);
3737 let mut vec2_components = Vec::with_capacity(3);
3738
3739 for i in 0..3 {
3740 let h = global_expressions.append(
3741 Expression::Literal(Literal::F32(i as f32)),
3742 Default::default(),
3743 );
3744
3745 vec1_components.push(h)
3746 }
3747
3748 for i in 3..6 {
3749 let h = global_expressions.append(
3750 Expression::Literal(Literal::F32(i as f32)),
3751 Default::default(),
3752 );
3753
3754 vec2_components.push(h)
3755 }
3756
3757 let vec1 = constants.append(
3758 Constant {
3759 name: None,
3760 ty: vec_ty,
3761 init: global_expressions.append(
3762 Expression::Compose {
3763 ty: vec_ty,
3764 components: vec1_components,
3765 },
3766 Default::default(),
3767 ),
3768 },
3769 Default::default(),
3770 );
3771
3772 let vec2 = constants.append(
3773 Constant {
3774 name: None,
3775 ty: vec_ty,
3776 init: global_expressions.append(
3777 Expression::Compose {
3778 ty: vec_ty,
3779 components: vec2_components,
3780 },
3781 Default::default(),
3782 ),
3783 },
3784 Default::default(),
3785 );
3786
3787 let h = constants.append(
3788 Constant {
3789 name: None,
3790 ty: matrix_ty,
3791 init: global_expressions.append(
3792 Expression::Compose {
3793 ty: matrix_ty,
3794 components: vec![constants[vec1].init, constants[vec2].init],
3795 },
3796 Default::default(),
3797 ),
3798 },
3799 Default::default(),
3800 );
3801
3802 let base = global_expressions.append(Expression::Constant(h), Default::default());
3803
3804 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3805 let mut solver = ConstantEvaluator {
3806 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3807 types: &mut types,
3808 constants: &constants,
3809 overrides: &overrides,
3810 expressions: &mut global_expressions,
3811 expression_kind_tracker,
3812 layouter: &mut crate::proc::Layouter::default(),
3813 };
3814
3815 let root1 = Expression::AccessIndex { base, index: 1 };
3816
3817 let res1 = solver
3818 .try_eval_and_append(root1, Default::default())
3819 .unwrap();
3820
3821 let root2 = Expression::AccessIndex {
3822 base: res1,
3823 index: 2,
3824 };
3825
3826 let res2 = solver
3827 .try_eval_and_append(root2, Default::default())
3828 .unwrap();
3829
3830 match global_expressions[res1] {
3831 Expression::Compose {
3832 ref ty,
3833 ref components,
3834 } => {
3835 assert_eq!(*ty, vec_ty);
3836 let mut components_iter = components.iter().copied();
3837 assert_eq!(
3838 global_expressions[components_iter.next().unwrap()],
3839 Expression::Literal(Literal::F32(3.))
3840 );
3841 assert_eq!(
3842 global_expressions[components_iter.next().unwrap()],
3843 Expression::Literal(Literal::F32(4.))
3844 );
3845 assert_eq!(
3846 global_expressions[components_iter.next().unwrap()],
3847 Expression::Literal(Literal::F32(5.))
3848 );
3849 assert!(components_iter.next().is_none());
3850 }
3851 _ => panic!("Expected vector"),
3852 }
3853
3854 assert_eq!(
3855 global_expressions[res2],
3856 Expression::Literal(Literal::F32(5.))
3857 );
3858 }
3859
3860 #[test]
3861 fn compose_of_constants() {
3862 let mut types = UniqueArena::new();
3863 let mut constants = Arena::new();
3864 let overrides = Arena::new();
3865 let mut global_expressions = Arena::new();
3866
3867 let i32_ty = types.insert(
3868 Type {
3869 name: None,
3870 inner: TypeInner::Scalar(crate::Scalar::I32),
3871 },
3872 Default::default(),
3873 );
3874
3875 let vec2_i32_ty = types.insert(
3876 Type {
3877 name: None,
3878 inner: TypeInner::Vector {
3879 size: VectorSize::Bi,
3880 scalar: crate::Scalar::I32,
3881 },
3882 },
3883 Default::default(),
3884 );
3885
3886 let h = constants.append(
3887 Constant {
3888 name: None,
3889 ty: i32_ty,
3890 init: global_expressions
3891 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3892 },
3893 Default::default(),
3894 );
3895
3896 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3897
3898 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3899 let mut solver = ConstantEvaluator {
3900 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3901 types: &mut types,
3902 constants: &constants,
3903 overrides: &overrides,
3904 expressions: &mut global_expressions,
3905 expression_kind_tracker,
3906 layouter: &mut crate::proc::Layouter::default(),
3907 };
3908
3909 let solved_compose = solver
3910 .try_eval_and_append(
3911 Expression::Compose {
3912 ty: vec2_i32_ty,
3913 components: vec![h_expr, h_expr],
3914 },
3915 Default::default(),
3916 )
3917 .unwrap();
3918 let solved_negate = solver
3919 .try_eval_and_append(
3920 Expression::Unary {
3921 op: UnaryOperator::Negate,
3922 expr: solved_compose,
3923 },
3924 Default::default(),
3925 )
3926 .unwrap();
3927
3928 let pass = match global_expressions[solved_negate] {
3929 Expression::Compose { ty, ref components } => {
3930 ty == vec2_i32_ty
3931 && components.iter().all(|&component| {
3932 let component = &global_expressions[component];
3933 matches!(*component, Expression::Literal(Literal::I32(-4)))
3934 })
3935 }
3936 _ => false,
3937 };
3938 if !pass {
3939 panic!("unexpected evaluation result")
3940 }
3941 }
3942
3943 #[test]
3944 fn splat_of_constant() {
3945 let mut types = UniqueArena::new();
3946 let mut constants = Arena::new();
3947 let overrides = Arena::new();
3948 let mut global_expressions = Arena::new();
3949
3950 let i32_ty = types.insert(
3951 Type {
3952 name: None,
3953 inner: TypeInner::Scalar(crate::Scalar::I32),
3954 },
3955 Default::default(),
3956 );
3957
3958 let vec2_i32_ty = types.insert(
3959 Type {
3960 name: None,
3961 inner: TypeInner::Vector {
3962 size: VectorSize::Bi,
3963 scalar: crate::Scalar::I32,
3964 },
3965 },
3966 Default::default(),
3967 );
3968
3969 let h = constants.append(
3970 Constant {
3971 name: None,
3972 ty: i32_ty,
3973 init: global_expressions
3974 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3975 },
3976 Default::default(),
3977 );
3978
3979 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3980
3981 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3982 let mut solver = ConstantEvaluator {
3983 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3984 types: &mut types,
3985 constants: &constants,
3986 overrides: &overrides,
3987 expressions: &mut global_expressions,
3988 expression_kind_tracker,
3989 layouter: &mut crate::proc::Layouter::default(),
3990 };
3991
3992 let solved_compose = solver
3993 .try_eval_and_append(
3994 Expression::Splat {
3995 size: VectorSize::Bi,
3996 value: h_expr,
3997 },
3998 Default::default(),
3999 )
4000 .unwrap();
4001 let solved_negate = solver
4002 .try_eval_and_append(
4003 Expression::Unary {
4004 op: UnaryOperator::Negate,
4005 expr: solved_compose,
4006 },
4007 Default::default(),
4008 )
4009 .unwrap();
4010
4011 let pass = match global_expressions[solved_negate] {
4012 Expression::Compose { ty, ref components } => {
4013 ty == vec2_i32_ty
4014 && components.iter().all(|&component| {
4015 let component = &global_expressions[component];
4016 matches!(*component, Expression::Literal(Literal::I32(-4)))
4017 })
4018 }
4019 _ => false,
4020 };
4021 if !pass {
4022 panic!("unexpected evaluation result")
4023 }
4024 }
4025
4026 #[test]
4027 fn splat_of_zero_value() {
4028 let mut types = UniqueArena::new();
4029 let constants = Arena::new();
4030 let overrides = Arena::new();
4031 let mut global_expressions = Arena::new();
4032
4033 let f32_ty = types.insert(
4034 Type {
4035 name: None,
4036 inner: TypeInner::Scalar(crate::Scalar::F32),
4037 },
4038 Default::default(),
4039 );
4040
4041 let vec2_f32_ty = types.insert(
4042 Type {
4043 name: None,
4044 inner: TypeInner::Vector {
4045 size: VectorSize::Bi,
4046 scalar: crate::Scalar::F32,
4047 },
4048 },
4049 Default::default(),
4050 );
4051
4052 let five =
4053 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4054 let five_splat = global_expressions.append(
4055 Expression::Splat {
4056 size: VectorSize::Bi,
4057 value: five,
4058 },
4059 Default::default(),
4060 );
4061 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4062 let zero_splat = global_expressions.append(
4063 Expression::Splat {
4064 size: VectorSize::Bi,
4065 value: zero,
4066 },
4067 Default::default(),
4068 );
4069
4070 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4071 let mut solver = ConstantEvaluator {
4072 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4073 types: &mut types,
4074 constants: &constants,
4075 overrides: &overrides,
4076 expressions: &mut global_expressions,
4077 expression_kind_tracker,
4078 layouter: &mut crate::proc::Layouter::default(),
4079 };
4080
4081 let solved_add = solver
4082 .try_eval_and_append(
4083 Expression::Binary {
4084 op: crate::BinaryOperator::Add,
4085 left: zero_splat,
4086 right: five_splat,
4087 },
4088 Default::default(),
4089 )
4090 .unwrap();
4091
4092 let pass = match global_expressions[solved_add] {
4093 Expression::Compose { ty, ref components } => {
4094 ty == vec2_f32_ty
4095 && components.iter().all(|&component| {
4096 let component = &global_expressions[component];
4097 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4098 })
4099 }
4100 _ => false,
4101 };
4102 if !pass {
4103 panic!("unexpected evaluation result")
4104 }
4105 }
4106}