naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4    arena::Handle,
5    proc::OverloadSet as _,
6    proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13    NotInScope,
14    #[error("Base type {0:?} is not compatible with this expression")]
15    InvalidBaseType(Handle<crate::Expression>),
16    #[error("Accessing with index {0:?} can't be done")]
17    InvalidIndexType(Handle<crate::Expression>),
18    #[error("Accessing {0:?} via a negative index is invalid")]
19    NegativeIndex(Handle<crate::Expression>),
20    #[error("Accessing index {1} is out of {0:?} bounds")]
21    IndexOutOfBounds(Handle<crate::Expression>, u32),
22    #[error("Function argument {0:?} doesn't exist")]
23    FunctionArgumentDoesntExist(u32),
24    #[error("Loading of {0:?} can't be done")]
25    InvalidPointerType(Handle<crate::Expression>),
26    #[error("Array length of {0:?} can't be done")]
27    InvalidArrayType(Handle<crate::Expression>),
28    #[error("Get intersection of {0:?} can't be done")]
29    InvalidRayQueryType(Handle<crate::Expression>),
30    #[error("Splatting {0:?} can't be done")]
31    InvalidSplatType(Handle<crate::Expression>),
32    #[error("Swizzling {0:?} can't be done")]
33    InvalidVectorType(Handle<crate::Expression>),
34    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36    #[error(transparent)]
37    Compose(#[from] super::ComposeError),
38    #[error("Cannot construct zero value of {0:?} because it is not a constructible type")]
39    InvalidZeroValue(Handle<crate::Type>),
40    #[error(transparent)]
41    IndexableLength(#[from] IndexableLengthError),
42    #[error("Operation {0:?} can't work with {1:?}")]
43    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
44    #[error(
45        "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
46        op,
47        lhs_expr,
48        lhs_type,
49        rhs_expr,
50        rhs_type
51    )]
52    InvalidBinaryOperandTypes {
53        op: crate::BinaryOperator,
54        lhs_expr: Handle<crate::Expression>,
55        lhs_type: crate::TypeInner,
56        rhs_expr: Handle<crate::Expression>,
57        rhs_type: crate::TypeInner,
58    },
59    #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
60    SelectValuesTypeMismatch {
61        accept: crate::TypeInner,
62        reject: crate::TypeInner,
63    },
64    #[error("Expected selection condition to be a boolean value, got {actual:?}")]
65    SelectConditionNotABool { actual: crate::TypeInner },
66    #[error("Relational argument {0:?} is not a boolean vector")]
67    InvalidBooleanVector(Handle<crate::Expression>),
68    #[error("Relational argument {0:?} is not a float")]
69    InvalidFloatArgument(Handle<crate::Expression>),
70    #[error("Type resolution failed")]
71    Type(#[from] ResolveError),
72    #[error("Not a global variable")]
73    ExpectedGlobalVariable,
74    #[error("Not a global variable or a function argument")]
75    ExpectedGlobalOrArgument,
76    #[error("Needs to be an binding array instead of {0:?}")]
77    ExpectedBindingArrayType(Handle<crate::Type>),
78    #[error("Needs to be an image instead of {0:?}")]
79    ExpectedImageType(Handle<crate::Type>),
80    #[error("Needs to be an image instead of {0:?}")]
81    ExpectedSamplerType(Handle<crate::Type>),
82    #[error("Unable to operate on image class {0:?}")]
83    InvalidImageClass(crate::ImageClass),
84    #[error("Image atomics are not supported for storage format {0:?}")]
85    InvalidImageFormat(crate::StorageFormat),
86    #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
87    InvalidImageStorageAccess(crate::StorageAccess),
88    #[error("Derivatives can only be taken from scalar and vector floats")]
89    InvalidDerivative,
90    #[error("Image array index parameter is misplaced")]
91    InvalidImageArrayIndex,
92    #[error("Cannot textureLoad from a specific multisample sample on a non-multisampled image.")]
93    InvalidImageSampleSelector,
94    #[error("Cannot textureLoad from a multisampled image without specifying a sample.")]
95    MissingImageSampleSelector,
96    #[error("Cannot textureLoad with a specific mip level on a non-mipmapped image.")]
97    InvalidImageLevelSelector,
98    #[error("Cannot textureLoad from a mipmapped image without specifying a level.")]
99    MissingImageLevelSelector,
100    #[error("Image array index type of {0:?} is not an integer scalar")]
101    InvalidImageArrayIndexType(Handle<crate::Expression>),
102    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
103    InvalidImageOtherIndexType(Handle<crate::Expression>),
104    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
105    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
106    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
107    ComparisonSamplingMismatch {
108        image: crate::ImageClass,
109        sampler: bool,
110        has_ref: bool,
111    },
112    #[error("Sample offset must be a const-expression")]
113    InvalidSampleOffsetExprType,
114    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
115    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
116    #[error("Depth reference {0:?} is not a scalar float")]
117    InvalidDepthReference(Handle<crate::Expression>),
118    #[error("Depth sample level can only be Auto or Zero")]
119    InvalidDepthSampleLevel,
120    #[error("Gather level can only be Zero")]
121    InvalidGatherLevel,
122    #[error("Gather component {0:?} doesn't exist in the image")]
123    InvalidGatherComponent(crate::SwizzleComponent),
124    #[error("Gather can't be done for image dimension {0:?}")]
125    InvalidGatherDimension(crate::ImageDimension),
126    #[error("Sample level (exact) type {0:?} has an invalid type")]
127    InvalidSampleLevelExactType(Handle<crate::Expression>),
128    #[error("Sample level (bias) type {0:?} is not a scalar float")]
129    InvalidSampleLevelBiasType(Handle<crate::Expression>),
130    #[error("Bias can't be done for image dimension {0:?}")]
131    InvalidSampleLevelBiasDimension(crate::ImageDimension),
132    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
133    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
134    #[error("Clamping sample coordinate to edge is not supported with {0}")]
135    InvalidSampleClampCoordinateToEdge(alloc::string::String),
136    #[error("Unable to cast")]
137    InvalidCastArgument,
138    #[error("Invalid argument count for {0:?}")]
139    WrongArgumentCount(crate::MathFunction),
140    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
141    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
142    #[error(
143        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
144    )]
145    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
146    #[error("Shader requires capability {0:?}")]
147    MissingCapabilities(super::Capabilities),
148    #[error(transparent)]
149    Literal(#[from] LiteralError),
150    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
151    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
152    #[error("Invalid operand for cooperative op")]
153    InvalidCooperativeOperand(Handle<crate::Expression>),
154    #[error("Shift amount exceeds the bit width of {lhs_type:?}")]
155    ShiftAmountTooLarge {
156        lhs_type: crate::TypeInner,
157        rhs_expr: Handle<crate::Expression>,
158    },
159    #[error("Division by zero")]
160    DivideByZero,
161}
162
163#[derive(Clone, Debug, thiserror::Error)]
164#[cfg_attr(test, derive(PartialEq))]
165pub enum ConstExpressionError {
166    #[error("The expression is not a constant or override expression")]
167    NonConstOrOverride,
168    #[error("The expression is not a fully evaluated constant expression")]
169    NonFullyEvaluatedConst,
170    #[error(transparent)]
171    Compose(#[from] super::ComposeError),
172    #[error("Splatting {0:?} can't be done")]
173    InvalidSplatType(Handle<crate::Expression>),
174    #[error("Type resolution failed")]
175    Type(#[from] ResolveError),
176    #[error(transparent)]
177    Literal(#[from] LiteralError),
178    #[error(transparent)]
179    Width(#[from] super::r#type::WidthError),
180}
181
182#[derive(Clone, Debug, thiserror::Error)]
183#[cfg_attr(test, derive(PartialEq))]
184pub enum LiteralError {
185    #[error("Float literal is NaN")]
186    NaN,
187    #[error("Float literal is infinite")]
188    Infinity,
189    #[error(transparent)]
190    Width(#[from] super::r#type::WidthError),
191}
192
193struct ExpressionTypeResolver<'a> {
194    root: Handle<crate::Expression>,
195    types: &'a UniqueArena<crate::Type>,
196    info: &'a FunctionInfo,
197}
198
199impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
200    type Output = crate::TypeInner;
201
202    #[allow(clippy::panic)]
203    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
204        if handle < self.root {
205            self.info[handle].ty.inner_with(self.types)
206        } else {
207            // `Validator::validate_module_handles` should have caught this.
208            panic!(
209                "Depends on {:?}, which has not been processed yet",
210                self.root
211            )
212        }
213    }
214}
215
216impl super::Validator {
217    pub(super) fn validate_const_expression(
218        &self,
219        handle: Handle<crate::Expression>,
220        gctx: crate::proc::GlobalCtx,
221        mod_info: &ModuleInfo,
222        global_expr_kind: &crate::proc::ExpressionKindTracker,
223    ) -> Result<(), ConstExpressionError> {
224        use crate::Expression as E;
225
226        if !global_expr_kind.is_const_or_override(handle) {
227            return Err(ConstExpressionError::NonConstOrOverride);
228        }
229
230        match gctx.global_expressions[handle] {
231            E::Literal(literal) => {
232                self.validate_literal(literal)?;
233            }
234            E::Constant(_) | E::ZeroValue(_) => {}
235            E::Compose { ref components, ty } => {
236                validate_compose(
237                    ty,
238                    gctx,
239                    components.iter().map(|&handle| mod_info[handle].clone()),
240                )?;
241            }
242            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
243                crate::TypeInner::Scalar { .. } => {}
244                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
245            },
246            _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
247                return Err(ConstExpressionError::NonFullyEvaluatedConst)
248            }
249            // the constant evaluator will report errors about override-expressions
250            _ => {}
251        }
252
253        Ok(())
254    }
255
256    /// Return an error if a constant shift amount in `right` exceeds the bit
257    /// width of `left_ty`.
258    ///
259    /// This function promises to return an error in cases where (1) the
260    /// expression is well-typed, (2) `left_ty` is a concrete integer, and
261    /// (3) the shift will overflow. It does not return an error in cases where
262    /// the expression is not well-typed (e.g. vector dimension mismatch),
263    /// because those will be rejected elsewhere.
264    fn validate_constant_shift_amounts(
265        left_ty: &crate::TypeInner,
266        right: Handle<crate::Expression>,
267        module: &crate::Module,
268        function: &crate::Function,
269    ) -> Result<(), ExpressionError> {
270        fn is_overflowing_shift(
271            left_ty: &crate::TypeInner,
272            right: Handle<crate::Expression>,
273            module: &crate::Module,
274            function: &crate::Function,
275        ) -> bool {
276            let Some((vec_size, scalar)) = left_ty.vector_size_and_scalar() else {
277                return false;
278            };
279            if !matches!(
280                scalar.kind,
281                crate::ScalarKind::Sint | crate::ScalarKind::Uint
282            ) {
283                return false;
284            }
285            let lhs_bits = u32::from(8 * scalar.width);
286            if vec_size.is_none() {
287                let shift_amount = module
288                    .to_ctx()
289                    .get_const_val_from::<u32, _>(right, &function.expressions);
290                shift_amount.ok().is_some_and(|s| s >= lhs_bits)
291            } else {
292                match function.expressions[right] {
293                    crate::Expression::ZeroValue(_) => false, // zero shift does not overflow
294                    crate::Expression::Splat { value, .. } => module
295                        .to_ctx()
296                        .get_const_val_from::<u32, _>(value, &function.expressions)
297                        .ok()
298                        .is_some_and(|s| s >= lhs_bits),
299                    crate::Expression::Compose {
300                        ty: _,
301                        ref components,
302                    } => components.iter().any(|comp| {
303                        module
304                            .to_ctx()
305                            .get_const_val_from::<u32, _>(*comp, &function.expressions)
306                            .ok()
307                            .is_some_and(|s| s >= lhs_bits)
308                    }),
309                    _ => false,
310                }
311            }
312        }
313
314        if is_overflowing_shift(left_ty, right, module, function) {
315            Err(ExpressionError::ShiftAmountTooLarge {
316                lhs_type: left_ty.clone(),
317                rhs_expr: right,
318            })
319        } else {
320            Ok(())
321        }
322    }
323
324    /// Return an error if a constant divisor in `right` evaluates to zero for
325    /// an integer division or remainder operation.
326    ///
327    /// This function promises to return an error in cases where (1) the
328    /// expression is well-typed, (2) `left_ty` is a concrete integer or a
329    /// vector, and (3) `right` is a const-expression that evaluates to zero.
330    /// It does not return an error in cases where the expression is not
331    /// well-typed (e.g. vector dimension mismatch), because those will be
332    /// rejected elsewhere.
333    fn validate_constant_divisor(
334        left_ty: &crate::TypeInner,
335        right: Handle<crate::Expression>,
336        module: &crate::Module,
337        function: &crate::Function,
338    ) -> Result<(), ExpressionError> {
339        fn contains_zero(
340            handle: Handle<crate::Expression>,
341            expressions: &crate::Arena<crate::Expression>,
342            module: &crate::Module,
343        ) -> bool {
344            match expressions[handle] {
345                crate::Expression::Literal(_) | crate::Expression::ZeroValue(_) => module
346                    .to_ctx()
347                    .get_const_val_from::<u32, _>(handle, expressions)
348                    .ok()
349                    .is_some_and(|v| v == 0),
350                crate::Expression::Splat { value, .. } => contains_zero(value, expressions, module),
351                crate::Expression::Compose { ref components, .. } => components
352                    .iter()
353                    .any(|&comp| contains_zero(comp, expressions, module)),
354                crate::Expression::Constant(c) => {
355                    contains_zero(module.constants[c].init, &module.global_expressions, module)
356                }
357                _ => false,
358            }
359        }
360
361        let Some((_, scalar)) = left_ty.vector_size_and_scalar() else {
362            return Ok(());
363        };
364        if !matches!(
365            scalar.kind,
366            crate::ScalarKind::Sint | crate::ScalarKind::Uint
367        ) {
368            return Ok(());
369        }
370
371        if contains_zero(right, &function.expressions, module) {
372            Err(ExpressionError::DivideByZero)
373        } else {
374            Ok(())
375        }
376    }
377
378    #[allow(clippy::too_many_arguments)]
379    pub(super) fn validate_expression(
380        &self,
381        root: Handle<crate::Expression>,
382        expression: &crate::Expression,
383        function: &crate::Function,
384        module: &crate::Module,
385        info: &FunctionInfo,
386        mod_info: &ModuleInfo,
387        expr_kind: &crate::proc::ExpressionKindTracker,
388    ) -> Result<ShaderStages, ExpressionError> {
389        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
390
391        let resolver = ExpressionTypeResolver {
392            root,
393            types: &module.types,
394            info,
395        };
396
397        let stages = match *expression {
398            E::Access { base, index } => {
399                let base_type = &resolver[base];
400                match *base_type {
401                    Ti::Matrix { .. }
402                    | Ti::Vector { .. }
403                    | Ti::Array { .. }
404                    | Ti::Pointer { .. }
405                    | Ti::ValuePointer { size: Some(_), .. }
406                    | Ti::BindingArray { .. } => {}
407                    ref other => {
408                        log::debug!("Indexing of {other:?}");
409                        return Err(ExpressionError::InvalidBaseType(base));
410                    }
411                };
412                match resolver[index] {
413                    //TODO: only allow one of these
414                    Ti::Scalar(Sc {
415                        kind: Sk::Sint | Sk::Uint,
416                        ..
417                    }) => {}
418                    ref other => {
419                        log::debug!("Indexing by {other:?}");
420                        return Err(ExpressionError::InvalidIndexType(index));
421                    }
422                }
423
424                // If index is const we can do check for non-negative index
425                match module
426                    .to_ctx()
427                    .get_const_val_from(index, &function.expressions)
428                {
429                    Ok(value) => {
430                        let length = if self.overrides_resolved {
431                            base_type.indexable_length_resolved(module)
432                        } else {
433                            base_type.indexable_length_pending(module)
434                        }?;
435                        // If we know both the length and the index, we can do the
436                        // bounds check now.
437                        if let crate::proc::IndexableLength::Known(known_length) = length {
438                            if value >= known_length {
439                                return Err(ExpressionError::IndexOutOfBounds(base, value));
440                            }
441                        }
442                    }
443                    Err(crate::proc::ConstValueError::Negative) => {
444                        return Err(ExpressionError::NegativeIndex(base))
445                    }
446                    Err(crate::proc::ConstValueError::NonConst) => {}
447                    Err(crate::proc::ConstValueError::InvalidType) => {
448                        return Err(ExpressionError::InvalidIndexType(index))
449                    }
450                }
451
452                ShaderStages::all()
453            }
454            E::AccessIndex { base, index } => {
455                fn resolve_index_limit(
456                    module: &crate::Module,
457                    top: Handle<crate::Expression>,
458                    ty: &crate::TypeInner,
459                    top_level: bool,
460                ) -> Result<u32, ExpressionError> {
461                    let limit = match *ty {
462                        Ti::Vector { size, .. }
463                        | Ti::ValuePointer {
464                            size: Some(size), ..
465                        } => size as u32,
466                        Ti::Matrix { columns, .. } => columns as u32,
467                        Ti::Array {
468                            size: crate::ArraySize::Constant(len),
469                            ..
470                        } => len.get(),
471                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
472                        Ti::Pointer { base, .. } if top_level => {
473                            resolve_index_limit(module, top, &module.types[base].inner, false)?
474                        }
475                        Ti::Struct { ref members, .. } => members.len() as u32,
476                        ref other => {
477                            log::debug!("Indexing of {other:?}");
478                            return Err(ExpressionError::InvalidBaseType(top));
479                        }
480                    };
481                    Ok(limit)
482                }
483
484                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
485                if index >= limit {
486                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
487                }
488                ShaderStages::all()
489            }
490            E::Splat { size: _, value } => match resolver[value] {
491                Ti::Scalar { .. } => ShaderStages::all(),
492                ref other => {
493                    log::debug!("Splat scalar type {other:?}");
494                    return Err(ExpressionError::InvalidSplatType(value));
495                }
496            },
497            E::Swizzle {
498                size,
499                vector,
500                pattern,
501            } => {
502                let vec_size = match resolver[vector] {
503                    Ti::Vector { size: vec_size, .. } => vec_size,
504                    ref other => {
505                        log::debug!("Swizzle vector type {other:?}");
506                        return Err(ExpressionError::InvalidVectorType(vector));
507                    }
508                };
509                for &sc in pattern[..size as usize].iter() {
510                    if sc as u8 >= vec_size as u8 {
511                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
512                    }
513                }
514                ShaderStages::all()
515            }
516            E::Literal(literal) => {
517                self.validate_literal(literal)?;
518                ShaderStages::all()
519            }
520            E::Constant(_) | E::Override(_) => ShaderStages::all(),
521            E::ZeroValue(ty) => {
522                if !mod_info[ty].contains(TypeFlags::CONSTRUCTIBLE) {
523                    return Err(ExpressionError::InvalidZeroValue(ty));
524                }
525                ShaderStages::all()
526            }
527            E::Compose { ref components, ty } => {
528                validate_compose(
529                    ty,
530                    module.to_ctx(),
531                    components.iter().map(|&handle| info[handle].ty.clone()),
532                )?;
533                ShaderStages::all()
534            }
535            E::FunctionArgument(index) => {
536                if index >= function.arguments.len() as u32 {
537                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
538                }
539                ShaderStages::all()
540            }
541            E::GlobalVariable(_handle) => ShaderStages::all(),
542            E::LocalVariable(_handle) => ShaderStages::all(),
543            E::Load { pointer } => {
544                match resolver[pointer] {
545                    Ti::Pointer { base, .. }
546                        if self.types[base.index()]
547                            .flags
548                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
549                    Ti::ValuePointer { .. } => {}
550                    ref other => {
551                        log::debug!("Loading {other:?}");
552                        return Err(ExpressionError::InvalidPointerType(pointer));
553                    }
554                }
555                ShaderStages::all()
556            }
557            E::ImageSample {
558                image,
559                sampler,
560                gather,
561                coordinate,
562                array_index,
563                offset,
564                level,
565                depth_ref,
566                clamp_to_edge,
567            } => {
568                // check the validity of expressions
569                let image_ty = Self::global_var_ty(module, function, image)?;
570                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
571
572                let comparison = match module.types[sampler_ty].inner {
573                    Ti::Sampler { comparison } => comparison,
574                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
575                };
576
577                let (class, dim) = match module.types[image_ty].inner {
578                    Ti::Image {
579                        class,
580                        arrayed,
581                        dim,
582                    } => {
583                        // check the array property
584                        if arrayed != array_index.is_some() {
585                            return Err(ExpressionError::InvalidImageArrayIndex);
586                        }
587                        if let Some(expr) = array_index {
588                            match resolver[expr] {
589                                Ti::Scalar(Sc {
590                                    kind: Sk::Sint | Sk::Uint,
591                                    ..
592                                }) => {}
593                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
594                            }
595                        }
596                        (class, dim)
597                    }
598                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
599                };
600
601                // check sampling and comparison properties
602                let image_depth = match class {
603                    crate::ImageClass::Sampled {
604                        kind: crate::ScalarKind::Float,
605                        multi: false,
606                    } => false,
607                    crate::ImageClass::Sampled {
608                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
609                        multi: false,
610                    } if gather.is_some() => false,
611                    crate::ImageClass::External => false,
612                    crate::ImageClass::Depth { multi: false } => true,
613                    _ => return Err(ExpressionError::InvalidImageClass(class)),
614                };
615                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
616                    return Err(ExpressionError::ComparisonSamplingMismatch {
617                        image: class,
618                        sampler: comparison,
619                        has_ref: depth_ref.is_some(),
620                    });
621                }
622
623                // check texture coordinates type
624                let num_components = match dim {
625                    crate::ImageDimension::D1 => 1,
626                    crate::ImageDimension::D2 => 2,
627                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
628                };
629                match resolver[coordinate] {
630                    Ti::Scalar(Sc {
631                        kind: Sk::Float, ..
632                    }) if num_components == 1 => {}
633                    Ti::Vector {
634                        size,
635                        scalar:
636                            Sc {
637                                kind: Sk::Float, ..
638                            },
639                    } if size as u32 == num_components => {}
640                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
641                }
642
643                // check constant offset
644                if let Some(const_expr) = offset {
645                    if !expr_kind.is_const(const_expr) {
646                        return Err(ExpressionError::InvalidSampleOffsetExprType);
647                    }
648
649                    match resolver[const_expr] {
650                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
651                        Ti::Vector {
652                            size,
653                            scalar: Sc { kind: Sk::Sint, .. },
654                        } if size as u32 == num_components => {}
655                        _ => {
656                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
657                        }
658                    }
659                }
660
661                // check depth reference type
662                if let Some(expr) = depth_ref {
663                    match resolver[expr] {
664                        Ti::Scalar(Sc {
665                            kind: Sk::Float, ..
666                        }) => {}
667                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
668                    }
669                    match level {
670                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
671                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
672                    }
673                }
674
675                if let Some(component) = gather {
676                    match dim {
677                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
678                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
679                            return Err(ExpressionError::InvalidGatherDimension(dim))
680                        }
681                    };
682                    let max_component = match class {
683                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
684                        _ => crate::SwizzleComponent::W,
685                    };
686                    if component > max_component {
687                        return Err(ExpressionError::InvalidGatherComponent(component));
688                    }
689                    match level {
690                        crate::SampleLevel::Zero => {}
691                        _ => return Err(ExpressionError::InvalidGatherLevel),
692                    }
693                }
694
695                // Clamping coordinate to edge is only supported with 2d non-arrayed, sampled images
696                // when sampling from level Zero without any offset, gather, or depth comparison.
697                if clamp_to_edge {
698                    if !matches!(
699                        class,
700                        crate::ImageClass::Sampled {
701                            kind: crate::ScalarKind::Float,
702                            multi: false
703                        } | crate::ImageClass::External
704                    ) {
705                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
706                            alloc::format!("image class `{class:?}`"),
707                        ));
708                    }
709                    if dim != crate::ImageDimension::D2 {
710                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
711                            alloc::format!("image dimension `{dim:?}`"),
712                        ));
713                    }
714                    if gather.is_some() {
715                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
716                            "gather".into(),
717                        ));
718                    }
719                    if array_index.is_some() {
720                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
721                            "array index".into(),
722                        ));
723                    }
724                    if offset.is_some() {
725                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
726                            "offset".into(),
727                        ));
728                    }
729                    if level != crate::SampleLevel::Zero {
730                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
731                            "non-zero level".into(),
732                        ));
733                    }
734                    if depth_ref.is_some() {
735                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
736                            "depth comparison".into(),
737                        ));
738                    }
739                }
740
741                // External textures can only be sampled using clamp_to_edge.
742                if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
743                    return Err(ExpressionError::InvalidImageClass(class));
744                }
745
746                // check level properties
747                match level {
748                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
749                    crate::SampleLevel::Zero => ShaderStages::all(),
750                    crate::SampleLevel::Exact(expr) => {
751                        match class {
752                            crate::ImageClass::Depth { .. } => match resolver[expr] {
753                                Ti::Scalar(Sc {
754                                    kind: Sk::Sint | Sk::Uint,
755                                    ..
756                                }) => {}
757                                _ => {
758                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
759                                }
760                            },
761                            _ => match resolver[expr] {
762                                Ti::Scalar(Sc {
763                                    kind: Sk::Float, ..
764                                }) => {}
765                                _ => {
766                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
767                                }
768                            },
769                        }
770                        ShaderStages::all()
771                    }
772                    crate::SampleLevel::Bias(expr) => {
773                        match resolver[expr] {
774                            Ti::Scalar(Sc {
775                                kind: Sk::Float, ..
776                            }) => {}
777                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
778                        }
779                        match class {
780                            crate::ImageClass::Sampled {
781                                kind: Sk::Float,
782                                multi: false,
783                            } => {
784                                if dim == crate::ImageDimension::D1 {
785                                    return Err(ExpressionError::InvalidSampleLevelBiasDimension(
786                                        dim,
787                                    ));
788                                }
789                            }
790                            _ => return Err(ExpressionError::InvalidImageClass(class)),
791                        }
792                        ShaderStages::FRAGMENT
793                    }
794                    crate::SampleLevel::Gradient { x, y } => {
795                        match resolver[x] {
796                            Ti::Scalar(Sc {
797                                kind: Sk::Float, ..
798                            }) if num_components == 1 => {}
799                            Ti::Vector {
800                                size,
801                                scalar:
802                                    Sc {
803                                        kind: Sk::Float, ..
804                                    },
805                            } if size as u32 == num_components => {}
806                            _ => {
807                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
808                            }
809                        }
810                        match resolver[y] {
811                            Ti::Scalar(Sc {
812                                kind: Sk::Float, ..
813                            }) if num_components == 1 => {}
814                            Ti::Vector {
815                                size,
816                                scalar:
817                                    Sc {
818                                        kind: Sk::Float, ..
819                                    },
820                            } if size as u32 == num_components => {}
821                            _ => {
822                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
823                            }
824                        }
825                        ShaderStages::all()
826                    }
827                }
828            }
829            E::ImageLoad {
830                image,
831                coordinate,
832                array_index,
833                sample,
834                level,
835            } => {
836                let ty = Self::global_var_ty(module, function, image)?;
837                let Ti::Image {
838                    class,
839                    arrayed,
840                    dim,
841                } = module.types[ty].inner
842                else {
843                    return Err(ExpressionError::ExpectedImageType(ty));
844                };
845
846                match resolver[coordinate].image_storage_coordinates() {
847                    Some(coord_dim) if coord_dim == dim => {}
848                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
849                };
850                if arrayed != array_index.is_some() {
851                    return Err(ExpressionError::InvalidImageArrayIndex);
852                }
853                if let Some(expr) = array_index {
854                    if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
855                        return Err(ExpressionError::InvalidImageArrayIndexType(expr));
856                    }
857                }
858
859                match (sample, class.is_multisampled()) {
860                    (None, false) => {}
861                    (Some(sample), true) => {
862                        if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
863                            return Err(ExpressionError::InvalidImageOtherIndexType(sample));
864                        }
865                    }
866                    (Some(_), false) => {
867                        return Err(ExpressionError::InvalidImageSampleSelector);
868                    }
869                    (None, true) => {
870                        return Err(ExpressionError::MissingImageSampleSelector);
871                    }
872                }
873
874                match (level, class.is_mipmapped()) {
875                    (None, false) => {}
876                    (Some(level), true) => match resolver[level] {
877                        Ti::Scalar(Sc {
878                            kind: Sk::Sint | Sk::Uint,
879                            width: _,
880                        }) => {}
881                        _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
882                    },
883                    (Some(_), false) => {
884                        return Err(ExpressionError::InvalidImageLevelSelector);
885                    }
886                    (None, true) => {
887                        return Err(ExpressionError::MissingImageLevelSelector);
888                    }
889                }
890                ShaderStages::all()
891            }
892            E::ImageQuery { image, query } => {
893                let ty = Self::global_var_ty(module, function, image)?;
894                match module.types[ty].inner {
895                    Ti::Image { class, arrayed, .. } => {
896                        let good = match query {
897                            crate::ImageQuery::NumLayers => arrayed,
898                            crate::ImageQuery::Size { level: None } => true,
899                            crate::ImageQuery::Size { level: Some(level) } => {
900                                match resolver[level] {
901                                    Ti::Scalar(Sc::I32 | Sc::U32) => {}
902                                    _ => {
903                                        return Err(ExpressionError::InvalidImageOtherIndexType(
904                                            level,
905                                        ))
906                                    }
907                                }
908                                class.is_mipmapped()
909                            }
910                            crate::ImageQuery::NumLevels => class.is_mipmapped(),
911                            crate::ImageQuery::NumSamples => class.is_multisampled(),
912                        };
913                        if !good {
914                            return Err(ExpressionError::InvalidImageClass(class));
915                        }
916                    }
917                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
918                }
919                ShaderStages::all()
920            }
921            E::Unary { op, expr } => {
922                use crate::UnaryOperator as Uo;
923                let Some((_, scalar)) = resolver[expr].vector_size_and_scalar() else {
924                    return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
925                };
926                match (op, scalar.kind) {
927                    (Uo::Negate, Sk::Float | Sk::Sint) => {}
928                    (Uo::LogicalNot, Sk::Bool) => {}
929                    (Uo::BitwiseNot, Sk::Sint | Sk::Uint) => {}
930                    _ => return Err(ExpressionError::InvalidUnaryOperandType(op, expr)),
931                }
932                ShaderStages::all()
933            }
934            E::Binary { op, left, right } => {
935                use crate::BinaryOperator as Bo;
936                let left_inner = &resolver[left];
937                let right_inner = &resolver[right];
938                let good = match op {
939                    Bo::Add | Bo::Subtract => match *left_inner {
940                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
941                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
942                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
943                        },
944                        Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => {
945                            left_inner == right_inner
946                        }
947                        _ => false,
948                    },
949                    Bo::Divide | Bo::Modulo => match *left_inner {
950                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
951                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
952                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
953                        },
954                        _ => false,
955                    },
956                    Bo::Multiply => {
957                        let kind_allowed = match left_inner.scalar_kind() {
958                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
959                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
960                        };
961                        let types_match = match (left_inner, right_inner) {
962                            // Straight scalar and mixed scalar/vector.
963                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
964                            | (
965                                &Ti::Vector {
966                                    scalar: scalar1, ..
967                                },
968                                &Ti::Scalar(scalar2),
969                            )
970                            | (
971                                &Ti::Scalar(scalar1),
972                                &Ti::Vector {
973                                    scalar: scalar2, ..
974                                },
975                            ) => scalar1 == scalar2,
976                            // Scalar * matrix.
977                            (
978                                &Ti::Scalar(Sc {
979                                    kind: Sk::Float, ..
980                                }),
981                                &Ti::Matrix { .. },
982                            )
983                            | (
984                                &Ti::Matrix { .. },
985                                &Ti::Scalar(Sc {
986                                    kind: Sk::Float, ..
987                                }),
988                            ) => true,
989                            // Vector * vector.
990                            (
991                                &Ti::Vector {
992                                    size: size1,
993                                    scalar: scalar1,
994                                },
995                                &Ti::Vector {
996                                    size: size2,
997                                    scalar: scalar2,
998                                },
999                            ) => scalar1 == scalar2 && size1 == size2,
1000                            // Matrix * vector.
1001                            (
1002                                &Ti::Matrix { columns, .. },
1003                                &Ti::Vector {
1004                                    size,
1005                                    scalar:
1006                                        Sc {
1007                                            kind: Sk::Float, ..
1008                                        },
1009                                },
1010                            ) => columns == size,
1011                            // Vector * matrix.
1012                            (
1013                                &Ti::Vector {
1014                                    size,
1015                                    scalar:
1016                                        Sc {
1017                                            kind: Sk::Float, ..
1018                                        },
1019                                },
1020                                &Ti::Matrix { rows, .. },
1021                            ) => size == rows,
1022                            // Matrix * matrix.
1023                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
1024                                columns == rows
1025                            }
1026                            // Scalar * coop matrix.
1027                            (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. })
1028                            | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => {
1029                                s1 == s2
1030                            }
1031                            _ => false,
1032                        };
1033                        let left_width = left_inner.scalar_width().unwrap_or(0);
1034                        let right_width = right_inner.scalar_width().unwrap_or(0);
1035                        kind_allowed && types_match && left_width == right_width
1036                    }
1037                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
1038                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
1039                        match *left_inner {
1040                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1041                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
1042                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
1043                            },
1044                            ref other => {
1045                                log::debug!("Op {op:?} left type {other:?}");
1046                                false
1047                            }
1048                        }
1049                    }
1050                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
1051                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
1052                        | Ti::Vector {
1053                            scalar: Sc { kind: Sk::Bool, .. },
1054                            ..
1055                        } => left_inner == right_inner,
1056                        ref other => {
1057                            log::debug!("Op {op:?} left type {other:?}");
1058                            false
1059                        }
1060                    },
1061                    Bo::And | Bo::InclusiveOr => match *left_inner {
1062                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1063                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
1064                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1065                        },
1066                        ref other => {
1067                            log::debug!("Op {op:?} left type {other:?}");
1068                            false
1069                        }
1070                    },
1071                    Bo::ExclusiveOr => match *left_inner {
1072                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1073                            Sk::Sint | Sk::Uint => left_inner == right_inner,
1074                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1075                        },
1076                        ref other => {
1077                            log::debug!("Op {op:?} left type {other:?}");
1078                            false
1079                        }
1080                    },
1081                    Bo::ShiftLeft | Bo::ShiftRight => {
1082                        let (base_size, base_scalar) = match *left_inner {
1083                            Ti::Scalar(scalar) => (Ok(None), scalar),
1084                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
1085                            ref other => {
1086                                log::debug!("Op {op:?} base type {other:?}");
1087                                (Err(()), Sc::BOOL)
1088                            }
1089                        };
1090                        let shift_size = match *right_inner {
1091                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
1092                            Ti::Vector {
1093                                size,
1094                                scalar: Sc { kind: Sk::Uint, .. },
1095                            } => Ok(Some(size)),
1096                            ref other => {
1097                                log::debug!("Op {op:?} shift type {other:?}");
1098                                Err(())
1099                            }
1100                        };
1101                        match base_scalar.kind {
1102                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
1103                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
1104                        }
1105                    }
1106                };
1107                if !good {
1108                    log::debug!(
1109                        "Left: {:?} of type {:?}",
1110                        function.expressions[left],
1111                        left_inner
1112                    );
1113                    log::debug!(
1114                        "Right: {:?} of type {:?}",
1115                        function.expressions[right],
1116                        right_inner
1117                    );
1118                    return Err(ExpressionError::InvalidBinaryOperandTypes {
1119                        op,
1120                        lhs_expr: left,
1121                        lhs_type: left_inner.clone(),
1122                        rhs_expr: right,
1123                        rhs_type: right_inner.clone(),
1124                    });
1125                }
1126                // For shift operations, check if the constant shift amount exceeds the bit width
1127                if matches!(op, Bo::ShiftLeft | Bo::ShiftRight) {
1128                    Self::validate_constant_shift_amounts(left_inner, right, module, function)?;
1129                }
1130                // For integer division or remainder, check if the constant divisor is zero
1131                if matches!(op, Bo::Divide | Bo::Modulo) {
1132                    Self::validate_constant_divisor(left_inner, right, module, function)?;
1133                }
1134                ShaderStages::all()
1135            }
1136            E::Select {
1137                condition,
1138                accept,
1139                reject,
1140            } => {
1141                let accept_inner = &resolver[accept];
1142                let reject_inner = &resolver[reject];
1143                let condition_ty = &resolver[condition];
1144                let condition_good = match *condition_ty {
1145                    Ti::Scalar(Sc {
1146                        kind: Sk::Bool,
1147                        width: _,
1148                    }) => {
1149                        // When `condition` is a single boolean, `accept` and
1150                        // `reject` can be vectors or scalars.
1151                        match *accept_inner {
1152                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
1153                            _ => false,
1154                        }
1155                    }
1156                    Ti::Vector {
1157                        size,
1158                        scalar:
1159                            Sc {
1160                                kind: Sk::Bool,
1161                                width: _,
1162                            },
1163                    } => match *accept_inner {
1164                        Ti::Vector {
1165                            size: other_size, ..
1166                        } => size == other_size,
1167                        _ => false,
1168                    },
1169                    _ => false,
1170                };
1171                if accept_inner != reject_inner {
1172                    return Err(ExpressionError::SelectValuesTypeMismatch {
1173                        accept: accept_inner.clone(),
1174                        reject: reject_inner.clone(),
1175                    });
1176                }
1177                if !condition_good {
1178                    return Err(ExpressionError::SelectConditionNotABool {
1179                        actual: condition_ty.clone(),
1180                    });
1181                }
1182                ShaderStages::all()
1183            }
1184            E::Derivative { expr, .. } => {
1185                let Some((_, scalar)) = resolver[expr].vector_size_and_scalar() else {
1186                    return Err(ExpressionError::InvalidDerivative);
1187                };
1188                if scalar.kind != Sk::Float || scalar.width < 4 {
1189                    // Derivatives are not supported for `f16`, although support may be added
1190                    // in the future, see https://github.com/gpuweb/gpuweb/issues/5482.
1191                    return Err(ExpressionError::InvalidDerivative);
1192                }
1193                ShaderStages::FRAGMENT
1194            }
1195            E::Relational { fun, argument } => {
1196                use crate::RelationalFunction as Rf;
1197                let argument_inner = &resolver[argument];
1198                match fun {
1199                    Rf::All | Rf::Any => match *argument_inner {
1200                        Ti::Vector {
1201                            scalar: Sc { kind: Sk::Bool, .. },
1202                            ..
1203                        } => {}
1204                        ref other => {
1205                            log::debug!("All/Any of type {other:?}");
1206                            return Err(ExpressionError::InvalidBooleanVector(argument));
1207                        }
1208                    },
1209                    Rf::IsNan | Rf::IsInf => match *argument_inner {
1210                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1211                            if scalar.kind == Sk::Float => {}
1212                        ref other => {
1213                            log::debug!("Float test of type {other:?}");
1214                            return Err(ExpressionError::InvalidFloatArgument(argument));
1215                        }
1216                    },
1217                }
1218                ShaderStages::all()
1219            }
1220            E::Math {
1221                fun,
1222                arg,
1223                arg1,
1224                arg2,
1225                arg3,
1226            } => {
1227                if matches!(
1228                    fun,
1229                    crate::MathFunction::QuantizeToF16
1230                        | crate::MathFunction::Pack2x16float
1231                        | crate::MathFunction::Unpack2x16float
1232                ) && !self
1233                    .capabilities
1234                    .contains(crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32)
1235                {
1236                    return Err(ExpressionError::MissingCapabilities(
1237                        crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32,
1238                    ));
1239                }
1240
1241                let actuals: &[_] = match (arg1, arg2, arg3) {
1242                    (None, None, None) => &[arg],
1243                    (Some(arg1), None, None) => &[arg, arg1],
1244                    (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1245                    (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1246                    _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1247                };
1248
1249                let resolve = |arg| &resolver[arg];
1250                let actual_types: &[_] = match *actuals {
1251                    [arg0] => &[resolve(arg0)],
1252                    [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1253                    [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1254                    [arg0, arg1, arg2, arg3] => {
1255                        &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1256                    }
1257                    _ => unreachable!(),
1258                };
1259
1260                // Start with the set of all overloads available for `fun`.
1261                let mut overloads = fun.overloads();
1262                log::debug!(
1263                    "initial overloads for {:?}: {:#?}",
1264                    fun,
1265                    overloads.for_debug(&module.types)
1266                );
1267
1268                // If any argument is not a constant expression, then no
1269                // overloads that accept abstract values should be considered.
1270                // `OverloadSet::concrete_only` is supposed to help impose this
1271                // restriction. However, no `MathFunction` accepts a mix of
1272                // abstract and concrete arguments, so we don't need to worry
1273                // about that here.
1274
1275                for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1276                    // Remove overloads that cannot accept an `i`'th
1277                    // argument arguments of type `ty`.
1278                    overloads = overloads.arg(i, ty, &module.types);
1279                    log::debug!(
1280                        "overloads after arg {i}: {:#?}",
1281                        overloads.for_debug(&module.types)
1282                    );
1283
1284                    if overloads.is_empty() {
1285                        log::debug!("all overloads eliminated");
1286                        return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1287                    }
1288                }
1289
1290                if actuals.len() < overloads.min_arguments() {
1291                    return Err(ExpressionError::WrongArgumentCount(fun));
1292                }
1293
1294                ShaderStages::all()
1295            }
1296            E::As {
1297                expr,
1298                kind,
1299                convert,
1300            } => {
1301                let mut base_scalar = match resolver[expr] {
1302                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1303                        scalar
1304                    }
1305                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1306                    _ => return Err(ExpressionError::InvalidCastArgument),
1307                };
1308                base_scalar.kind = kind;
1309                if let Some(width) = convert {
1310                    base_scalar.width = width;
1311                }
1312                if self.check_width(base_scalar).is_err() {
1313                    return Err(ExpressionError::InvalidCastArgument);
1314                }
1315                ShaderStages::all()
1316            }
1317            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1318            E::AtomicResult { .. } => {
1319                // These expressions are validated when we check the `Atomic` statement
1320                // that refers to them, because we have all the information we need at
1321                // that point. The checks driven by `Validator::needs_visit` ensure
1322                // that this expression is indeed visited by one `Atomic` statement.
1323                ShaderStages::all()
1324            }
1325            E::WorkGroupUniformLoadResult { ty } => {
1326                if self.types[ty.index()]
1327                    .flags
1328                    // Sized | Constructible is exactly the types currently supported by
1329                    // WorkGroupUniformLoad
1330                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1331                {
1332                    ShaderStages::COMPUTE_LIKE
1333                } else {
1334                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1335                }
1336            }
1337            E::ArrayLength(expr) => match resolver[expr] {
1338                Ti::Pointer { base, .. } => {
1339                    let base_ty = &resolver.types[base];
1340                    if let Ti::Array {
1341                        size: crate::ArraySize::Dynamic,
1342                        ..
1343                    } = base_ty.inner
1344                    {
1345                        ShaderStages::all()
1346                    } else {
1347                        return Err(ExpressionError::InvalidArrayType(expr));
1348                    }
1349                }
1350                ref other => {
1351                    log::debug!("Array length of {other:?}");
1352                    return Err(ExpressionError::InvalidArrayType(expr));
1353                }
1354            },
1355            E::RayQueryProceedResult => ShaderStages::all(),
1356            E::RayQueryGetIntersection {
1357                query,
1358                committed: _,
1359            } => match resolver[query] {
1360                Ti::Pointer {
1361                    base,
1362                    space: crate::AddressSpace::Function,
1363                } => match resolver.types[base].inner {
1364                    Ti::RayQuery { .. } => ShaderStages::all(),
1365                    ref other => {
1366                        log::debug!("Intersection result of a pointer to {other:?}");
1367                        return Err(ExpressionError::InvalidRayQueryType(query));
1368                    }
1369                },
1370                ref other => {
1371                    log::debug!("Intersection result of {other:?}");
1372                    return Err(ExpressionError::InvalidRayQueryType(query));
1373                }
1374            },
1375            E::RayQueryVertexPositions {
1376                query,
1377                committed: _,
1378            } => match resolver[query] {
1379                Ti::Pointer {
1380                    base,
1381                    space: crate::AddressSpace::Function,
1382                } => match resolver.types[base].inner {
1383                    Ti::RayQuery {
1384                        vertex_return: true,
1385                    } => ShaderStages::all(),
1386                    ref other => {
1387                        log::debug!("Intersection result of a pointer to {other:?}");
1388                        return Err(ExpressionError::InvalidRayQueryType(query));
1389                    }
1390                },
1391                ref other => {
1392                    log::debug!("Intersection result of {other:?}");
1393                    return Err(ExpressionError::InvalidRayQueryType(query));
1394                }
1395            },
1396            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1397            E::CooperativeLoad { ref data, .. } => {
1398                if resolver[data.pointer]
1399                    .pointer_base_type()
1400                    .and_then(|tr| tr.inner_with(&module.types).scalar())
1401                    .is_none()
1402                {
1403                    return Err(ExpressionError::InvalidPointerType(data.pointer));
1404                }
1405                ShaderStages::COMPUTE
1406            }
1407            E::CooperativeMultiplyAdd { a, b, c } => {
1408                let roles = [
1409                    crate::CooperativeRole::A,
1410                    crate::CooperativeRole::B,
1411                    crate::CooperativeRole::C,
1412                ];
1413                for (operand, expected_role) in [a, b, c].into_iter().zip(roles) {
1414                    match resolver[operand] {
1415                        Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
1416                        ref other => {
1417                            log::debug!("{expected_role:?} operand type: {other:?}");
1418                            return Err(ExpressionError::InvalidCooperativeOperand(a));
1419                        }
1420                    }
1421                }
1422                ShaderStages::COMPUTE
1423            }
1424        };
1425        Ok(stages)
1426    }
1427
1428    fn global_var_ty(
1429        module: &crate::Module,
1430        function: &crate::Function,
1431        expr: Handle<crate::Expression>,
1432    ) -> Result<Handle<crate::Type>, ExpressionError> {
1433        use crate::Expression as Ex;
1434
1435        match function.expressions[expr] {
1436            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1437            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1438            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1439                match function.expressions[base] {
1440                    Ex::GlobalVariable(var_handle) => {
1441                        let array_ty = module.global_variables[var_handle].ty;
1442
1443                        match module.types[array_ty].inner {
1444                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1445                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1446                        }
1447                    }
1448                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1449                }
1450            }
1451            _ => Err(ExpressionError::ExpectedGlobalVariable),
1452        }
1453    }
1454
1455    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1456        let _ = self.check_width(literal.scalar())?;
1457        check_literal_value(literal)?;
1458
1459        Ok(())
1460    }
1461}
1462
1463pub const fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1464    let is_nan = match literal {
1465        crate::Literal::F64(v) => v.is_nan(),
1466        crate::Literal::F32(v) => v.is_nan(),
1467        _ => false,
1468    };
1469    if is_nan {
1470        return Err(LiteralError::NaN);
1471    }
1472
1473    let is_infinite = match literal {
1474        crate::Literal::F64(v) => v.is_infinite(),
1475        crate::Literal::F32(v) => v.is_infinite(),
1476        _ => false,
1477    };
1478    if is_infinite {
1479        return Err(LiteralError::Infinity);
1480    }
1481
1482    Ok(())
1483}
1484
1485#[cfg(test)]
1486/// Validate a module containing the given expression, expecting an error.
1487fn validate_with_expression(
1488    expr: crate::Expression,
1489    caps: super::Capabilities,
1490) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1491    use crate::span::Span;
1492
1493    let mut function = crate::Function::default();
1494    function.expressions.append(expr, Span::default());
1495    function.body.push(
1496        crate::Statement::Emit(function.expressions.range_from(0)),
1497        Span::default(),
1498    );
1499
1500    let mut module = crate::Module::default();
1501    module.functions.append(function, Span::default());
1502
1503    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1504
1505    validator.validate(&module)
1506}
1507
1508#[cfg(test)]
1509/// Validate a module containing the given constant expression, expecting an error.
1510fn validate_with_const_expression(
1511    expr: crate::Expression,
1512    caps: super::Capabilities,
1513) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1514    use crate::span::Span;
1515
1516    let mut module = crate::Module::default();
1517    module.global_expressions.append(expr, Span::default());
1518
1519    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1520
1521    validator.validate(&module)
1522}
1523
1524/// Using F64 in a function's expression arena is forbidden.
1525#[test]
1526fn f64_runtime_literals() {
1527    let result = validate_with_expression(
1528        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1529        super::Capabilities::default(),
1530    );
1531    let error = result.unwrap_err().into_inner();
1532    assert!(matches!(
1533        error,
1534        crate::valid::ValidationError::Function {
1535            source: super::FunctionError::Expression {
1536                source: ExpressionError::Literal(LiteralError::Width(
1537                    super::r#type::WidthError::MissingCapability {
1538                        name: "f64",
1539                        flag: "FLOAT64",
1540                    }
1541                ),),
1542                ..
1543            },
1544            ..
1545        }
1546    ));
1547
1548    let result = validate_with_expression(
1549        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1550        super::Capabilities::default() | super::Capabilities::FLOAT64,
1551    );
1552    assert!(result.is_ok());
1553}
1554
1555/// Using F64 in a module's constant expression arena is forbidden.
1556#[test]
1557fn f64_const_literals() {
1558    let result = validate_with_const_expression(
1559        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1560        super::Capabilities::default(),
1561    );
1562    let error = result.unwrap_err().into_inner();
1563    assert!(matches!(
1564        error,
1565        crate::valid::ValidationError::ConstExpression {
1566            source: ConstExpressionError::Literal(LiteralError::Width(
1567                super::r#type::WidthError::MissingCapability {
1568                    name: "f64",
1569                    flag: "FLOAT64",
1570                }
1571            )),
1572            ..
1573        }
1574    ));
1575
1576    let result = validate_with_const_expression(
1577        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1578        super::Capabilities::default() | super::Capabilities::FLOAT64,
1579    );
1580    assert!(result.is_ok());
1581}