1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
273enum LiteralVector {
274 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284}
285
286impl LiteralVector {
287 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288 fn len(&self) -> usize {
289 match *self {
290 LiteralVector::F64(ref v) => v.len(),
291 LiteralVector::F32(ref v) => v.len(),
292 LiteralVector::F16(ref v) => v.len(),
293 LiteralVector::U32(ref v) => v.len(),
294 LiteralVector::I32(ref v) => v.len(),
295 LiteralVector::U64(ref v) => v.len(),
296 LiteralVector::I64(ref v) => v.len(),
297 LiteralVector::Bool(ref v) => v.len(),
298 LiteralVector::AbstractInt(ref v) => v.len(),
299 LiteralVector::AbstractFloat(ref v) => v.len(),
300 }
301 }
302
303 fn from_literal(literal: Literal) -> Self {
305 match literal {
306 Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307 Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308 Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309 Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310 Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311 Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312 Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313 Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314 Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315 Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316 }
317 }
318
319 fn from_literal_vec(
324 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325 ) -> Result<Self, ConstantEvaluatorError> {
326 assert!(!components.is_empty());
327 Ok(match components[0] {
328 Literal::I32(_) => Self::I32(
329 components
330 .iter()
331 .map(|l| match l {
332 &Literal::I32(v) => Ok(v),
333 _ => Err(ConstantEvaluatorError::InvalidMathArg),
335 })
336 .collect::<Result<_, _>>()?,
337 ),
338 Literal::U32(_) => Self::U32(
339 components
340 .iter()
341 .map(|l| match l {
342 &Literal::U32(v) => Ok(v),
343 _ => Err(ConstantEvaluatorError::InvalidMathArg),
344 })
345 .collect::<Result<_, _>>()?,
346 ),
347 Literal::I64(_) => Self::I64(
348 components
349 .iter()
350 .map(|l| match l {
351 &Literal::I64(v) => Ok(v),
352 _ => Err(ConstantEvaluatorError::InvalidMathArg),
353 })
354 .collect::<Result<_, _>>()?,
355 ),
356 Literal::U64(_) => Self::U64(
357 components
358 .iter()
359 .map(|l| match l {
360 &Literal::U64(v) => Ok(v),
361 _ => Err(ConstantEvaluatorError::InvalidMathArg),
362 })
363 .collect::<Result<_, _>>()?,
364 ),
365 Literal::F32(_) => Self::F32(
366 components
367 .iter()
368 .map(|l| match l {
369 &Literal::F32(v) => Ok(v),
370 _ => Err(ConstantEvaluatorError::InvalidMathArg),
371 })
372 .collect::<Result<_, _>>()?,
373 ),
374 Literal::F64(_) => Self::F64(
375 components
376 .iter()
377 .map(|l| match l {
378 &Literal::F64(v) => Ok(v),
379 _ => Err(ConstantEvaluatorError::InvalidMathArg),
380 })
381 .collect::<Result<_, _>>()?,
382 ),
383 Literal::Bool(_) => Self::Bool(
384 components
385 .iter()
386 .map(|l| match l {
387 &Literal::Bool(v) => Ok(v),
388 _ => Err(ConstantEvaluatorError::InvalidMathArg),
389 })
390 .collect::<Result<_, _>>()?,
391 ),
392 Literal::AbstractInt(_) => Self::AbstractInt(
393 components
394 .iter()
395 .map(|l| match l {
396 &Literal::AbstractInt(v) => Ok(v),
397 _ => Err(ConstantEvaluatorError::InvalidMathArg),
398 })
399 .collect::<Result<_, _>>()?,
400 ),
401 Literal::AbstractFloat(_) => Self::AbstractFloat(
402 components
403 .iter()
404 .map(|l| match l {
405 &Literal::AbstractFloat(v) => Ok(v),
406 _ => Err(ConstantEvaluatorError::InvalidMathArg),
407 })
408 .collect::<Result<_, _>>()?,
409 ),
410 Literal::F16(_) => Self::F16(
411 components
412 .iter()
413 .map(|l| match l {
414 &Literal::F16(v) => Ok(v),
415 _ => Err(ConstantEvaluatorError::InvalidMathArg),
416 })
417 .collect::<Result<_, _>>()?,
418 ),
419 })
420 }
421
422 #[allow(dead_code)]
423 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425 match *self {
426 LiteralVector::F64(ref v) => v.iter().map(|e| Literal::F64(*e)).collect(),
427 LiteralVector::F32(ref v) => v.iter().map(|e| Literal::F32(*e)).collect(),
428 LiteralVector::F16(ref v) => v.iter().map(|e| Literal::F16(*e)).collect(),
429 LiteralVector::U32(ref v) => v.iter().map(|e| Literal::U32(*e)).collect(),
430 LiteralVector::I32(ref v) => v.iter().map(|e| Literal::I32(*e)).collect(),
431 LiteralVector::U64(ref v) => v.iter().map(|e| Literal::U64(*e)).collect(),
432 LiteralVector::I64(ref v) => v.iter().map(|e| Literal::I64(*e)).collect(),
433 LiteralVector::Bool(ref v) => v.iter().map(|e| Literal::Bool(*e)).collect(),
434 LiteralVector::AbstractInt(ref v) => {
435 v.iter().map(|e| Literal::AbstractInt(*e)).collect()
436 }
437 LiteralVector::AbstractFloat(ref v) => {
438 v.iter().map(|e| Literal::AbstractFloat(*e)).collect()
439 }
440 }
441 }
442
443 #[allow(dead_code)]
444 fn register_as_evaluated_expr(
446 &self,
447 eval: &mut ConstantEvaluator<'_>,
448 span: Span,
449 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450 let lit_vec = self.to_literal_vec();
451 assert!(!lit_vec.is_empty());
452 let expr = if lit_vec.len() == 1 {
453 Expression::Literal(lit_vec[0])
454 } else {
455 Expression::Compose {
456 ty: eval.types.insert(
457 Type {
458 name: None,
459 inner: TypeInner::Vector {
460 size: match lit_vec.len() {
461 2 => crate::VectorSize::Bi,
462 3 => crate::VectorSize::Tri,
463 4 => crate::VectorSize::Quad,
464 _ => unreachable!(),
465 },
466 scalar: lit_vec[0].scalar(),
467 },
468 },
469 Span::UNDEFINED,
470 ),
471 components: lit_vec
472 .iter()
473 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474 .collect::<Result<_, _>>()?,
475 }
476 };
477 eval.register_evaluated_expr(expr, span)
478 }
479}
480
481macro_rules! match_literal_vector {
506 (match $lit_vec:expr => $out:ident {
507 $(
508 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509 ),+
510 $(,)?
511 }) => {
512 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513 };
514
515 (@inner_start
516 $lit_vec:expr;
517 $out:ident;
518 [$($ty:ident),+];
519 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520 ) => {
521 match_literal_vector!(@inner
522 $lit_vec;
523 $out;
524 [$($ty),+];
525 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526 )
527 };
528
529 (@inner
530 $lit_vec:expr;
531 $out:ident;
532 [$ty:ident $(, $ty1:ident)*];
533 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535 ) => {
536 match_literal_vector!(@inner
537 $ty;
538 $lit_vec;
539 $out;
540 [$($ty1),*];
541 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542 [$({ $($var),+ ; $($ret)? ; $body }),+]
543 )
544 };
545 (@inner
546 Integer;
547 $lit_vec:expr;
548 $out:ident;
549 [$($ty:ident),*];
550 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551 [
552 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554 ]
555 ) => {
556 match_literal_vector!(@inner
557 $lit_vec;
558 $out;
559 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
560 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561 [
562 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568 ]
569 )
570 };
571 (@inner
572 Float;
573 $lit_vec:expr;
574 $out:ident;
575 [$($ty:ident),*];
576 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577 [
578 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580 ]
581 ) => {
582 match_literal_vector!(@inner
583 $lit_vec;
584 $out;
585 [F16, F32, F64, AbstractFloat $(, $ty)*];
586 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587 [
588 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593 ]
594 )
595 };
596 (@inner
597 $ty:ident;
598 $lit_vec:expr;
599 $out:ident;
600 [$ty1:ident $(,$ty2:ident)*];
601 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604 ]
605 ) => {
606 match_literal_vector!(@inner
607 $ty1;
608 $lit_vec;
609 $out;
610 [$($ty2),*];
611 [
612 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613 { $ty; $($var),+ ; $($ret)? ; $body }
614 ] <>
615 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616
617 )
618 };
619 (@inner
620 $ty:ident;
621 $lit_vec:expr;
622 $out:ident;
623 [];
624 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626 ) => {
627 match_literal_vector!(@inner_finish
628 $lit_vec;
629 $out;
630 [
631 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632 { $ty; $($var),+ ; $($ret)? ; $body }
633 ]
634 )
635 };
636 (@inner_finish
637 $lit_vec:expr;
638 $out:ident;
639 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640 ) => {
641 match $lit_vec {
642 $(
643 #[allow(unused_parens)]
644 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645 )+
646 _ => Err(ConstantEvaluatorError::InvalidMathArg),
647 }
648 };
649 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
650 $out::$ty($body)
651 };
652 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653 $out::$ret($body)
654 };
655}
656
657#[derive(Debug)]
658enum Behavior<'a> {
659 Wgsl(WgslRestrictions<'a>),
660 Glsl(GlslRestrictions<'a>),
661}
662
663impl Behavior<'_> {
664 const fn has_runtime_restrictions(&self) -> bool {
666 matches!(
667 self,
668 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
669 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
670 )
671 }
672}
673
674#[derive(Debug)]
692pub struct ConstantEvaluator<'a> {
693 behavior: Behavior<'a>,
695
696 types: &'a mut UniqueArena<Type>,
703
704 constants: &'a Arena<Constant>,
706
707 overrides: &'a Arena<Override>,
709
710 expressions: &'a mut Arena<Expression>,
712
713 expression_kind_tracker: &'a mut ExpressionKindTracker,
715
716 layouter: &'a mut crate::proc::Layouter,
717}
718
719#[derive(Debug)]
720enum WgslRestrictions<'a> {
721 Const(Option<FunctionLocalData<'a>>),
723 Override,
726 Runtime(FunctionLocalData<'a>),
730}
731
732#[derive(Debug)]
733enum GlslRestrictions<'a> {
734 Const,
736 Runtime(FunctionLocalData<'a>),
740}
741
742#[derive(Debug)]
743struct FunctionLocalData<'a> {
744 global_expressions: &'a Arena<Expression>,
746 emitter: &'a mut super::Emitter,
747 block: &'a mut crate::Block,
748}
749
750#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
751pub enum ExpressionKind {
752 Const,
753 Override,
754 Runtime,
755}
756
757#[derive(Debug)]
758pub struct ExpressionKindTracker {
759 inner: HandleVec<Expression, ExpressionKind>,
760}
761
762impl ExpressionKindTracker {
763 pub const fn new() -> Self {
764 Self {
765 inner: HandleVec::new(),
766 }
767 }
768
769 pub fn force_non_const(&mut self, value: Handle<Expression>) {
771 self.inner[value] = ExpressionKind::Runtime;
772 }
773
774 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
775 self.inner.insert(value, expr_type);
776 }
777
778 pub fn is_const(&self, h: Handle<Expression>) -> bool {
779 matches!(self.type_of(h), ExpressionKind::Const)
780 }
781
782 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
783 matches!(
784 self.type_of(h),
785 ExpressionKind::Const | ExpressionKind::Override
786 )
787 }
788
789 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
790 self.inner[value]
791 }
792
793 pub fn from_arena(arena: &Arena<Expression>) -> Self {
794 let mut tracker = Self {
795 inner: HandleVec::with_capacity(arena.len()),
796 };
797 for (handle, expr) in arena.iter() {
798 tracker
799 .inner
800 .insert(handle, tracker.type_of_with_expr(expr));
801 }
802 tracker
803 }
804
805 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
806 match *expr {
807 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
808 ExpressionKind::Const
809 }
810 Expression::Override(_) => ExpressionKind::Override,
811 Expression::Compose { ref components, .. } => {
812 let mut expr_type = ExpressionKind::Const;
813 for component in components {
814 expr_type = expr_type.max(self.type_of(*component))
815 }
816 expr_type
817 }
818 Expression::Splat { value, .. } => self.type_of(value),
819 Expression::AccessIndex { base, .. } => self.type_of(base),
820 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
821 Expression::Swizzle { vector, .. } => self.type_of(vector),
822 Expression::Unary { expr, .. } => self.type_of(expr),
823 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
824 Expression::Math {
825 arg,
826 arg1,
827 arg2,
828 arg3,
829 ..
830 } => self
831 .type_of(arg)
832 .max(
833 arg1.map(|arg| self.type_of(arg))
834 .unwrap_or(ExpressionKind::Const),
835 )
836 .max(
837 arg2.map(|arg| self.type_of(arg))
838 .unwrap_or(ExpressionKind::Const),
839 )
840 .max(
841 arg3.map(|arg| self.type_of(arg))
842 .unwrap_or(ExpressionKind::Const),
843 ),
844 Expression::As { expr, .. } => self.type_of(expr),
845 Expression::Select {
846 condition,
847 accept,
848 reject,
849 } => self
850 .type_of(condition)
851 .max(self.type_of(accept))
852 .max(self.type_of(reject)),
853 Expression::Relational { argument, .. } => self.type_of(argument),
854 Expression::ArrayLength(expr) => self.type_of(expr),
855 _ => ExpressionKind::Runtime,
856 }
857 }
858}
859
860#[derive(Clone, Debug, thiserror::Error)]
861#[cfg_attr(test, derive(PartialEq))]
862pub enum ConstantEvaluatorError {
863 #[error("Constants cannot access function arguments")]
864 FunctionArg,
865 #[error("Constants cannot access global variables")]
866 GlobalVariable,
867 #[error("Constants cannot access local variables")]
868 LocalVariable,
869 #[error("Cannot get the array length of a non array type")]
870 InvalidArrayLengthArg,
871 #[error("Constants cannot get the array length of a dynamically sized array")]
872 ArrayLengthDynamic,
873 #[error("Cannot call arrayLength on array sized by override-expression")]
874 ArrayLengthOverridden,
875 #[error("Constants cannot call functions")]
876 Call,
877 #[error("Constants don't support workGroupUniformLoad")]
878 WorkGroupUniformLoadResult,
879 #[error("Constants don't support atomic functions")]
880 Atomic,
881 #[error("Constants don't support derivative functions")]
882 Derivative,
883 #[error("Constants don't support load expressions")]
884 Load,
885 #[error("Constants don't support image expressions")]
886 ImageExpression,
887 #[error("Constants don't support ray query expressions")]
888 RayQueryExpression,
889 #[error("Constants don't support subgroup expressions")]
890 SubgroupExpression,
891 #[error("Cannot access the type")]
892 InvalidAccessBase,
893 #[error("Cannot access at the index")]
894 InvalidAccessIndex,
895 #[error("Cannot access with index of type")]
896 InvalidAccessIndexTy,
897 #[error("Constants don't support array length expressions")]
898 ArrayLength,
899 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
900 InvalidCastArg { from: String, to: String },
901 #[error("Cannot apply the unary op to the argument")]
902 InvalidUnaryOpArg,
903 #[error("Cannot apply the binary op to the arguments")]
904 InvalidBinaryOpArgs,
905 #[error("Cannot apply math function to type")]
906 InvalidMathArg,
907 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
908 InvalidMathArgCount(crate::MathFunction, usize, usize),
909 #[error("Cannot apply relational function to type")]
910 InvalidRelationalArg(RelationalFunction),
911 #[error("value of `low` is greater than `high` for clamp built-in function")]
912 InvalidClamp,
913 #[error("Constructor expects {expected} components, found {actual}")]
914 InvalidVectorComposeLength { expected: usize, actual: usize },
915 #[error("Constructor must only contain vector or scalar arguments")]
916 InvalidVectorComposeComponent,
917 #[error("Splat is defined only on scalar values")]
918 SplatScalarOnly,
919 #[error("Can only swizzle vector constants")]
920 SwizzleVectorOnly,
921 #[error("swizzle component not present in source expression")]
922 SwizzleOutOfBounds,
923 #[error("Type is not constructible")]
924 TypeNotConstructible,
925 #[error("Subexpression(s) are not constant")]
926 SubexpressionsAreNotConstant,
927 #[error("Not implemented as constant expression: {0}")]
928 NotImplemented(String),
929 #[error("{0} operation overflowed")]
930 Overflow(String),
931 #[error(
932 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
933 )]
934 AutomaticConversionLossy {
935 value: String,
936 to_type: &'static str,
937 },
938 #[error("Division by zero")]
939 DivisionByZero,
940 #[error("Remainder by zero")]
941 RemainderByZero,
942 #[error("RHS of shift operation is greater than or equal to 32")]
943 ShiftedMoreThan32Bits,
944 #[error(transparent)]
945 Literal(#[from] crate::valid::LiteralError),
946 #[error("Can't use pipeline-overridable constants in const-expressions")]
947 Override,
948 #[error("Unexpected runtime-expression")]
949 RuntimeExpr,
950 #[error("Unexpected override-expression")]
951 OverrideExpr,
952 #[error("Expected boolean expression for condition argument of `select`, got something else")]
953 SelectScalarConditionNotABool,
954 #[error(
955 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
956 reject,
957 accept
958 )]
959 SelectVecRejectAcceptSizeMismatch {
960 reject: crate::VectorSize,
961 accept: crate::VectorSize,
962 },
963 #[error("Expected boolean vector for condition arg., got something else")]
964 SelectConditionNotAVecBool,
965 #[error(
966 "Expected same number of vector components between condition, accept, and reject args., got something else",
967 )]
968 SelectConditionVecSizeMismatch,
969 #[error(
970 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
971 )]
972 SelectAcceptRejectTypeMismatch,
973 #[error("Cooperative operations can't be constant")]
974 CooperativeOperation,
975}
976
977impl<'a> ConstantEvaluator<'a> {
978 pub const fn for_wgsl_module(
983 module: &'a mut crate::Module,
984 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
985 layouter: &'a mut crate::proc::Layouter,
986 in_override_ctx: bool,
987 ) -> Self {
988 Self::for_module(
989 Behavior::Wgsl(if in_override_ctx {
990 WgslRestrictions::Override
991 } else {
992 WgslRestrictions::Const(None)
993 }),
994 module,
995 global_expression_kind_tracker,
996 layouter,
997 )
998 }
999
1000 pub const fn for_glsl_module(
1005 module: &'a mut crate::Module,
1006 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1007 layouter: &'a mut crate::proc::Layouter,
1008 ) -> Self {
1009 Self::for_module(
1010 Behavior::Glsl(GlslRestrictions::Const),
1011 module,
1012 global_expression_kind_tracker,
1013 layouter,
1014 )
1015 }
1016
1017 const fn for_module(
1018 behavior: Behavior<'a>,
1019 module: &'a mut crate::Module,
1020 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
1021 layouter: &'a mut crate::proc::Layouter,
1022 ) -> Self {
1023 Self {
1024 behavior,
1025 types: &mut module.types,
1026 constants: &module.constants,
1027 overrides: &module.overrides,
1028 expressions: &mut module.global_expressions,
1029 expression_kind_tracker: global_expression_kind_tracker,
1030 layouter,
1031 }
1032 }
1033
1034 pub const fn for_wgsl_function(
1039 module: &'a mut crate::Module,
1040 expressions: &'a mut Arena<Expression>,
1041 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1042 layouter: &'a mut crate::proc::Layouter,
1043 emitter: &'a mut super::Emitter,
1044 block: &'a mut crate::Block,
1045 is_const: bool,
1046 ) -> Self {
1047 let local_data = FunctionLocalData {
1048 global_expressions: &module.global_expressions,
1049 emitter,
1050 block,
1051 };
1052 Self {
1053 behavior: Behavior::Wgsl(if is_const {
1054 WgslRestrictions::Const(Some(local_data))
1055 } else {
1056 WgslRestrictions::Runtime(local_data)
1057 }),
1058 types: &mut module.types,
1059 constants: &module.constants,
1060 overrides: &module.overrides,
1061 expressions,
1062 expression_kind_tracker: local_expression_kind_tracker,
1063 layouter,
1064 }
1065 }
1066
1067 pub const fn for_glsl_function(
1072 module: &'a mut crate::Module,
1073 expressions: &'a mut Arena<Expression>,
1074 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1075 layouter: &'a mut crate::proc::Layouter,
1076 emitter: &'a mut super::Emitter,
1077 block: &'a mut crate::Block,
1078 ) -> Self {
1079 Self {
1080 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1081 global_expressions: &module.global_expressions,
1082 emitter,
1083 block,
1084 })),
1085 types: &mut module.types,
1086 constants: &module.constants,
1087 overrides: &module.overrides,
1088 expressions,
1089 expression_kind_tracker: local_expression_kind_tracker,
1090 layouter,
1091 }
1092 }
1093
1094 pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1095 crate::proc::GlobalCtx {
1096 types: self.types,
1097 constants: self.constants,
1098 overrides: self.overrides,
1099 global_expressions: match self.function_local_data() {
1100 Some(data) => data.global_expressions,
1101 None => self.expressions,
1102 },
1103 }
1104 }
1105
1106 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1107 if !self.expression_kind_tracker.is_const(expr) {
1108 log::debug!("check: SubexpressionsAreNotConstant");
1109 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1110 }
1111 Ok(())
1112 }
1113
1114 fn check_and_get(
1115 &mut self,
1116 expr: Handle<Expression>,
1117 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1118 match self.expressions[expr] {
1119 Expression::Constant(c) => {
1120 if let Some(function_local_data) = self.function_local_data() {
1123 self.copy_from(
1125 self.constants[c].init,
1126 function_local_data.global_expressions,
1127 )
1128 } else {
1129 Ok(self.constants[c].init)
1131 }
1132 }
1133 _ => {
1134 self.check(expr)?;
1135 Ok(expr)
1136 }
1137 }
1138 }
1139
1140 pub fn try_eval_and_append(
1164 &mut self,
1165 expr: Expression,
1166 span: Span,
1167 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1168 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1169 ExpressionKind::Const => {
1170 let eval_result = self.try_eval_and_append_impl(&expr, span);
1171 if self.behavior.has_runtime_restrictions()
1176 && matches!(
1177 eval_result,
1178 Err(ConstantEvaluatorError::NotImplemented(_)
1179 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1180 )
1181 {
1182 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1183 } else {
1184 eval_result
1185 }
1186 }
1187 ExpressionKind::Override => match self.behavior {
1188 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1189 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1190 }
1191 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1192 Err(ConstantEvaluatorError::OverrideExpr)
1193 }
1194
1195 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1197 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1198 }
1199 Behavior::Glsl(GlslRestrictions::Const) => {
1200 Err(ConstantEvaluatorError::OverrideExpr)
1201 }
1202 },
1203 ExpressionKind::Runtime => {
1204 if self.behavior.has_runtime_restrictions() {
1205 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1206 } else {
1207 Err(ConstantEvaluatorError::RuntimeExpr)
1208 }
1209 }
1210 }
1211 }
1212
1213 const fn is_global_arena(&self) -> bool {
1215 matches!(
1216 self.behavior,
1217 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1218 | Behavior::Glsl(GlslRestrictions::Const)
1219 )
1220 }
1221
1222 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1223 match self.behavior {
1224 Behavior::Wgsl(
1225 WgslRestrictions::Runtime(ref function_local_data)
1226 | WgslRestrictions::Const(Some(ref function_local_data)),
1227 )
1228 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1229 Some(function_local_data)
1230 }
1231 _ => None,
1232 }
1233 }
1234
1235 fn try_eval_and_append_impl(
1236 &mut self,
1237 expr: &Expression,
1238 span: Span,
1239 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1240 log::trace!("try_eval_and_append: {expr:?}");
1241 match *expr {
1242 Expression::Constant(c) if self.is_global_arena() => {
1243 Ok(self.constants[c].init)
1246 }
1247 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1248 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1249 self.register_evaluated_expr(expr.clone(), span)
1250 }
1251 Expression::Compose { ty, ref components } => {
1252 let components = components
1253 .iter()
1254 .map(|component| self.check_and_get(*component))
1255 .collect::<Result<Vec<_>, _>>()?;
1256 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1257 }
1258 Expression::Splat { size, value } => {
1259 let value = self.check_and_get(value)?;
1260 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1261 }
1262 Expression::AccessIndex { base, index } => {
1263 let base = self.check_and_get(base)?;
1264
1265 self.access(base, index as usize, span)
1266 }
1267 Expression::Access { base, index } => {
1268 let base = self.check_and_get(base)?;
1269 let index = self.check_and_get(index)?;
1270
1271 let index_val: u32 = self
1272 .to_ctx()
1273 .get_const_val_from(index, self.expressions)
1274 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1275 self.access(base, index_val as usize, span)
1276 }
1277 Expression::Swizzle {
1278 size,
1279 vector,
1280 pattern,
1281 } => {
1282 let vector = self.check_and_get(vector)?;
1283
1284 self.swizzle(size, span, vector, pattern)
1285 }
1286 Expression::Unary { expr, op } => {
1287 let expr = self.check_and_get(expr)?;
1288
1289 self.unary_op(op, expr, span)
1290 }
1291 Expression::Binary { left, right, op } => {
1292 let left = self.check_and_get(left)?;
1293 let right = self.check_and_get(right)?;
1294
1295 self.binary_op(op, left, right, span)
1296 }
1297 Expression::Math {
1298 fun,
1299 arg,
1300 arg1,
1301 arg2,
1302 arg3,
1303 } => {
1304 let arg = self.check_and_get(arg)?;
1305 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1306 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1307 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1308
1309 self.math(arg, arg1, arg2, arg3, fun, span)
1310 }
1311 Expression::As {
1312 convert,
1313 expr,
1314 kind,
1315 } => {
1316 let expr = self.check_and_get(expr)?;
1317
1318 match convert {
1319 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1320 None => Err(ConstantEvaluatorError::NotImplemented(
1321 "bitcast built-in function".into(),
1322 )),
1323 }
1324 }
1325 Expression::Select {
1326 reject,
1327 accept,
1328 condition,
1329 } => {
1330 let mut arg = |expr| self.check_and_get(expr);
1331
1332 let reject = arg(reject)?;
1333 let accept = arg(accept)?;
1334 let condition = arg(condition)?;
1335
1336 self.select(reject, accept, condition, span)
1337 }
1338 Expression::Relational { fun, argument } => {
1339 let argument = self.check_and_get(argument)?;
1340 self.relational(fun, argument, span)
1341 }
1342 Expression::ArrayLength(expr) => match self.behavior {
1343 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1344 Behavior::Glsl(_) => {
1345 let expr = self.check_and_get(expr)?;
1346 self.array_length(expr, span)
1347 }
1348 },
1349 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1350 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1351 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1352 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1353 Expression::WorkGroupUniformLoadResult { .. } => {
1354 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1355 }
1356 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1357 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1358 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1359 Expression::ImageSample { .. }
1360 | Expression::ImageLoad { .. }
1361 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1362 Expression::RayQueryProceedResult
1363 | Expression::RayQueryGetIntersection { .. }
1364 | Expression::RayQueryVertexPositions { .. } => {
1365 Err(ConstantEvaluatorError::RayQueryExpression)
1366 }
1367 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1368 Expression::SubgroupOperationResult { .. } => {
1369 Err(ConstantEvaluatorError::SubgroupExpression)
1370 }
1371 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1372 Err(ConstantEvaluatorError::CooperativeOperation)
1373 }
1374 }
1375 }
1376
1377 fn splat(
1390 &mut self,
1391 value: Handle<Expression>,
1392 size: crate::VectorSize,
1393 span: Span,
1394 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1395 match self.expressions[value] {
1396 Expression::Literal(literal) => {
1397 let scalar = literal.scalar();
1398 let ty = self.types.insert(
1399 Type {
1400 name: None,
1401 inner: TypeInner::Vector { size, scalar },
1402 },
1403 span,
1404 );
1405 let expr = Expression::Compose {
1406 ty,
1407 components: vec![value; size as usize],
1408 };
1409 self.register_evaluated_expr(expr, span)
1410 }
1411 Expression::ZeroValue(ty) => {
1412 let inner = match self.types[ty].inner {
1413 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1414 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1415 };
1416 let res_ty = self.types.insert(Type { name: None, inner }, span);
1417 let expr = Expression::ZeroValue(res_ty);
1418 self.register_evaluated_expr(expr, span)
1419 }
1420 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1421 }
1422 }
1423
1424 fn swizzle(
1425 &mut self,
1426 size: crate::VectorSize,
1427 span: Span,
1428 src_constant: Handle<Expression>,
1429 pattern: [crate::SwizzleComponent; 4],
1430 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1431 let mut get_dst_ty = |ty| match self.types[ty].inner {
1432 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1433 Type {
1434 name: None,
1435 inner: TypeInner::Vector { size, scalar },
1436 },
1437 span,
1438 )),
1439 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1440 };
1441
1442 match self.expressions[src_constant] {
1443 Expression::ZeroValue(ty) => {
1444 let dst_ty = get_dst_ty(ty)?;
1445 let expr = Expression::ZeroValue(dst_ty);
1446 self.register_evaluated_expr(expr, span)
1447 }
1448 Expression::Splat { value, .. } => {
1449 let expr = Expression::Splat { size, value };
1450 self.register_evaluated_expr(expr, span)
1451 }
1452 Expression::Compose { ty, ref components } => {
1453 let dst_ty = get_dst_ty(ty)?;
1454
1455 let mut flattened = [src_constant; 4]; let len =
1457 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1458 .zip(flattened.iter_mut())
1459 .map(|(component, elt)| *elt = component)
1460 .count();
1461 let flattened = &flattened[..len];
1462
1463 let swizzled_components = pattern[..size as usize]
1464 .iter()
1465 .map(|&sc| {
1466 let sc = sc as usize;
1467 if let Some(elt) = flattened.get(sc) {
1468 Ok(*elt)
1469 } else {
1470 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1471 }
1472 })
1473 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1474 let expr = Expression::Compose {
1475 ty: dst_ty,
1476 components: swizzled_components,
1477 };
1478 self.register_evaluated_expr(expr, span)
1479 }
1480 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1481 }
1482 }
1483
1484 fn math(
1485 &mut self,
1486 arg: Handle<Expression>,
1487 arg1: Option<Handle<Expression>>,
1488 arg2: Option<Handle<Expression>>,
1489 arg3: Option<Handle<Expression>>,
1490 fun: crate::MathFunction,
1491 span: Span,
1492 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1493 let expected = fun.argument_count();
1494 let given = Some(arg)
1495 .into_iter()
1496 .chain(arg1)
1497 .chain(arg2)
1498 .chain(arg3)
1499 .count();
1500 if expected != given {
1501 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1502 fun, expected, given,
1503 ));
1504 }
1505
1506 match fun {
1508 crate::MathFunction::Abs => {
1510 component_wise_scalar(self, span, [arg], |args| match args {
1511 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1512 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1513 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1514 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1515 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1516 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1518 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1519 })
1520 }
1521 crate::MathFunction::Min => {
1522 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1523 Ok([e1.min(e2)])
1524 })
1525 }
1526 crate::MathFunction::Max => {
1527 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1528 Ok([e1.max(e2)])
1529 })
1530 }
1531 crate::MathFunction::Clamp => {
1532 component_wise_scalar!(
1533 self,
1534 span,
1535 [arg, arg1.unwrap(), arg2.unwrap()],
1536 |e, low, high| {
1537 if low > high {
1538 Err(ConstantEvaluatorError::InvalidClamp)
1539 } else {
1540 Ok([e.clamp(low, high)])
1541 }
1542 }
1543 )
1544 }
1545 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1546 Float::F16([e]) => Ok(Float::F16(
1547 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1548 )),
1549 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1550 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1551 }),
1552
1553 crate::MathFunction::Cos => {
1555 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1556 }
1557 crate::MathFunction::Cosh => {
1558 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1559 }
1560 crate::MathFunction::Sin => {
1561 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1562 }
1563 crate::MathFunction::Sinh => {
1564 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1565 }
1566 crate::MathFunction::Tan => {
1567 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1568 }
1569 crate::MathFunction::Tanh => {
1570 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1571 }
1572 crate::MathFunction::Acos => {
1573 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1574 }
1575 crate::MathFunction::Asin => {
1576 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1577 }
1578 crate::MathFunction::Atan => {
1579 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1580 }
1581 crate::MathFunction::Atan2 => {
1582 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1583 Ok([y.atan2(x)])
1584 })
1585 }
1586 crate::MathFunction::Asinh => {
1587 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1588 }
1589 crate::MathFunction::Acosh => {
1590 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1591 }
1592 crate::MathFunction::Atanh => {
1593 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1594 }
1595 crate::MathFunction::Radians => {
1596 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1597 }
1598 crate::MathFunction::Degrees => {
1599 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1600 }
1601
1602 crate::MathFunction::Ceil => {
1604 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1605 }
1606 crate::MathFunction::Floor => {
1607 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1608 }
1609 crate::MathFunction::Round => {
1610 component_wise_float(self, span, [arg], |e| match e {
1611 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1612 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1613 Float::F16([e]) => {
1614 fn round_ties_even(x: f64) -> f64 {
1622 let i = x as i64;
1623 let f = (x - i as f64).abs();
1624 if f == 0.5 {
1625 if i & 1 == 1 {
1626 (x.abs() + 0.5).copysign(x)
1628 } else {
1629 (x.abs() - 0.5).copysign(x)
1630 }
1631 } else {
1632 x.round()
1633 }
1634 }
1635
1636 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1637 }
1638 })
1639 }
1640 crate::MathFunction::Fract => {
1641 component_wise_float!(self, span, [arg], |e| {
1642 Ok([e - e.floor()])
1645 })
1646 }
1647 crate::MathFunction::Trunc => {
1648 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1649 }
1650
1651 crate::MathFunction::Exp => {
1653 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1654 }
1655 crate::MathFunction::Exp2 => {
1656 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1657 }
1658 crate::MathFunction::Log => {
1659 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1660 }
1661 crate::MathFunction::Log2 => {
1662 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1663 }
1664 crate::MathFunction::Pow => {
1665 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1666 Ok([e1.powf(e2)])
1667 })
1668 }
1669
1670 crate::MathFunction::Sign => {
1672 component_wise_signed!(self, span, [arg], |e| {
1673 Ok([if e.is_zero() {
1674 Zero::zero()
1675 } else {
1676 e.signum()
1677 }])
1678 })
1679 }
1680 crate::MathFunction::Fma => {
1681 component_wise_float!(
1682 self,
1683 span,
1684 [arg, arg1.unwrap(), arg2.unwrap()],
1685 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1686 )
1687 }
1688 crate::MathFunction::Step => {
1689 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1690 Float::Abstract([edge, x]) => {
1691 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1692 }
1693 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1694 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1695 f16::one()
1696 } else {
1697 f16::zero()
1698 }])),
1699 })
1700 }
1701 crate::MathFunction::Sqrt => {
1702 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1703 }
1704 crate::MathFunction::InverseSqrt => {
1705 component_wise_float(self, span, [arg], |e| match e {
1706 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1707 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1708 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1709 })
1710 }
1711
1712 crate::MathFunction::CountTrailingZeros => {
1714 component_wise_concrete_int!(self, span, [arg], |e| {
1715 #[allow(clippy::useless_conversion)]
1716 Ok([e
1717 .trailing_zeros()
1718 .try_into()
1719 .expect("bit count overflowed 32 bits, somehow!?")])
1720 })
1721 }
1722 crate::MathFunction::CountLeadingZeros => {
1723 component_wise_concrete_int!(self, span, [arg], |e| {
1724 #[allow(clippy::useless_conversion)]
1725 Ok([e
1726 .leading_zeros()
1727 .try_into()
1728 .expect("bit count overflowed 32 bits, somehow!?")])
1729 })
1730 }
1731 crate::MathFunction::CountOneBits => {
1732 component_wise_concrete_int!(self, span, [arg], |e| {
1733 #[allow(clippy::useless_conversion)]
1734 Ok([e
1735 .count_ones()
1736 .try_into()
1737 .expect("bit count overflowed 32 bits, somehow!?")])
1738 })
1739 }
1740 crate::MathFunction::ReverseBits => {
1741 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1742 }
1743 crate::MathFunction::FirstTrailingBit => {
1744 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1745 }
1746 crate::MathFunction::FirstLeadingBit => {
1747 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1748 }
1749
1750 crate::MathFunction::Dot4I8Packed => {
1752 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1753 }
1754 crate::MathFunction::Dot4U8Packed => {
1755 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1756 }
1757 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1758 crate::MathFunction::Dot => {
1759 let e1 = self.extract_vec(arg, false)?;
1761 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1762 if e1.len() != e2.len() {
1763 return Err(ConstantEvaluatorError::InvalidMathArg);
1764 }
1765
1766 fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1767 where
1768 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1769 {
1770 a.iter()
1771 .zip(b.iter())
1772 .map(|(&aa, bb)| aa.checked_mul(bb))
1773 .try_fold(P::zero(), |acc, x| {
1774 if let Some(x) = x {
1775 acc.checked_add(&x)
1776 } else {
1777 None
1778 }
1779 })
1780 .ok_or(ConstantEvaluatorError::Overflow(
1781 "in dot built-in".to_string(),
1782 ))
1783 }
1784
1785 let result = match_literal_vector!(match (e1, e2) => Literal {
1786 Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1787 Integer => |e1, e2| { int_dot(e1, e2)? },
1788 })?;
1789 self.register_evaluated_expr(Expression::Literal(result), span)
1790 }
1791 crate::MathFunction::Length => {
1792 let e1 = self.extract_vec(arg, true)?;
1794
1795 fn float_length<F>(e: &[F]) -> F
1796 where
1797 F: core::ops::Mul<F>,
1798 F: num_traits::Float + iter::Sum,
1799 {
1800 if e.len() == 1 {
1801 e[0].abs()
1803 } else {
1804 e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1805 }
1806 }
1807
1808 let result = match_literal_vector!(match e1 => Literal {
1809 Float => |e1| { float_length(e1) },
1810 })?;
1811 self.register_evaluated_expr(Expression::Literal(result), span)
1812 }
1813 crate::MathFunction::Distance => {
1814 let e1 = self.extract_vec(arg, true)?;
1816 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1817 if e1.len() != e2.len() {
1818 return Err(ConstantEvaluatorError::InvalidMathArg);
1819 }
1820
1821 fn float_distance<F>(a: &[F], b: &[F]) -> F
1822 where
1823 F: core::ops::Mul<F>,
1824 F: num_traits::Float + iter::Sum + core::ops::Sub,
1825 {
1826 if a.len() == 1 {
1827 (a[0] - b[0]).abs()
1829 } else {
1830 a.iter()
1831 .zip(b.iter())
1832 .map(|(&aa, &bb)| aa - bb)
1833 .map(|ei| ei * ei)
1834 .sum::<F>()
1835 .sqrt()
1836 }
1837 }
1838 let result = match_literal_vector!(match (e1, e2) => Literal {
1839 Float => |e1, e2| { float_distance(e1, e2) },
1840 })?;
1841 self.register_evaluated_expr(Expression::Literal(result), span)
1842 }
1843 crate::MathFunction::Normalize => {
1844 let e1 = self.extract_vec(arg, true)?;
1846
1847 fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1848 where
1849 F: core::ops::Mul<F>,
1850 F: num_traits::Float + iter::Sum,
1851 {
1852 let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1853 e.iter().map(|&ei| ei / len).collect()
1854 }
1855
1856 let result = match_literal_vector!(match e1 => LiteralVector {
1857 Float => |e1| { float_normalize(e1) },
1858 })?;
1859 result.register_as_evaluated_expr(self, span)
1860 }
1861
1862 crate::MathFunction::Modf
1864 | crate::MathFunction::Frexp
1865 | crate::MathFunction::Ldexp
1866 | crate::MathFunction::Outer
1867 | crate::MathFunction::FaceForward
1868 | crate::MathFunction::Reflect
1869 | crate::MathFunction::Refract
1870 | crate::MathFunction::Mix
1871 | crate::MathFunction::SmoothStep
1872 | crate::MathFunction::Inverse
1873 | crate::MathFunction::Transpose
1874 | crate::MathFunction::Determinant
1875 | crate::MathFunction::QuantizeToF16
1876 | crate::MathFunction::ExtractBits
1877 | crate::MathFunction::InsertBits
1878 | crate::MathFunction::Pack4x8snorm
1879 | crate::MathFunction::Pack4x8unorm
1880 | crate::MathFunction::Pack2x16snorm
1881 | crate::MathFunction::Pack2x16unorm
1882 | crate::MathFunction::Pack2x16float
1883 | crate::MathFunction::Pack4xI8
1884 | crate::MathFunction::Pack4xU8
1885 | crate::MathFunction::Pack4xI8Clamp
1886 | crate::MathFunction::Pack4xU8Clamp
1887 | crate::MathFunction::Unpack4x8snorm
1888 | crate::MathFunction::Unpack4x8unorm
1889 | crate::MathFunction::Unpack2x16snorm
1890 | crate::MathFunction::Unpack2x16unorm
1891 | crate::MathFunction::Unpack2x16float
1892 | crate::MathFunction::Unpack4xI8
1893 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1894 format!("{fun:?} built-in function"),
1895 )),
1896 }
1897 }
1898
1899 fn packed_dot_product(
1901 &mut self,
1902 a: Handle<Expression>,
1903 b: Handle<Expression>,
1904 span: Span,
1905 signed: bool,
1906 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1907 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1908 return Err(ConstantEvaluatorError::InvalidMathArg);
1909 };
1910 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1911 return Err(ConstantEvaluatorError::InvalidMathArg);
1912 };
1913
1914 let result = if signed {
1915 Literal::I32(
1916 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1917 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1918 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1919 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1920 )
1921 } else {
1922 Literal::U32(
1923 (a & 0xFF) * (b & 0xFF)
1924 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1925 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1926 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1927 )
1928 };
1929
1930 self.register_evaluated_expr(Expression::Literal(result), span)
1931 }
1932
1933 fn cross_product(
1935 &mut self,
1936 a: Handle<Expression>,
1937 b: Handle<Expression>,
1938 span: Span,
1939 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1940 use Literal as Li;
1941
1942 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1943 let (b, _) = self.extract_vec_with_size::<3>(b)?;
1944
1945 let product = match (a, b) {
1946 (
1947 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1948 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1949 ) => {
1950 let p = cross_product(
1955 [a0 as f64, a1 as f64, a2 as f64],
1956 [b0 as f64, b1 as f64, b2 as f64],
1957 );
1958 [
1959 Li::AbstractFloat(p[0]),
1960 Li::AbstractFloat(p[1]),
1961 Li::AbstractFloat(p[2]),
1962 ]
1963 }
1964 (
1965 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1966 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1967 ) => {
1968 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1969 [
1970 Li::AbstractFloat(p[0]),
1971 Li::AbstractFloat(p[1]),
1972 Li::AbstractFloat(p[2]),
1973 ]
1974 }
1975 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1976 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1977 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1978 }
1979 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1980 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1981 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1982 }
1983 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1984 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1985 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1986 }
1987 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1988 };
1989
1990 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1991 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1992 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1993
1994 self.register_evaluated_expr(
1995 Expression::Compose {
1996 ty,
1997 components: vec![p0, p1, p2],
1998 },
1999 span,
2000 )
2001 }
2002
2003 fn extract_vec_with_size<const N: usize>(
2011 &mut self,
2012 expr: Handle<Expression>,
2013 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2014 let span = self.expressions.get_span(expr);
2015 let expr = self.eval_zero_value_and_splat(expr, span)?;
2016 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2017 return Err(ConstantEvaluatorError::InvalidMathArg);
2018 };
2019
2020 let mut value = [Literal::Bool(false); N];
2021 for (component, elt) in
2022 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2023 .zip(value.iter_mut())
2024 {
2025 let Expression::Literal(literal) = self.expressions[component] else {
2026 return Err(ConstantEvaluatorError::InvalidMathArg);
2027 };
2028 *elt = literal;
2029 }
2030
2031 Ok((value, ty))
2032 }
2033
2034 fn extract_vec(
2042 &mut self,
2043 expr: Handle<Expression>,
2044 allow_single: bool,
2045 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2046 let span = self.expressions.get_span(expr);
2047 let expr = self.eval_zero_value_and_splat(expr, span)?;
2048
2049 match self.expressions[expr] {
2050 Expression::Literal(literal) if allow_single => {
2051 Ok(LiteralVector::from_literal(literal))
2052 }
2053 Expression::Compose { ty, ref components } => {
2054 let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2055 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2056 .map(|expr| match self.expressions[expr] {
2057 Expression::Literal(l) => Ok(l),
2058 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2059 })
2060 .collect::<Result<_, ConstantEvaluatorError>>()?;
2061 LiteralVector::from_literal_vec(components)
2062 }
2063 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2064 }
2065 }
2066
2067 fn array_length(
2068 &mut self,
2069 array: Handle<Expression>,
2070 span: Span,
2071 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2072 match self.expressions[array] {
2073 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2074 match self.types[ty].inner {
2075 TypeInner::Array { size, .. } => match size {
2076 ArraySize::Constant(len) => {
2077 let expr = Expression::Literal(Literal::U32(len.get()));
2078 self.register_evaluated_expr(expr, span)
2079 }
2080 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2081 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2082 },
2083 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2084 }
2085 }
2086 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2087 }
2088 }
2089
2090 fn access(
2091 &mut self,
2092 base: Handle<Expression>,
2093 index: usize,
2094 span: Span,
2095 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2096 match self.expressions[base] {
2097 Expression::ZeroValue(ty) => {
2098 let ty_inner = &self.types[ty].inner;
2099 let components = ty_inner
2100 .components()
2101 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2102
2103 if index >= components as usize {
2104 Err(ConstantEvaluatorError::InvalidAccessBase)
2105 } else {
2106 let ty_res = ty_inner
2107 .component_type(index)
2108 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2109 let ty = match ty_res {
2110 crate::proc::TypeResolution::Handle(ty) => ty,
2111 crate::proc::TypeResolution::Value(inner) => {
2112 self.types.insert(Type { name: None, inner }, span)
2113 }
2114 };
2115 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2116 }
2117 }
2118 Expression::Splat { size, value } => {
2119 if index >= size as usize {
2120 Err(ConstantEvaluatorError::InvalidAccessBase)
2121 } else {
2122 Ok(value)
2123 }
2124 }
2125 Expression::Compose { ty, ref components } => {
2126 let _ = self.types[ty]
2127 .inner
2128 .components()
2129 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2130
2131 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2132 .nth(index)
2133 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2134 }
2135 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2136 }
2137 }
2138
2139 fn eval_zero_value_and_splat(
2146 &mut self,
2147 mut expr: Handle<Expression>,
2148 span: Span,
2149 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2150 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2153 let components = components
2154 .clone()
2155 .iter()
2156 .map(|component| self.eval_zero_value_and_splat(*component, span))
2157 .collect::<Result<_, _>>()?;
2158 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2159 }
2160
2161 if let Expression::Splat { size, value } = self.expressions[expr] {
2165 expr = self.splat(value, size, span)?;
2166 }
2167 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2168 expr = self.eval_zero_value_impl(ty, span)?;
2169 }
2170 Ok(expr)
2171 }
2172
2173 fn eval_zero_value(
2179 &mut self,
2180 expr: Handle<Expression>,
2181 span: Span,
2182 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2183 match self.expressions[expr] {
2184 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2185 _ => Ok(expr),
2186 }
2187 }
2188
2189 fn eval_zero_value_impl(
2195 &mut self,
2196 ty: Handle<Type>,
2197 span: Span,
2198 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2199 match self.types[ty].inner {
2200 TypeInner::Scalar(scalar) => {
2201 let expr = Expression::Literal(
2202 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2203 );
2204 self.register_evaluated_expr(expr, span)
2205 }
2206 TypeInner::Vector { size, scalar } => {
2207 let scalar_ty = self.types.insert(
2208 Type {
2209 name: None,
2210 inner: TypeInner::Scalar(scalar),
2211 },
2212 span,
2213 );
2214 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2215 let expr = Expression::Compose {
2216 ty,
2217 components: vec![el; size as usize],
2218 };
2219 self.register_evaluated_expr(expr, span)
2220 }
2221 TypeInner::Matrix {
2222 columns,
2223 rows,
2224 scalar,
2225 } => {
2226 let vec_ty = self.types.insert(
2227 Type {
2228 name: None,
2229 inner: TypeInner::Vector { size: rows, scalar },
2230 },
2231 span,
2232 );
2233 let el = self.eval_zero_value_impl(vec_ty, span)?;
2234 let expr = Expression::Compose {
2235 ty,
2236 components: vec![el; columns as usize],
2237 };
2238 self.register_evaluated_expr(expr, span)
2239 }
2240 TypeInner::Array {
2241 base,
2242 size: ArraySize::Constant(size),
2243 ..
2244 } => {
2245 let el = self.eval_zero_value_impl(base, span)?;
2246 let expr = Expression::Compose {
2247 ty,
2248 components: vec![el; size.get() as usize],
2249 };
2250 self.register_evaluated_expr(expr, span)
2251 }
2252 TypeInner::Struct { ref members, .. } => {
2253 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2254 let mut components = Vec::with_capacity(members.len());
2255 for ty in types {
2256 components.push(self.eval_zero_value_impl(ty, span)?);
2257 }
2258 let expr = Expression::Compose { ty, components };
2259 self.register_evaluated_expr(expr, span)
2260 }
2261 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2262 }
2263 }
2264
2265 pub fn cast(
2269 &mut self,
2270 expr: Handle<Expression>,
2271 target: crate::Scalar,
2272 span: Span,
2273 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2274 use crate::Scalar as Sc;
2275
2276 let expr = self.eval_zero_value(expr, span)?;
2277
2278 let make_error = || -> Result<_, ConstantEvaluatorError> {
2279 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2280
2281 #[cfg(feature = "wgsl-in")]
2282 let to = target.to_wgsl_for_diagnostics();
2283
2284 #[cfg(not(feature = "wgsl-in"))]
2285 let to = format!("{target:?}");
2286
2287 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2288 };
2289
2290 use crate::proc::type_methods::IntFloatLimits;
2291
2292 let expr = match self.expressions[expr] {
2293 Expression::Literal(literal) => {
2294 let literal = match target {
2295 Sc::I32 => Literal::I32(match literal {
2296 Literal::I32(v) => v,
2297 Literal::U32(v) => v as i32,
2298 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2299 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2301 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2302 return make_error();
2303 }
2304 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2305 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2306 }),
2307 Sc::U32 => Literal::U32(match literal {
2308 Literal::I32(v) => v as u32,
2309 Literal::U32(v) => v,
2310 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2311 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2313 Literal::Bool(v) => v as u32,
2314 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2315 return make_error();
2316 }
2317 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2318 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2319 }),
2320 Sc::I64 => Literal::I64(match literal {
2321 Literal::I32(v) => v as i64,
2322 Literal::U32(v) => v as i64,
2323 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2324 Literal::Bool(v) => v as i64,
2325 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2326 Literal::I64(v) => v,
2327 Literal::U64(v) => v as i64,
2328 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2330 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2331 }),
2332 Sc::U64 => Literal::U64(match literal {
2333 Literal::I32(v) => v as u64,
2334 Literal::U32(v) => v as u64,
2335 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2336 Literal::Bool(v) => v as u64,
2337 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2338 Literal::I64(v) => v as u64,
2339 Literal::U64(v) => v,
2340 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2342 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2343 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2344 }),
2345 Sc::F16 => Literal::F16(match literal {
2346 Literal::F16(v) => v,
2347 Literal::F32(v) => f16::from_f32(v),
2348 Literal::F64(v) => f16::from_f64(v),
2349 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2350 Literal::I64(v) => f16::from_i64(v).unwrap(),
2351 Literal::U64(v) => f16::from_u64(v).unwrap(),
2352 Literal::I32(v) => f16::from_i32(v).unwrap(),
2353 Literal::U32(v) => f16::from_u32(v).unwrap(),
2354 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2355 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2356 }),
2357 Sc::F32 => Literal::F32(match literal {
2358 Literal::I32(v) => v as f32,
2359 Literal::U32(v) => v as f32,
2360 Literal::F32(v) => v,
2361 Literal::Bool(v) => v as u32 as f32,
2362 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2363 return make_error();
2364 }
2365 Literal::F16(v) => f16::to_f32(v),
2366 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2367 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2368 }),
2369 Sc::F64 => Literal::F64(match literal {
2370 Literal::I32(v) => v as f64,
2371 Literal::U32(v) => v as f64,
2372 Literal::F16(v) => f16::to_f64(v),
2373 Literal::F32(v) => v as f64,
2374 Literal::F64(v) => v,
2375 Literal::Bool(v) => v as u32 as f64,
2376 Literal::I64(_) | Literal::U64(_) => return make_error(),
2377 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2378 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2379 }),
2380 Sc::BOOL => Literal::Bool(match literal {
2381 Literal::I32(v) => v != 0,
2382 Literal::U32(v) => v != 0,
2383 Literal::F32(v) => v != 0.0,
2384 Literal::F16(v) => v != f16::zero(),
2385 Literal::Bool(v) => v,
2386 Literal::AbstractInt(v) => v != 0,
2387 Literal::AbstractFloat(v) => v != 0.0,
2388 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2389 return make_error();
2390 }
2391 }),
2392 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2393 Literal::AbstractInt(v) => {
2394 v as f64
2399 }
2400 Literal::AbstractFloat(v) => v,
2401 _ => return make_error(),
2402 }),
2403 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2404 Literal::AbstractInt(v) => v,
2405 _ => return make_error(),
2406 }),
2407 _ => {
2408 log::debug!("Constant evaluator refused to convert value to {target:?}");
2409 return make_error();
2410 }
2411 };
2412 Expression::Literal(literal)
2413 }
2414 Expression::Compose {
2415 ty,
2416 components: ref src_components,
2417 } => {
2418 let ty_inner = match self.types[ty].inner {
2419 TypeInner::Vector { size, .. } => TypeInner::Vector {
2420 size,
2421 scalar: target,
2422 },
2423 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2424 columns,
2425 rows,
2426 scalar: target,
2427 },
2428 _ => return make_error(),
2429 };
2430
2431 let mut components = src_components.clone();
2432 for component in &mut components {
2433 *component = self.cast(*component, target, span)?;
2434 }
2435
2436 let ty = self.types.insert(
2437 Type {
2438 name: None,
2439 inner: ty_inner,
2440 },
2441 span,
2442 );
2443
2444 Expression::Compose { ty, components }
2445 }
2446 Expression::Splat { size, value } => {
2447 let value_span = self.expressions.get_span(value);
2448 let cast_value = self.cast(value, target, value_span)?;
2449 Expression::Splat {
2450 size,
2451 value: cast_value,
2452 }
2453 }
2454 _ => return make_error(),
2455 };
2456
2457 self.register_evaluated_expr(expr, span)
2458 }
2459
2460 pub fn cast_array(
2473 &mut self,
2474 expr: Handle<Expression>,
2475 target: crate::Scalar,
2476 span: Span,
2477 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2478 let expr = self.check_and_get(expr)?;
2479
2480 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2481 return self.cast(expr, target, span);
2482 };
2483
2484 let TypeInner::Array {
2485 base: _,
2486 size,
2487 stride: _,
2488 } = self.types[ty].inner
2489 else {
2490 return self.cast(expr, target, span);
2491 };
2492
2493 let mut components = components.clone();
2494 for component in &mut components {
2495 *component = self.cast_array(*component, target, span)?;
2496 }
2497
2498 let first = components.first().unwrap();
2499 let new_base = match self.resolve_type(*first)? {
2500 crate::proc::TypeResolution::Handle(ty) => ty,
2501 crate::proc::TypeResolution::Value(inner) => {
2502 self.types.insert(Type { name: None, inner }, span)
2503 }
2504 };
2505 let mut layouter = core::mem::take(self.layouter);
2506 layouter.update(self.to_ctx()).unwrap();
2507 *self.layouter = layouter;
2508
2509 let new_base_stride = self.layouter[new_base].to_stride();
2510 let new_array_ty = self.types.insert(
2511 Type {
2512 name: None,
2513 inner: TypeInner::Array {
2514 base: new_base,
2515 size,
2516 stride: new_base_stride,
2517 },
2518 },
2519 span,
2520 );
2521
2522 let compose = Expression::Compose {
2523 ty: new_array_ty,
2524 components,
2525 };
2526 self.register_evaluated_expr(compose, span)
2527 }
2528
2529 fn unary_op(
2530 &mut self,
2531 op: UnaryOperator,
2532 expr: Handle<Expression>,
2533 span: Span,
2534 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2535 let expr = self.eval_zero_value_and_splat(expr, span)?;
2536
2537 let expr = match self.expressions[expr] {
2538 Expression::Literal(value) => Expression::Literal(match op {
2539 UnaryOperator::Negate => match value {
2540 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2541 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2542 Literal::F32(v) => Literal::F32(-v),
2543 Literal::F16(v) => Literal::F16(-v),
2544 Literal::F64(v) => Literal::F64(-v),
2545 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2546 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2547 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2548 },
2549 UnaryOperator::LogicalNot => match value {
2550 Literal::Bool(v) => Literal::Bool(!v),
2551 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2552 },
2553 UnaryOperator::BitwiseNot => match value {
2554 Literal::I32(v) => Literal::I32(!v),
2555 Literal::I64(v) => Literal::I64(!v),
2556 Literal::U32(v) => Literal::U32(!v),
2557 Literal::U64(v) => Literal::U64(!v),
2558 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2559 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2560 },
2561 }),
2562 Expression::Compose {
2563 ty,
2564 components: ref src_components,
2565 } => {
2566 match self.types[ty].inner {
2567 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2568 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2569 }
2570
2571 let mut components = src_components.clone();
2572 for component in &mut components {
2573 *component = self.unary_op(op, *component, span)?;
2574 }
2575
2576 Expression::Compose { ty, components }
2577 }
2578 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2579 };
2580
2581 self.register_evaluated_expr(expr, span)
2582 }
2583
2584 fn binary_op(
2585 &mut self,
2586 op: BinaryOperator,
2587 left: Handle<Expression>,
2588 right: Handle<Expression>,
2589 span: Span,
2590 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2591 let left = self.eval_zero_value_and_splat(left, span)?;
2592 let right = self.eval_zero_value_and_splat(right, span)?;
2593
2594 let expr = match (&self.expressions[left], &self.expressions[right]) {
2599 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2600 let literal = match op {
2601 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2602 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2603 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2604 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2605 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2606 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2607
2608 _ => match (left_value, right_value) {
2609 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2610 BinaryOperator::Add => a.wrapping_add(b),
2611 BinaryOperator::Subtract => a.wrapping_sub(b),
2612 BinaryOperator::Multiply => a.wrapping_mul(b),
2613 BinaryOperator::Divide => {
2614 if b == 0 {
2615 return Err(ConstantEvaluatorError::DivisionByZero);
2616 } else {
2617 a.wrapping_div(b)
2618 }
2619 }
2620 BinaryOperator::Modulo => {
2621 if b == 0 {
2622 return Err(ConstantEvaluatorError::RemainderByZero);
2623 } else {
2624 a.wrapping_rem(b)
2625 }
2626 }
2627 BinaryOperator::And => a & b,
2628 BinaryOperator::ExclusiveOr => a ^ b,
2629 BinaryOperator::InclusiveOr => a | b,
2630 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2631 }),
2632 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2633 BinaryOperator::ShiftLeft => {
2634 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2635 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2636 }
2637 a.checked_shl(b)
2638 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2639 }
2640 BinaryOperator::ShiftRight => a
2641 .checked_shr(b)
2642 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2643 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2644 }),
2645 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2646 BinaryOperator::Add => a.wrapping_add(b),
2647 BinaryOperator::Subtract => a.wrapping_sub(b),
2648 BinaryOperator::Multiply => a.wrapping_mul(b),
2649 BinaryOperator::Divide => a
2650 .checked_div(b)
2651 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2652 BinaryOperator::Modulo => a
2653 .checked_rem(b)
2654 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2655 BinaryOperator::And => a & b,
2656 BinaryOperator::ExclusiveOr => a ^ b,
2657 BinaryOperator::InclusiveOr => a | b,
2658 BinaryOperator::ShiftLeft => a
2659 .checked_mul(
2660 1u32.checked_shl(b)
2661 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2662 )
2663 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2664 BinaryOperator::ShiftRight => a
2665 .checked_shr(b)
2666 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2667 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2668 }),
2669 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2670 BinaryOperator::Add => a + b,
2671 BinaryOperator::Subtract => a - b,
2672 BinaryOperator::Multiply => a * b,
2673 BinaryOperator::Divide => a / b,
2674 BinaryOperator::Modulo => a % b,
2675 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2676 }),
2677 (Literal::AbstractInt(a), Literal::U32(b)) => {
2678 Literal::AbstractInt(match op {
2679 BinaryOperator::ShiftLeft => {
2680 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2681 return Err(ConstantEvaluatorError::Overflow(
2682 "<<".to_string(),
2683 ));
2684 }
2685 a.checked_shl(b).unwrap_or(0)
2686 }
2687 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2688 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2689 })
2690 }
2691 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2692 BinaryOperator::Add => a + b,
2693 BinaryOperator::Subtract => a - b,
2694 BinaryOperator::Multiply => a * b,
2695 BinaryOperator::Divide => a / b,
2696 BinaryOperator::Modulo => a % b,
2697 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2698 }),
2699 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2700 Literal::AbstractInt(match op {
2701 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2702 ConstantEvaluatorError::Overflow("addition".into())
2703 })?,
2704 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2705 ConstantEvaluatorError::Overflow("subtraction".into())
2706 })?,
2707 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2708 ConstantEvaluatorError::Overflow("multiplication".into())
2709 })?,
2710 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2711 if b == 0 {
2712 ConstantEvaluatorError::DivisionByZero
2713 } else {
2714 ConstantEvaluatorError::Overflow("division".into())
2715 }
2716 })?,
2717 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2718 if b == 0 {
2719 ConstantEvaluatorError::RemainderByZero
2720 } else {
2721 ConstantEvaluatorError::Overflow("remainder".into())
2722 }
2723 })?,
2724 BinaryOperator::And => a & b,
2725 BinaryOperator::ExclusiveOr => a ^ b,
2726 BinaryOperator::InclusiveOr => a | b,
2727 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2728 })
2729 }
2730 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2731 Literal::AbstractFloat(match op {
2732 BinaryOperator::Add => a + b,
2733 BinaryOperator::Subtract => a - b,
2734 BinaryOperator::Multiply => a * b,
2735 BinaryOperator::Divide => a / b,
2736 BinaryOperator::Modulo => a % b,
2737 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2738 })
2739 }
2740 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2741 BinaryOperator::LogicalAnd => a && b,
2742 BinaryOperator::LogicalOr => a || b,
2743 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2744 }),
2745 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2746 },
2747 };
2748 Expression::Literal(literal)
2749 }
2750 (
2751 &Expression::Compose {
2752 components: ref src_components,
2753 ty,
2754 },
2755 &Expression::Literal(_),
2756 ) => match op {
2757 BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
2758 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2759 }
2760 _ => {
2761 let mut components = src_components.clone();
2762 for component in &mut components {
2763 *component = self.binary_op(op, *component, right, span)?;
2764 }
2765 Expression::Compose { ty, components }
2766 }
2767 },
2768 (
2769 &Expression::Literal(_),
2770 &Expression::Compose {
2771 components: ref src_components,
2772 ty,
2773 },
2774 ) => match op {
2775 BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
2776 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2777 }
2778 _ => {
2779 let mut components = src_components.clone();
2780 for component in &mut components {
2781 *component = self.binary_op(op, left, *component, span)?;
2782 }
2783 Expression::Compose { ty, components }
2784 }
2785 },
2786 (
2787 &Expression::Compose {
2788 components: ref left_components,
2789 ty: left_ty,
2790 },
2791 &Expression::Compose {
2792 components: ref right_components,
2793 ty: right_ty,
2794 },
2795 ) => {
2796 let left_flattened = crate::proc::flatten_compose(
2800 left_ty,
2801 left_components,
2802 self.expressions,
2803 self.types,
2804 );
2805 let right_flattened = crate::proc::flatten_compose(
2806 right_ty,
2807 right_components,
2808 self.expressions,
2809 self.types,
2810 );
2811
2812 let mut flattened = Vec::with_capacity(left_components.len());
2815 flattened.extend(left_flattened.zip(right_flattened));
2816
2817 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2818 (
2819 &TypeInner::Vector {
2820 size: left_size, ..
2821 },
2822 &TypeInner::Vector {
2823 size: right_size, ..
2824 },
2825 ) if left_size == right_size => {
2826 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2827 }
2828 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2829 }
2830 }
2831 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2832 };
2833
2834 self.register_evaluated_expr(expr, span)
2835 }
2836
2837 fn binary_op_vector(
2838 &mut self,
2839 op: BinaryOperator,
2840 size: crate::VectorSize,
2841 components: &[(Handle<Expression>, Handle<Expression>)],
2842 left_ty: Handle<Type>,
2843 span: Span,
2844 ) -> Result<Expression, ConstantEvaluatorError> {
2845 let ty = match op {
2846 BinaryOperator::Equal
2848 | BinaryOperator::NotEqual
2849 | BinaryOperator::Less
2850 | BinaryOperator::LessEqual
2851 | BinaryOperator::Greater
2852 | BinaryOperator::GreaterEqual => self.types.insert(
2853 Type {
2854 name: None,
2855 inner: TypeInner::Vector {
2856 size,
2857 scalar: crate::Scalar::BOOL,
2858 },
2859 },
2860 span,
2861 ),
2862
2863 BinaryOperator::Add
2866 | BinaryOperator::Subtract
2867 | BinaryOperator::Multiply
2868 | BinaryOperator::Divide
2869 | BinaryOperator::Modulo
2870 | BinaryOperator::And
2871 | BinaryOperator::ExclusiveOr
2872 | BinaryOperator::InclusiveOr
2873 | BinaryOperator::ShiftLeft
2874 | BinaryOperator::ShiftRight => left_ty,
2875
2876 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2877 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2879 }
2880 };
2881
2882 let components = components
2883 .iter()
2884 .map(|&(left, right)| self.binary_op(op, left, right, span))
2885 .collect::<Result<Vec<_>, _>>()?;
2886
2887 Ok(Expression::Compose { ty, components })
2888 }
2889
2890 fn relational(
2891 &mut self,
2892 fun: RelationalFunction,
2893 arg: Handle<Expression>,
2894 span: Span,
2895 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2896 let arg = self.eval_zero_value_and_splat(arg, span)?;
2897 match fun {
2898 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2899 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2900 Expression::Compose { ty, ref components }
2901 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2902 {
2903 let components =
2904 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2905 .map(|component| match self.expressions[component] {
2906 Expression::Literal(Literal::Bool(val)) => Ok(val),
2907 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2908 })
2909 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2910 let result = match fun {
2911 RelationalFunction::All => components.iter().all(|c| *c),
2912 RelationalFunction::Any => components.iter().any(|c| *c),
2913 _ => unreachable!(),
2914 };
2915 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2916 }
2917 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2918 },
2919 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2920 "{fun:?} built-in function"
2921 ))),
2922 }
2923 }
2924
2925 fn copy_from(
2933 &mut self,
2934 expr: Handle<Expression>,
2935 expressions: &Arena<Expression>,
2936 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2937 let span = expressions.get_span(expr);
2938 match expressions[expr] {
2939 ref expr @ (Expression::Literal(_)
2940 | Expression::Constant(_)
2941 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2942 Expression::Compose { ty, ref components } => {
2943 let mut components = components.clone();
2944 for component in &mut components {
2945 *component = self.copy_from(*component, expressions)?;
2946 }
2947 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2948 }
2949 Expression::Splat { size, value } => {
2950 let value = self.copy_from(value, expressions)?;
2951 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2952 }
2953 _ => {
2954 log::debug!("copy_from: SubexpressionsAreNotConstant");
2955 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2956 }
2957 }
2958 }
2959
2960 fn vector_compose_flattened_size(
2962 &self,
2963 components: &[Handle<Expression>],
2964 ) -> Result<usize, ConstantEvaluatorError> {
2965 components
2966 .iter()
2967 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2968 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2969 TypeInner::Scalar(_) => 1,
2970 TypeInner::Vector { size, .. } => size as usize,
2974 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2975 };
2976 Ok(acc + size)
2977 })
2978 }
2979
2980 fn register_evaluated_expr(
2981 &mut self,
2982 expr: Expression,
2983 span: Span,
2984 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2985 if let Expression::Literal(literal) = expr {
2990 crate::valid::check_literal_value(literal)?;
2991 }
2992
2993 if let Expression::Compose { ty, ref components } = expr {
2997 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2998 let expected = size as usize;
2999 let actual = self.vector_compose_flattened_size(components)?;
3000 if expected != actual {
3001 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3002 expected,
3003 actual,
3004 });
3005 }
3006 }
3007 }
3008
3009 Ok(self.append_expr(expr, span, ExpressionKind::Const))
3010 }
3011
3012 fn append_expr(
3013 &mut self,
3014 expr: Expression,
3015 span: Span,
3016 expr_type: ExpressionKind,
3017 ) -> Handle<Expression> {
3018 let h = match self.behavior {
3019 Behavior::Wgsl(
3020 WgslRestrictions::Runtime(ref mut function_local_data)
3021 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3022 )
3023 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3024 let is_running = function_local_data.emitter.is_running();
3025 let needs_pre_emit = expr.needs_pre_emit();
3026 if is_running && needs_pre_emit {
3027 function_local_data
3028 .block
3029 .extend(function_local_data.emitter.finish(self.expressions));
3030 let h = self.expressions.append(expr, span);
3031 function_local_data.emitter.start(self.expressions);
3032 h
3033 } else {
3034 self.expressions.append(expr, span)
3035 }
3036 }
3037 _ => self.expressions.append(expr, span),
3038 };
3039 self.expression_kind_tracker.insert(h, expr_type);
3040 h
3041 }
3042
3043 fn resolve_type(
3048 &self,
3049 expr: Handle<Expression>,
3050 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3051 use crate::proc::TypeResolution as Tr;
3052 use crate::Expression as Ex;
3053 let resolution = match self.expressions[expr] {
3054 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3055 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3056 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3057 Ex::Splat { size, value } => {
3058 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3059 return Err(ConstantEvaluatorError::SplatScalarOnly);
3060 };
3061 Tr::Value(TypeInner::Vector { scalar, size })
3062 }
3063 _ => {
3064 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3065 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3066 }
3067 };
3068
3069 Ok(resolution)
3070 }
3071
3072 fn select(
3073 &mut self,
3074 reject: Handle<Expression>,
3075 accept: Handle<Expression>,
3076 condition: Handle<Expression>,
3077 span: Span,
3078 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3079 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3080
3081 let reject = arg(reject)?;
3082 let accept = arg(accept)?;
3083 let condition = arg(condition)?;
3084
3085 let select_single_component =
3086 |this: &mut Self, reject_scalar, reject, accept, condition| {
3087 let accept = this.cast(accept, reject_scalar, span)?;
3088 if condition {
3089 Ok(accept)
3090 } else {
3091 Ok(reject)
3092 }
3093 };
3094
3095 match (&self.expressions[reject], &self.expressions[accept]) {
3096 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3097 let reject_scalar = reject_lit.scalar();
3098 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3099 else {
3100 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3101 };
3102 select_single_component(self, reject_scalar, reject, accept, condition)
3103 }
3104 (
3105 &Expression::Compose {
3106 ty: reject_ty,
3107 components: ref reject_components,
3108 },
3109 &Expression::Compose {
3110 ty: accept_ty,
3111 components: ref accept_components,
3112 },
3113 ) => {
3114 let ty_deets = |ty| {
3115 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3116 (size.unwrap(), scalar)
3117 };
3118
3119 let expected_vec_size = {
3120 let [(reject_vec_size, _), (accept_vec_size, _)] =
3121 [reject_ty, accept_ty].map(ty_deets);
3122
3123 if reject_vec_size != accept_vec_size {
3124 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3125 reject: reject_vec_size,
3126 accept: accept_vec_size,
3127 });
3128 }
3129 reject_vec_size
3130 };
3131
3132 let condition_components = match self.expressions[condition] {
3133 Expression::Literal(Literal::Bool(condition)) => {
3134 vec![condition; (expected_vec_size as u8).into()]
3135 }
3136 Expression::Compose {
3137 ty: condition_ty,
3138 components: ref condition_components,
3139 } => {
3140 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3141 if condition_scalar.kind != ScalarKind::Bool {
3142 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3143 }
3144 if condition_vec_size != expected_vec_size {
3145 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3146 }
3147 condition_components
3148 .iter()
3149 .copied()
3150 .map(|component| match &self.expressions[component] {
3151 &Expression::Literal(Literal::Bool(condition)) => condition,
3152 _ => unreachable!(),
3153 })
3154 .collect()
3155 }
3156
3157 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3158 };
3159
3160 let evaluated = Expression::Compose {
3161 ty: reject_ty,
3162 components: reject_components
3163 .clone()
3164 .into_iter()
3165 .zip(accept_components.clone().into_iter())
3166 .zip(condition_components.into_iter())
3167 .map(|((reject, accept), condition)| {
3168 let reject_scalar = match &self.expressions[reject] {
3169 &Expression::Literal(lit) => lit.scalar(),
3170 _ => unreachable!(),
3171 };
3172 select_single_component(self, reject_scalar, reject, accept, condition)
3173 })
3174 .collect::<Result<_, _>>()?,
3175 };
3176 self.register_evaluated_expr(evaluated, span)
3177 }
3178 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3179 }
3180 }
3181}
3182
3183fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3184 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3188 match e {
3189 idx @ 0..=31 => idx,
3190 32 => u32::MAX,
3191 _ => unreachable!(),
3192 }
3193 };
3194 match concrete_int {
3195 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3196 ConcreteInt::I32([e]) => {
3197 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3198 }
3199 }
3200}
3201
3202#[test]
3203fn first_trailing_bit_smoke() {
3204 assert_eq!(
3205 first_trailing_bit(ConcreteInt::I32([0])),
3206 ConcreteInt::I32([-1])
3207 );
3208 assert_eq!(
3209 first_trailing_bit(ConcreteInt::I32([1])),
3210 ConcreteInt::I32([0])
3211 );
3212 assert_eq!(
3213 first_trailing_bit(ConcreteInt::I32([2])),
3214 ConcreteInt::I32([1])
3215 );
3216 assert_eq!(
3217 first_trailing_bit(ConcreteInt::I32([-1])),
3218 ConcreteInt::I32([0]),
3219 );
3220 assert_eq!(
3221 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3222 ConcreteInt::I32([31]),
3223 );
3224 assert_eq!(
3225 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3226 ConcreteInt::I32([0]),
3227 );
3228 for idx in 0..32 {
3229 assert_eq!(
3230 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3231 ConcreteInt::I32([idx])
3232 )
3233 }
3234
3235 assert_eq!(
3236 first_trailing_bit(ConcreteInt::U32([0])),
3237 ConcreteInt::U32([u32::MAX])
3238 );
3239 assert_eq!(
3240 first_trailing_bit(ConcreteInt::U32([1])),
3241 ConcreteInt::U32([0])
3242 );
3243 assert_eq!(
3244 first_trailing_bit(ConcreteInt::U32([2])),
3245 ConcreteInt::U32([1])
3246 );
3247 assert_eq!(
3248 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3249 ConcreteInt::U32([31]),
3250 );
3251 assert_eq!(
3252 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3253 ConcreteInt::U32([0]),
3254 );
3255 for idx in 0..32 {
3256 assert_eq!(
3257 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3258 ConcreteInt::U32([idx])
3259 )
3260 }
3261}
3262
3263fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3264 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3268 match e {
3269 idx @ 0..=31 => 31 - idx,
3270 32 => u32::MAX,
3271 _ => unreachable!(),
3272 }
3273 };
3274 match concrete_int {
3275 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3276 let rtl_bit_index = if e.is_negative() {
3277 e.leading_ones()
3278 } else {
3279 e.leading_zeros()
3280 };
3281 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3282 }]),
3283 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3284 }
3285}
3286
3287#[test]
3288fn first_leading_bit_smoke() {
3289 assert_eq!(
3290 first_leading_bit(ConcreteInt::I32([-1])),
3291 ConcreteInt::I32([-1])
3292 );
3293 assert_eq!(
3294 first_leading_bit(ConcreteInt::I32([0])),
3295 ConcreteInt::I32([-1])
3296 );
3297 assert_eq!(
3298 first_leading_bit(ConcreteInt::I32([1])),
3299 ConcreteInt::I32([0])
3300 );
3301 assert_eq!(
3302 first_leading_bit(ConcreteInt::I32([-2])),
3303 ConcreteInt::I32([0])
3304 );
3305 assert_eq!(
3306 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3307 ConcreteInt::I32([12])
3308 );
3309 assert_eq!(
3310 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3311 ConcreteInt::I32([30])
3312 );
3313 assert_eq!(
3314 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3315 ConcreteInt::I32([30])
3316 );
3317 for idx in 0..(32 - 1) {
3319 assert_eq!(
3320 first_leading_bit(ConcreteInt::I32([1 << idx])),
3321 ConcreteInt::I32([idx])
3322 );
3323 }
3324 for idx in 1..(32 - 1) {
3325 assert_eq!(
3326 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3327 ConcreteInt::I32([idx - 1])
3328 );
3329 }
3330
3331 assert_eq!(
3332 first_leading_bit(ConcreteInt::U32([0])),
3333 ConcreteInt::U32([u32::MAX])
3334 );
3335 assert_eq!(
3336 first_leading_bit(ConcreteInt::U32([1])),
3337 ConcreteInt::U32([0])
3338 );
3339 assert_eq!(
3340 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3341 ConcreteInt::U32([31])
3342 );
3343 for idx in 0..32 {
3344 assert_eq!(
3345 first_leading_bit(ConcreteInt::U32([1 << idx])),
3346 ConcreteInt::U32([idx])
3347 )
3348 }
3349}
3350
3351trait TryFromAbstract<T>: Sized {
3353 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3375}
3376
3377impl TryFromAbstract<i64> for i32 {
3378 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3379 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3380 value: format!("{value:?}"),
3381 to_type: "i32",
3382 })
3383 }
3384}
3385
3386impl TryFromAbstract<i64> for u32 {
3387 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3388 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3389 value: format!("{value:?}"),
3390 to_type: "u32",
3391 })
3392 }
3393}
3394
3395impl TryFromAbstract<i64> for u64 {
3396 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3397 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3398 value: format!("{value:?}"),
3399 to_type: "u64",
3400 })
3401 }
3402}
3403
3404impl TryFromAbstract<i64> for i64 {
3405 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3406 Ok(value)
3407 }
3408}
3409
3410impl TryFromAbstract<i64> for f32 {
3411 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3412 let f = value as f32;
3413 Ok(f)
3417 }
3418}
3419
3420impl TryFromAbstract<f64> for f32 {
3421 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3422 let f = value as f32;
3423 if f.is_infinite() {
3424 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3425 value: format!("{value:?}"),
3426 to_type: "f32",
3427 });
3428 }
3429 Ok(f)
3430 }
3431}
3432
3433impl TryFromAbstract<i64> for f64 {
3434 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3435 let f = value as f64;
3436 Ok(f)
3440 }
3441}
3442
3443impl TryFromAbstract<f64> for f64 {
3444 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3445 Ok(value)
3446 }
3447}
3448
3449impl TryFromAbstract<f64> for i32 {
3450 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3451 Ok(value as i32)
3464 }
3465}
3466
3467impl TryFromAbstract<f64> for u32 {
3468 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3469 Ok(value as u32)
3472 }
3473}
3474
3475impl TryFromAbstract<f64> for i64 {
3476 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3477 use crate::proc::type_methods::IntFloatLimits;
3480 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3481 }
3482}
3483
3484impl TryFromAbstract<f64> for u64 {
3485 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3486 use crate::proc::type_methods::IntFloatLimits;
3489 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3490 }
3491}
3492
3493impl TryFromAbstract<f64> for f16 {
3494 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3495 let f = f16::from_f64(value);
3496 if f.is_infinite() {
3497 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3498 value: format!("{value:?}"),
3499 to_type: "f16",
3500 });
3501 }
3502 Ok(f)
3503 }
3504}
3505
3506impl TryFromAbstract<i64> for f16 {
3507 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3508 let f = f16::from_i64(value);
3509 if f.is_none() {
3510 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3511 value: format!("{value:?}"),
3512 to_type: "f16",
3513 });
3514 }
3515 Ok(f.unwrap())
3516 }
3517}
3518
3519fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3520where
3521 T: Copy,
3522 T: core::ops::Mul<T, Output = T>,
3523 T: core::ops::Sub<T, Output = T>,
3524{
3525 [
3526 a[1] * b[2] - a[2] * b[1],
3527 a[2] * b[0] - a[0] * b[2],
3528 a[0] * b[1] - a[1] * b[0],
3529 ]
3530}
3531
3532#[cfg(test)]
3533mod tests {
3534 use alloc::{vec, vec::Vec};
3535
3536 use crate::{
3537 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
3538 UniqueArena, VectorSize,
3539 };
3540
3541 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3542
3543 #[test]
3544 fn unary_op() {
3545 let mut types = UniqueArena::new();
3546 let mut constants = Arena::new();
3547 let overrides = Arena::new();
3548 let mut global_expressions = Arena::new();
3549
3550 let scalar_ty = types.insert(
3551 Type {
3552 name: None,
3553 inner: TypeInner::Scalar(crate::Scalar::I32),
3554 },
3555 Default::default(),
3556 );
3557
3558 let vec_ty = types.insert(
3559 Type {
3560 name: None,
3561 inner: TypeInner::Vector {
3562 size: VectorSize::Bi,
3563 scalar: crate::Scalar::I32,
3564 },
3565 },
3566 Default::default(),
3567 );
3568
3569 let h = constants.append(
3570 Constant {
3571 name: None,
3572 ty: scalar_ty,
3573 init: global_expressions
3574 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3575 },
3576 Default::default(),
3577 );
3578
3579 let h1 = constants.append(
3580 Constant {
3581 name: None,
3582 ty: scalar_ty,
3583 init: global_expressions
3584 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3585 },
3586 Default::default(),
3587 );
3588
3589 let vec_h = constants.append(
3590 Constant {
3591 name: None,
3592 ty: vec_ty,
3593 init: global_expressions.append(
3594 Expression::Compose {
3595 ty: vec_ty,
3596 components: vec![constants[h].init, constants[h1].init],
3597 },
3598 Default::default(),
3599 ),
3600 },
3601 Default::default(),
3602 );
3603
3604 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3605 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3606
3607 let expr2 = Expression::Unary {
3608 op: UnaryOperator::Negate,
3609 expr,
3610 };
3611
3612 let expr3 = Expression::Unary {
3613 op: UnaryOperator::BitwiseNot,
3614 expr,
3615 };
3616
3617 let expr4 = Expression::Unary {
3618 op: UnaryOperator::BitwiseNot,
3619 expr: expr1,
3620 };
3621
3622 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3623 let mut solver = ConstantEvaluator {
3624 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3625 types: &mut types,
3626 constants: &constants,
3627 overrides: &overrides,
3628 expressions: &mut global_expressions,
3629 expression_kind_tracker,
3630 layouter: &mut crate::proc::Layouter::default(),
3631 };
3632
3633 let res1 = solver
3634 .try_eval_and_append(expr2, Default::default())
3635 .unwrap();
3636 let res2 = solver
3637 .try_eval_and_append(expr3, Default::default())
3638 .unwrap();
3639 let res3 = solver
3640 .try_eval_and_append(expr4, Default::default())
3641 .unwrap();
3642
3643 assert_eq!(
3644 global_expressions[res1],
3645 Expression::Literal(Literal::I32(-4))
3646 );
3647
3648 assert_eq!(
3649 global_expressions[res2],
3650 Expression::Literal(Literal::I32(!4))
3651 );
3652
3653 let res3_inner = &global_expressions[res3];
3654
3655 match *res3_inner {
3656 Expression::Compose {
3657 ref ty,
3658 ref components,
3659 } => {
3660 assert_eq!(*ty, vec_ty);
3661 let mut components_iter = components.iter().copied();
3662 assert_eq!(
3663 global_expressions[components_iter.next().unwrap()],
3664 Expression::Literal(Literal::I32(!4))
3665 );
3666 assert_eq!(
3667 global_expressions[components_iter.next().unwrap()],
3668 Expression::Literal(Literal::I32(!8))
3669 );
3670 assert!(components_iter.next().is_none());
3671 }
3672 _ => panic!("Expected vector"),
3673 }
3674 }
3675
3676 #[test]
3677 fn cast() {
3678 let mut types = UniqueArena::new();
3679 let mut constants = Arena::new();
3680 let overrides = Arena::new();
3681 let mut global_expressions = Arena::new();
3682
3683 let scalar_ty = types.insert(
3684 Type {
3685 name: None,
3686 inner: TypeInner::Scalar(crate::Scalar::I32),
3687 },
3688 Default::default(),
3689 );
3690
3691 let h = constants.append(
3692 Constant {
3693 name: None,
3694 ty: scalar_ty,
3695 init: global_expressions
3696 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3697 },
3698 Default::default(),
3699 );
3700
3701 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3702
3703 let root = Expression::As {
3704 expr,
3705 kind: ScalarKind::Bool,
3706 convert: Some(crate::BOOL_WIDTH),
3707 };
3708
3709 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3710 let mut solver = ConstantEvaluator {
3711 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3712 types: &mut types,
3713 constants: &constants,
3714 overrides: &overrides,
3715 expressions: &mut global_expressions,
3716 expression_kind_tracker,
3717 layouter: &mut crate::proc::Layouter::default(),
3718 };
3719
3720 let res = solver
3721 .try_eval_and_append(root, Default::default())
3722 .unwrap();
3723
3724 assert_eq!(
3725 global_expressions[res],
3726 Expression::Literal(Literal::Bool(true))
3727 );
3728 }
3729
3730 #[test]
3731 fn access() {
3732 let mut types = UniqueArena::new();
3733 let mut constants = Arena::new();
3734 let overrides = Arena::new();
3735 let mut global_expressions = Arena::new();
3736
3737 let matrix_ty = types.insert(
3738 Type {
3739 name: None,
3740 inner: TypeInner::Matrix {
3741 columns: VectorSize::Bi,
3742 rows: VectorSize::Tri,
3743 scalar: crate::Scalar::F32,
3744 },
3745 },
3746 Default::default(),
3747 );
3748
3749 let vec_ty = types.insert(
3750 Type {
3751 name: None,
3752 inner: TypeInner::Vector {
3753 size: VectorSize::Tri,
3754 scalar: crate::Scalar::F32,
3755 },
3756 },
3757 Default::default(),
3758 );
3759
3760 let mut vec1_components = Vec::with_capacity(3);
3761 let mut vec2_components = Vec::with_capacity(3);
3762
3763 for i in 0..3 {
3764 let h = global_expressions.append(
3765 Expression::Literal(Literal::F32(i as f32)),
3766 Default::default(),
3767 );
3768
3769 vec1_components.push(h)
3770 }
3771
3772 for i in 3..6 {
3773 let h = global_expressions.append(
3774 Expression::Literal(Literal::F32(i as f32)),
3775 Default::default(),
3776 );
3777
3778 vec2_components.push(h)
3779 }
3780
3781 let vec1 = constants.append(
3782 Constant {
3783 name: None,
3784 ty: vec_ty,
3785 init: global_expressions.append(
3786 Expression::Compose {
3787 ty: vec_ty,
3788 components: vec1_components,
3789 },
3790 Default::default(),
3791 ),
3792 },
3793 Default::default(),
3794 );
3795
3796 let vec2 = constants.append(
3797 Constant {
3798 name: None,
3799 ty: vec_ty,
3800 init: global_expressions.append(
3801 Expression::Compose {
3802 ty: vec_ty,
3803 components: vec2_components,
3804 },
3805 Default::default(),
3806 ),
3807 },
3808 Default::default(),
3809 );
3810
3811 let h = constants.append(
3812 Constant {
3813 name: None,
3814 ty: matrix_ty,
3815 init: global_expressions.append(
3816 Expression::Compose {
3817 ty: matrix_ty,
3818 components: vec![constants[vec1].init, constants[vec2].init],
3819 },
3820 Default::default(),
3821 ),
3822 },
3823 Default::default(),
3824 );
3825
3826 let base = global_expressions.append(Expression::Constant(h), Default::default());
3827
3828 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3829 let mut solver = ConstantEvaluator {
3830 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3831 types: &mut types,
3832 constants: &constants,
3833 overrides: &overrides,
3834 expressions: &mut global_expressions,
3835 expression_kind_tracker,
3836 layouter: &mut crate::proc::Layouter::default(),
3837 };
3838
3839 let root1 = Expression::AccessIndex { base, index: 1 };
3840
3841 let res1 = solver
3842 .try_eval_and_append(root1, Default::default())
3843 .unwrap();
3844
3845 let root2 = Expression::AccessIndex {
3846 base: res1,
3847 index: 2,
3848 };
3849
3850 let res2 = solver
3851 .try_eval_and_append(root2, Default::default())
3852 .unwrap();
3853
3854 match global_expressions[res1] {
3855 Expression::Compose {
3856 ref ty,
3857 ref components,
3858 } => {
3859 assert_eq!(*ty, vec_ty);
3860 let mut components_iter = components.iter().copied();
3861 assert_eq!(
3862 global_expressions[components_iter.next().unwrap()],
3863 Expression::Literal(Literal::F32(3.))
3864 );
3865 assert_eq!(
3866 global_expressions[components_iter.next().unwrap()],
3867 Expression::Literal(Literal::F32(4.))
3868 );
3869 assert_eq!(
3870 global_expressions[components_iter.next().unwrap()],
3871 Expression::Literal(Literal::F32(5.))
3872 );
3873 assert!(components_iter.next().is_none());
3874 }
3875 _ => panic!("Expected vector"),
3876 }
3877
3878 assert_eq!(
3879 global_expressions[res2],
3880 Expression::Literal(Literal::F32(5.))
3881 );
3882 }
3883
3884 #[test]
3885 fn compose_of_constants() {
3886 let mut types = UniqueArena::new();
3887 let mut constants = Arena::new();
3888 let overrides = Arena::new();
3889 let mut global_expressions = Arena::new();
3890
3891 let i32_ty = types.insert(
3892 Type {
3893 name: None,
3894 inner: TypeInner::Scalar(crate::Scalar::I32),
3895 },
3896 Default::default(),
3897 );
3898
3899 let vec2_i32_ty = types.insert(
3900 Type {
3901 name: None,
3902 inner: TypeInner::Vector {
3903 size: VectorSize::Bi,
3904 scalar: crate::Scalar::I32,
3905 },
3906 },
3907 Default::default(),
3908 );
3909
3910 let h = constants.append(
3911 Constant {
3912 name: None,
3913 ty: i32_ty,
3914 init: global_expressions
3915 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3916 },
3917 Default::default(),
3918 );
3919
3920 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3921
3922 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3923 let mut solver = ConstantEvaluator {
3924 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3925 types: &mut types,
3926 constants: &constants,
3927 overrides: &overrides,
3928 expressions: &mut global_expressions,
3929 expression_kind_tracker,
3930 layouter: &mut crate::proc::Layouter::default(),
3931 };
3932
3933 let solved_compose = solver
3934 .try_eval_and_append(
3935 Expression::Compose {
3936 ty: vec2_i32_ty,
3937 components: vec![h_expr, h_expr],
3938 },
3939 Default::default(),
3940 )
3941 .unwrap();
3942 let solved_negate = solver
3943 .try_eval_and_append(
3944 Expression::Unary {
3945 op: UnaryOperator::Negate,
3946 expr: solved_compose,
3947 },
3948 Default::default(),
3949 )
3950 .unwrap();
3951
3952 let pass = match global_expressions[solved_negate] {
3953 Expression::Compose { ty, ref components } => {
3954 ty == vec2_i32_ty
3955 && components.iter().all(|&component| {
3956 let component = &global_expressions[component];
3957 matches!(*component, Expression::Literal(Literal::I32(-4)))
3958 })
3959 }
3960 _ => false,
3961 };
3962 if !pass {
3963 panic!("unexpected evaluation result")
3964 }
3965 }
3966
3967 #[test]
3968 fn splat_of_constant() {
3969 let mut types = UniqueArena::new();
3970 let mut constants = Arena::new();
3971 let overrides = Arena::new();
3972 let mut global_expressions = Arena::new();
3973
3974 let i32_ty = types.insert(
3975 Type {
3976 name: None,
3977 inner: TypeInner::Scalar(crate::Scalar::I32),
3978 },
3979 Default::default(),
3980 );
3981
3982 let vec2_i32_ty = types.insert(
3983 Type {
3984 name: None,
3985 inner: TypeInner::Vector {
3986 size: VectorSize::Bi,
3987 scalar: crate::Scalar::I32,
3988 },
3989 },
3990 Default::default(),
3991 );
3992
3993 let h = constants.append(
3994 Constant {
3995 name: None,
3996 ty: i32_ty,
3997 init: global_expressions
3998 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3999 },
4000 Default::default(),
4001 );
4002
4003 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4004
4005 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4006 let mut solver = ConstantEvaluator {
4007 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4008 types: &mut types,
4009 constants: &constants,
4010 overrides: &overrides,
4011 expressions: &mut global_expressions,
4012 expression_kind_tracker,
4013 layouter: &mut crate::proc::Layouter::default(),
4014 };
4015
4016 let solved_compose = solver
4017 .try_eval_and_append(
4018 Expression::Splat {
4019 size: VectorSize::Bi,
4020 value: h_expr,
4021 },
4022 Default::default(),
4023 )
4024 .unwrap();
4025 let solved_negate = solver
4026 .try_eval_and_append(
4027 Expression::Unary {
4028 op: UnaryOperator::Negate,
4029 expr: solved_compose,
4030 },
4031 Default::default(),
4032 )
4033 .unwrap();
4034
4035 let pass = match global_expressions[solved_negate] {
4036 Expression::Compose { ty, ref components } => {
4037 ty == vec2_i32_ty
4038 && components.iter().all(|&component| {
4039 let component = &global_expressions[component];
4040 matches!(*component, Expression::Literal(Literal::I32(-4)))
4041 })
4042 }
4043 _ => false,
4044 };
4045 if !pass {
4046 panic!("unexpected evaluation result")
4047 }
4048 }
4049
4050 #[test]
4051 fn splat_of_zero_value() {
4052 let mut types = UniqueArena::new();
4053 let constants = Arena::new();
4054 let overrides = Arena::new();
4055 let mut global_expressions = Arena::new();
4056
4057 let f32_ty = types.insert(
4058 Type {
4059 name: None,
4060 inner: TypeInner::Scalar(crate::Scalar::F32),
4061 },
4062 Default::default(),
4063 );
4064
4065 let vec2_f32_ty = types.insert(
4066 Type {
4067 name: None,
4068 inner: TypeInner::Vector {
4069 size: VectorSize::Bi,
4070 scalar: crate::Scalar::F32,
4071 },
4072 },
4073 Default::default(),
4074 );
4075
4076 let five =
4077 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4078 let five_splat = global_expressions.append(
4079 Expression::Splat {
4080 size: VectorSize::Bi,
4081 value: five,
4082 },
4083 Default::default(),
4084 );
4085 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4086 let zero_splat = global_expressions.append(
4087 Expression::Splat {
4088 size: VectorSize::Bi,
4089 value: zero,
4090 },
4091 Default::default(),
4092 );
4093
4094 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4095 let mut solver = ConstantEvaluator {
4096 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4097 types: &mut types,
4098 constants: &constants,
4099 overrides: &overrides,
4100 expressions: &mut global_expressions,
4101 expression_kind_tracker,
4102 layouter: &mut crate::proc::Layouter::default(),
4103 };
4104
4105 let solved_add = solver
4106 .try_eval_and_append(
4107 Expression::Binary {
4108 op: crate::BinaryOperator::Add,
4109 left: zero_splat,
4110 right: five_splat,
4111 },
4112 Default::default(),
4113 )
4114 .unwrap();
4115
4116 let pass = match global_expressions[solved_add] {
4117 Expression::Compose { ty, ref components } => {
4118 ty == vec2_f32_ty
4119 && components.iter().all(|&component| {
4120 let component = &global_expressions[component];
4121 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4122 })
4123 }
4124 _ => false,
4125 };
4126 if !pass {
4127 panic!("unexpected evaluation result")
4128 }
4129 }
4130}