naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4    arena::Handle,
5    proc::OverloadSet as _,
6    proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13    NotInScope,
14    #[error("Base type {0:?} is not compatible with this expression")]
15    InvalidBaseType(Handle<crate::Expression>),
16    #[error("Accessing with index {0:?} can't be done")]
17    InvalidIndexType(Handle<crate::Expression>),
18    #[error("Accessing {0:?} via a negative index is invalid")]
19    NegativeIndex(Handle<crate::Expression>),
20    #[error("Accessing index {1} is out of {0:?} bounds")]
21    IndexOutOfBounds(Handle<crate::Expression>, u32),
22    #[error("Function argument {0:?} doesn't exist")]
23    FunctionArgumentDoesntExist(u32),
24    #[error("Loading of {0:?} can't be done")]
25    InvalidPointerType(Handle<crate::Expression>),
26    #[error("Array length of {0:?} can't be done")]
27    InvalidArrayType(Handle<crate::Expression>),
28    #[error("Get intersection of {0:?} can't be done")]
29    InvalidRayQueryType(Handle<crate::Expression>),
30    #[error("Splatting {0:?} can't be done")]
31    InvalidSplatType(Handle<crate::Expression>),
32    #[error("Swizzling {0:?} can't be done")]
33    InvalidVectorType(Handle<crate::Expression>),
34    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36    #[error(transparent)]
37    Compose(#[from] super::ComposeError),
38    #[error("Cannot construct zero value of {0:?} because it is not a constructible type")]
39    InvalidZeroValue(Handle<crate::Type>),
40    #[error(transparent)]
41    IndexableLength(#[from] IndexableLengthError),
42    #[error("Operation {0:?} can't work with {1:?}")]
43    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
44    #[error(
45        "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
46        op,
47        lhs_expr,
48        lhs_type,
49        rhs_expr,
50        rhs_type
51    )]
52    InvalidBinaryOperandTypes {
53        op: crate::BinaryOperator,
54        lhs_expr: Handle<crate::Expression>,
55        lhs_type: crate::TypeInner,
56        rhs_expr: Handle<crate::Expression>,
57        rhs_type: crate::TypeInner,
58    },
59    #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
60    SelectValuesTypeMismatch {
61        accept: crate::TypeInner,
62        reject: crate::TypeInner,
63    },
64    #[error("Expected selection condition to be a boolean value, got {actual:?}")]
65    SelectConditionNotABool { actual: crate::TypeInner },
66    #[error("Relational argument {0:?} is not a boolean vector")]
67    InvalidBooleanVector(Handle<crate::Expression>),
68    #[error("Relational argument {0:?} is not a float")]
69    InvalidFloatArgument(Handle<crate::Expression>),
70    #[error("Type resolution failed")]
71    Type(#[from] ResolveError),
72    #[error("Not a global variable")]
73    ExpectedGlobalVariable,
74    #[error("Not a global variable or a function argument")]
75    ExpectedGlobalOrArgument,
76    #[error("Needs to be an binding array instead of {0:?}")]
77    ExpectedBindingArrayType(Handle<crate::Type>),
78    #[error("Needs to be an image instead of {0:?}")]
79    ExpectedImageType(Handle<crate::Type>),
80    #[error("Needs to be an image instead of {0:?}")]
81    ExpectedSamplerType(Handle<crate::Type>),
82    #[error("Unable to operate on image class {0:?}")]
83    InvalidImageClass(crate::ImageClass),
84    #[error("Image atomics are not supported for storage format {0:?}")]
85    InvalidImageFormat(crate::StorageFormat),
86    #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
87    InvalidImageStorageAccess(crate::StorageAccess),
88    #[error("Derivatives can only be taken from scalar and vector floats")]
89    InvalidDerivative,
90    #[error("Image array index parameter is misplaced")]
91    InvalidImageArrayIndex,
92    #[error("Cannot textureLoad from a specific multisample sample on a non-multisampled image.")]
93    InvalidImageSampleSelector,
94    #[error("Cannot textureLoad from a multisampled image without specifying a sample.")]
95    MissingImageSampleSelector,
96    #[error("Cannot textureLoad with a specific mip level on a non-mipmapped image.")]
97    InvalidImageLevelSelector,
98    #[error("Cannot textureLoad from a mipmapped image without specifying a level.")]
99    MissingImageLevelSelector,
100    #[error("Image array index type of {0:?} is not an integer scalar")]
101    InvalidImageArrayIndexType(Handle<crate::Expression>),
102    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
103    InvalidImageOtherIndexType(Handle<crate::Expression>),
104    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
105    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
106    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
107    ComparisonSamplingMismatch {
108        image: crate::ImageClass,
109        sampler: bool,
110        has_ref: bool,
111    },
112    #[error("Sample offset must be a const-expression")]
113    InvalidSampleOffsetExprType,
114    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
115    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
116    #[error("Depth reference {0:?} is not a scalar float")]
117    InvalidDepthReference(Handle<crate::Expression>),
118    #[error("Depth sample level can only be Auto or Zero")]
119    InvalidDepthSampleLevel,
120    #[error("Gather level can only be Zero")]
121    InvalidGatherLevel,
122    #[error("Gather component {0:?} doesn't exist in the image")]
123    InvalidGatherComponent(crate::SwizzleComponent),
124    #[error("Gather can't be done for image dimension {0:?}")]
125    InvalidGatherDimension(crate::ImageDimension),
126    #[error("Sample level (exact) type {0:?} has an invalid type")]
127    InvalidSampleLevelExactType(Handle<crate::Expression>),
128    #[error("Sample level (bias) type {0:?} is not a scalar float")]
129    InvalidSampleLevelBiasType(Handle<crate::Expression>),
130    #[error("Bias can't be done for image dimension {0:?}")]
131    InvalidSampleLevelBiasDimension(crate::ImageDimension),
132    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
133    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
134    #[error("Clamping sample coordinate to edge is not supported with {0}")]
135    InvalidSampleClampCoordinateToEdge(alloc::string::String),
136    #[error("Unable to cast")]
137    InvalidCastArgument,
138    #[error("Invalid argument count for {0:?}")]
139    WrongArgumentCount(crate::MathFunction),
140    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
141    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
142    #[error(
143        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
144    )]
145    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
146    #[error("Shader requires capability {0:?}")]
147    MissingCapabilities(super::Capabilities),
148    #[error(transparent)]
149    Literal(#[from] LiteralError),
150    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
151    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
152    #[error("Invalid operand for cooperative op")]
153    InvalidCooperativeOperand(Handle<crate::Expression>),
154    #[error("Shift amount exceeds the bit width of {lhs_type:?}")]
155    ShiftAmountTooLarge {
156        lhs_type: crate::TypeInner,
157        rhs_expr: Handle<crate::Expression>,
158    },
159}
160
161#[derive(Clone, Debug, thiserror::Error)]
162#[cfg_attr(test, derive(PartialEq))]
163pub enum ConstExpressionError {
164    #[error("The expression is not a constant or override expression")]
165    NonConstOrOverride,
166    #[error("The expression is not a fully evaluated constant expression")]
167    NonFullyEvaluatedConst,
168    #[error(transparent)]
169    Compose(#[from] super::ComposeError),
170    #[error("Splatting {0:?} can't be done")]
171    InvalidSplatType(Handle<crate::Expression>),
172    #[error("Type resolution failed")]
173    Type(#[from] ResolveError),
174    #[error(transparent)]
175    Literal(#[from] LiteralError),
176    #[error(transparent)]
177    Width(#[from] super::r#type::WidthError),
178}
179
180#[derive(Clone, Debug, thiserror::Error)]
181#[cfg_attr(test, derive(PartialEq))]
182pub enum LiteralError {
183    #[error("Float literal is NaN")]
184    NaN,
185    #[error("Float literal is infinite")]
186    Infinity,
187    #[error(transparent)]
188    Width(#[from] super::r#type::WidthError),
189}
190
191struct ExpressionTypeResolver<'a> {
192    root: Handle<crate::Expression>,
193    types: &'a UniqueArena<crate::Type>,
194    info: &'a FunctionInfo,
195}
196
197impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
198    type Output = crate::TypeInner;
199
200    #[allow(clippy::panic)]
201    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
202        if handle < self.root {
203            self.info[handle].ty.inner_with(self.types)
204        } else {
205            // `Validator::validate_module_handles` should have caught this.
206            panic!(
207                "Depends on {:?}, which has not been processed yet",
208                self.root
209            )
210        }
211    }
212}
213
214impl super::Validator {
215    pub(super) fn validate_const_expression(
216        &self,
217        handle: Handle<crate::Expression>,
218        gctx: crate::proc::GlobalCtx,
219        mod_info: &ModuleInfo,
220        global_expr_kind: &crate::proc::ExpressionKindTracker,
221    ) -> Result<(), ConstExpressionError> {
222        use crate::Expression as E;
223
224        if !global_expr_kind.is_const_or_override(handle) {
225            return Err(ConstExpressionError::NonConstOrOverride);
226        }
227
228        match gctx.global_expressions[handle] {
229            E::Literal(literal) => {
230                self.validate_literal(literal)?;
231            }
232            E::Constant(_) | E::ZeroValue(_) => {}
233            E::Compose { ref components, ty } => {
234                validate_compose(
235                    ty,
236                    gctx,
237                    components.iter().map(|&handle| mod_info[handle].clone()),
238                )?;
239            }
240            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
241                crate::TypeInner::Scalar { .. } => {}
242                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
243            },
244            _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
245                return Err(ConstExpressionError::NonFullyEvaluatedConst)
246            }
247            // the constant evaluator will report errors about override-expressions
248            _ => {}
249        }
250
251        Ok(())
252    }
253
254    /// Return an error if a constant shift amount in `right` exceeds the bit
255    /// width of `left_ty`.
256    ///
257    /// This function promises to return an error in cases where (1) the
258    /// expression is well-typed, (2) `left_ty` is a concrete integer, and
259    /// (3) the shift will overflow. It does not return an error in cases where
260    /// the expression is not well-typed (e.g. vector dimension mismatch),
261    /// because those will be rejected elsewhere.
262    fn validate_constant_shift_amounts(
263        left_ty: &crate::TypeInner,
264        right: Handle<crate::Expression>,
265        module: &crate::Module,
266        function: &crate::Function,
267    ) -> Result<(), ExpressionError> {
268        fn is_overflowing_shift(
269            left_ty: &crate::TypeInner,
270            right: Handle<crate::Expression>,
271            module: &crate::Module,
272            function: &crate::Function,
273        ) -> bool {
274            let Some((vec_size, scalar)) = left_ty.vector_size_and_scalar() else {
275                return false;
276            };
277            if !matches!(
278                scalar.kind,
279                crate::ScalarKind::Sint | crate::ScalarKind::Uint
280            ) {
281                return false;
282            }
283            let lhs_bits = u32::from(8 * scalar.width);
284            if vec_size.is_none() {
285                let shift_amount = module
286                    .to_ctx()
287                    .get_const_val_from::<u32, _>(right, &function.expressions);
288                shift_amount.ok().is_some_and(|s| s >= lhs_bits)
289            } else {
290                match function.expressions[right] {
291                    crate::Expression::ZeroValue(_) => false, // zero shift does not overflow
292                    crate::Expression::Splat { value, .. } => module
293                        .to_ctx()
294                        .get_const_val_from::<u32, _>(value, &function.expressions)
295                        .ok()
296                        .is_some_and(|s| s >= lhs_bits),
297                    crate::Expression::Compose {
298                        ty: _,
299                        ref components,
300                    } => components.iter().any(|comp| {
301                        module
302                            .to_ctx()
303                            .get_const_val_from::<u32, _>(*comp, &function.expressions)
304                            .ok()
305                            .is_some_and(|s| s >= lhs_bits)
306                    }),
307                    _ => false,
308                }
309            }
310        }
311
312        if is_overflowing_shift(left_ty, right, module, function) {
313            Err(ExpressionError::ShiftAmountTooLarge {
314                lhs_type: left_ty.clone(),
315                rhs_expr: right,
316            })
317        } else {
318            Ok(())
319        }
320    }
321
322    #[allow(clippy::too_many_arguments)]
323    pub(super) fn validate_expression(
324        &self,
325        root: Handle<crate::Expression>,
326        expression: &crate::Expression,
327        function: &crate::Function,
328        module: &crate::Module,
329        info: &FunctionInfo,
330        mod_info: &ModuleInfo,
331        expr_kind: &crate::proc::ExpressionKindTracker,
332    ) -> Result<ShaderStages, ExpressionError> {
333        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
334
335        let resolver = ExpressionTypeResolver {
336            root,
337            types: &module.types,
338            info,
339        };
340
341        let stages = match *expression {
342            E::Access { base, index } => {
343                let base_type = &resolver[base];
344                match *base_type {
345                    Ti::Matrix { .. }
346                    | Ti::Vector { .. }
347                    | Ti::Array { .. }
348                    | Ti::Pointer { .. }
349                    | Ti::ValuePointer { size: Some(_), .. }
350                    | Ti::BindingArray { .. } => {}
351                    ref other => {
352                        log::debug!("Indexing of {other:?}");
353                        return Err(ExpressionError::InvalidBaseType(base));
354                    }
355                };
356                match resolver[index] {
357                    //TODO: only allow one of these
358                    Ti::Scalar(Sc {
359                        kind: Sk::Sint | Sk::Uint,
360                        ..
361                    }) => {}
362                    ref other => {
363                        log::debug!("Indexing by {other:?}");
364                        return Err(ExpressionError::InvalidIndexType(index));
365                    }
366                }
367
368                // If index is const we can do check for non-negative index
369                match module
370                    .to_ctx()
371                    .get_const_val_from(index, &function.expressions)
372                {
373                    Ok(value) => {
374                        let length = if self.overrides_resolved {
375                            base_type.indexable_length_resolved(module)
376                        } else {
377                            base_type.indexable_length_pending(module)
378                        }?;
379                        // If we know both the length and the index, we can do the
380                        // bounds check now.
381                        if let crate::proc::IndexableLength::Known(known_length) = length {
382                            if value >= known_length {
383                                return Err(ExpressionError::IndexOutOfBounds(base, value));
384                            }
385                        }
386                    }
387                    Err(crate::proc::ConstValueError::Negative) => {
388                        return Err(ExpressionError::NegativeIndex(base))
389                    }
390                    Err(crate::proc::ConstValueError::NonConst) => {}
391                    Err(crate::proc::ConstValueError::InvalidType) => {
392                        return Err(ExpressionError::InvalidIndexType(index))
393                    }
394                }
395
396                ShaderStages::all()
397            }
398            E::AccessIndex { base, index } => {
399                fn resolve_index_limit(
400                    module: &crate::Module,
401                    top: Handle<crate::Expression>,
402                    ty: &crate::TypeInner,
403                    top_level: bool,
404                ) -> Result<u32, ExpressionError> {
405                    let limit = match *ty {
406                        Ti::Vector { size, .. }
407                        | Ti::ValuePointer {
408                            size: Some(size), ..
409                        } => size as u32,
410                        Ti::Matrix { columns, .. } => columns as u32,
411                        Ti::Array {
412                            size: crate::ArraySize::Constant(len),
413                            ..
414                        } => len.get(),
415                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
416                        Ti::Pointer { base, .. } if top_level => {
417                            resolve_index_limit(module, top, &module.types[base].inner, false)?
418                        }
419                        Ti::Struct { ref members, .. } => members.len() as u32,
420                        ref other => {
421                            log::debug!("Indexing of {other:?}");
422                            return Err(ExpressionError::InvalidBaseType(top));
423                        }
424                    };
425                    Ok(limit)
426                }
427
428                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
429                if index >= limit {
430                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
431                }
432                ShaderStages::all()
433            }
434            E::Splat { size: _, value } => match resolver[value] {
435                Ti::Scalar { .. } => ShaderStages::all(),
436                ref other => {
437                    log::debug!("Splat scalar type {other:?}");
438                    return Err(ExpressionError::InvalidSplatType(value));
439                }
440            },
441            E::Swizzle {
442                size,
443                vector,
444                pattern,
445            } => {
446                let vec_size = match resolver[vector] {
447                    Ti::Vector { size: vec_size, .. } => vec_size,
448                    ref other => {
449                        log::debug!("Swizzle vector type {other:?}");
450                        return Err(ExpressionError::InvalidVectorType(vector));
451                    }
452                };
453                for &sc in pattern[..size as usize].iter() {
454                    if sc as u8 >= vec_size as u8 {
455                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
456                    }
457                }
458                ShaderStages::all()
459            }
460            E::Literal(literal) => {
461                self.validate_literal(literal)?;
462                ShaderStages::all()
463            }
464            E::Constant(_) | E::Override(_) => ShaderStages::all(),
465            E::ZeroValue(ty) => {
466                if !mod_info[ty].contains(TypeFlags::CONSTRUCTIBLE) {
467                    return Err(ExpressionError::InvalidZeroValue(ty));
468                }
469                ShaderStages::all()
470            }
471            E::Compose { ref components, ty } => {
472                validate_compose(
473                    ty,
474                    module.to_ctx(),
475                    components.iter().map(|&handle| info[handle].ty.clone()),
476                )?;
477                ShaderStages::all()
478            }
479            E::FunctionArgument(index) => {
480                if index >= function.arguments.len() as u32 {
481                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
482                }
483                ShaderStages::all()
484            }
485            E::GlobalVariable(_handle) => ShaderStages::all(),
486            E::LocalVariable(_handle) => ShaderStages::all(),
487            E::Load { pointer } => {
488                match resolver[pointer] {
489                    Ti::Pointer { base, .. }
490                        if self.types[base.index()]
491                            .flags
492                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
493                    Ti::ValuePointer { .. } => {}
494                    ref other => {
495                        log::debug!("Loading {other:?}");
496                        return Err(ExpressionError::InvalidPointerType(pointer));
497                    }
498                }
499                ShaderStages::all()
500            }
501            E::ImageSample {
502                image,
503                sampler,
504                gather,
505                coordinate,
506                array_index,
507                offset,
508                level,
509                depth_ref,
510                clamp_to_edge,
511            } => {
512                // check the validity of expressions
513                let image_ty = Self::global_var_ty(module, function, image)?;
514                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
515
516                let comparison = match module.types[sampler_ty].inner {
517                    Ti::Sampler { comparison } => comparison,
518                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
519                };
520
521                let (class, dim) = match module.types[image_ty].inner {
522                    Ti::Image {
523                        class,
524                        arrayed,
525                        dim,
526                    } => {
527                        // check the array property
528                        if arrayed != array_index.is_some() {
529                            return Err(ExpressionError::InvalidImageArrayIndex);
530                        }
531                        if let Some(expr) = array_index {
532                            match resolver[expr] {
533                                Ti::Scalar(Sc {
534                                    kind: Sk::Sint | Sk::Uint,
535                                    ..
536                                }) => {}
537                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
538                            }
539                        }
540                        (class, dim)
541                    }
542                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
543                };
544
545                // check sampling and comparison properties
546                let image_depth = match class {
547                    crate::ImageClass::Sampled {
548                        kind: crate::ScalarKind::Float,
549                        multi: false,
550                    } => false,
551                    crate::ImageClass::Sampled {
552                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
553                        multi: false,
554                    } if gather.is_some() => false,
555                    crate::ImageClass::External => false,
556                    crate::ImageClass::Depth { multi: false } => true,
557                    _ => return Err(ExpressionError::InvalidImageClass(class)),
558                };
559                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
560                    return Err(ExpressionError::ComparisonSamplingMismatch {
561                        image: class,
562                        sampler: comparison,
563                        has_ref: depth_ref.is_some(),
564                    });
565                }
566
567                // check texture coordinates type
568                let num_components = match dim {
569                    crate::ImageDimension::D1 => 1,
570                    crate::ImageDimension::D2 => 2,
571                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
572                };
573                match resolver[coordinate] {
574                    Ti::Scalar(Sc {
575                        kind: Sk::Float, ..
576                    }) if num_components == 1 => {}
577                    Ti::Vector {
578                        size,
579                        scalar:
580                            Sc {
581                                kind: Sk::Float, ..
582                            },
583                    } if size as u32 == num_components => {}
584                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
585                }
586
587                // check constant offset
588                if let Some(const_expr) = offset {
589                    if !expr_kind.is_const(const_expr) {
590                        return Err(ExpressionError::InvalidSampleOffsetExprType);
591                    }
592
593                    match resolver[const_expr] {
594                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
595                        Ti::Vector {
596                            size,
597                            scalar: Sc { kind: Sk::Sint, .. },
598                        } if size as u32 == num_components => {}
599                        _ => {
600                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
601                        }
602                    }
603                }
604
605                // check depth reference type
606                if let Some(expr) = depth_ref {
607                    match resolver[expr] {
608                        Ti::Scalar(Sc {
609                            kind: Sk::Float, ..
610                        }) => {}
611                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
612                    }
613                    match level {
614                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
615                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
616                    }
617                }
618
619                if let Some(component) = gather {
620                    match dim {
621                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
622                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
623                            return Err(ExpressionError::InvalidGatherDimension(dim))
624                        }
625                    };
626                    let max_component = match class {
627                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
628                        _ => crate::SwizzleComponent::W,
629                    };
630                    if component > max_component {
631                        return Err(ExpressionError::InvalidGatherComponent(component));
632                    }
633                    match level {
634                        crate::SampleLevel::Zero => {}
635                        _ => return Err(ExpressionError::InvalidGatherLevel),
636                    }
637                }
638
639                // Clamping coordinate to edge is only supported with 2d non-arrayed, sampled images
640                // when sampling from level Zero without any offset, gather, or depth comparison.
641                if clamp_to_edge {
642                    if !matches!(
643                        class,
644                        crate::ImageClass::Sampled {
645                            kind: crate::ScalarKind::Float,
646                            multi: false
647                        } | crate::ImageClass::External
648                    ) {
649                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
650                            alloc::format!("image class `{class:?}`"),
651                        ));
652                    }
653                    if dim != crate::ImageDimension::D2 {
654                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
655                            alloc::format!("image dimension `{dim:?}`"),
656                        ));
657                    }
658                    if gather.is_some() {
659                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
660                            "gather".into(),
661                        ));
662                    }
663                    if array_index.is_some() {
664                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
665                            "array index".into(),
666                        ));
667                    }
668                    if offset.is_some() {
669                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
670                            "offset".into(),
671                        ));
672                    }
673                    if level != crate::SampleLevel::Zero {
674                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
675                            "non-zero level".into(),
676                        ));
677                    }
678                    if depth_ref.is_some() {
679                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
680                            "depth comparison".into(),
681                        ));
682                    }
683                }
684
685                // External textures can only be sampled using clamp_to_edge.
686                if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
687                    return Err(ExpressionError::InvalidImageClass(class));
688                }
689
690                // check level properties
691                match level {
692                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
693                    crate::SampleLevel::Zero => ShaderStages::all(),
694                    crate::SampleLevel::Exact(expr) => {
695                        match class {
696                            crate::ImageClass::Depth { .. } => match resolver[expr] {
697                                Ti::Scalar(Sc {
698                                    kind: Sk::Sint | Sk::Uint,
699                                    ..
700                                }) => {}
701                                _ => {
702                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
703                                }
704                            },
705                            _ => match resolver[expr] {
706                                Ti::Scalar(Sc {
707                                    kind: Sk::Float, ..
708                                }) => {}
709                                _ => {
710                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
711                                }
712                            },
713                        }
714                        ShaderStages::all()
715                    }
716                    crate::SampleLevel::Bias(expr) => {
717                        match resolver[expr] {
718                            Ti::Scalar(Sc {
719                                kind: Sk::Float, ..
720                            }) => {}
721                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
722                        }
723                        match class {
724                            crate::ImageClass::Sampled {
725                                kind: Sk::Float,
726                                multi: false,
727                            } => {
728                                if dim == crate::ImageDimension::D1 {
729                                    return Err(ExpressionError::InvalidSampleLevelBiasDimension(
730                                        dim,
731                                    ));
732                                }
733                            }
734                            _ => return Err(ExpressionError::InvalidImageClass(class)),
735                        }
736                        ShaderStages::FRAGMENT
737                    }
738                    crate::SampleLevel::Gradient { x, y } => {
739                        match resolver[x] {
740                            Ti::Scalar(Sc {
741                                kind: Sk::Float, ..
742                            }) if num_components == 1 => {}
743                            Ti::Vector {
744                                size,
745                                scalar:
746                                    Sc {
747                                        kind: Sk::Float, ..
748                                    },
749                            } if size as u32 == num_components => {}
750                            _ => {
751                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
752                            }
753                        }
754                        match resolver[y] {
755                            Ti::Scalar(Sc {
756                                kind: Sk::Float, ..
757                            }) if num_components == 1 => {}
758                            Ti::Vector {
759                                size,
760                                scalar:
761                                    Sc {
762                                        kind: Sk::Float, ..
763                                    },
764                            } if size as u32 == num_components => {}
765                            _ => {
766                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
767                            }
768                        }
769                        ShaderStages::all()
770                    }
771                }
772            }
773            E::ImageLoad {
774                image,
775                coordinate,
776                array_index,
777                sample,
778                level,
779            } => {
780                let ty = Self::global_var_ty(module, function, image)?;
781                let Ti::Image {
782                    class,
783                    arrayed,
784                    dim,
785                } = module.types[ty].inner
786                else {
787                    return Err(ExpressionError::ExpectedImageType(ty));
788                };
789
790                match resolver[coordinate].image_storage_coordinates() {
791                    Some(coord_dim) if coord_dim == dim => {}
792                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
793                };
794                if arrayed != array_index.is_some() {
795                    return Err(ExpressionError::InvalidImageArrayIndex);
796                }
797                if let Some(expr) = array_index {
798                    if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
799                        return Err(ExpressionError::InvalidImageArrayIndexType(expr));
800                    }
801                }
802
803                match (sample, class.is_multisampled()) {
804                    (None, false) => {}
805                    (Some(sample), true) => {
806                        if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
807                            return Err(ExpressionError::InvalidImageOtherIndexType(sample));
808                        }
809                    }
810                    (Some(_), false) => {
811                        return Err(ExpressionError::InvalidImageSampleSelector);
812                    }
813                    (None, true) => {
814                        return Err(ExpressionError::MissingImageSampleSelector);
815                    }
816                }
817
818                match (level, class.is_mipmapped()) {
819                    (None, false) => {}
820                    (Some(level), true) => match resolver[level] {
821                        Ti::Scalar(Sc {
822                            kind: Sk::Sint | Sk::Uint,
823                            width: _,
824                        }) => {}
825                        _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
826                    },
827                    (Some(_), false) => {
828                        return Err(ExpressionError::InvalidImageLevelSelector);
829                    }
830                    (None, true) => {
831                        return Err(ExpressionError::MissingImageLevelSelector);
832                    }
833                }
834                ShaderStages::all()
835            }
836            E::ImageQuery { image, query } => {
837                let ty = Self::global_var_ty(module, function, image)?;
838                match module.types[ty].inner {
839                    Ti::Image { class, arrayed, .. } => {
840                        let good = match query {
841                            crate::ImageQuery::NumLayers => arrayed,
842                            crate::ImageQuery::Size { level: None } => true,
843                            crate::ImageQuery::Size { level: Some(level) } => {
844                                match resolver[level] {
845                                    Ti::Scalar(Sc::I32 | Sc::U32) => {}
846                                    _ => {
847                                        return Err(ExpressionError::InvalidImageOtherIndexType(
848                                            level,
849                                        ))
850                                    }
851                                }
852                                class.is_mipmapped()
853                            }
854                            crate::ImageQuery::NumLevels => class.is_mipmapped(),
855                            crate::ImageQuery::NumSamples => class.is_multisampled(),
856                        };
857                        if !good {
858                            return Err(ExpressionError::InvalidImageClass(class));
859                        }
860                    }
861                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
862                }
863                ShaderStages::all()
864            }
865            E::Unary { op, expr } => {
866                use crate::UnaryOperator as Uo;
867                let inner = &resolver[expr];
868                match (op, inner.scalar_kind()) {
869                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
870                    | (Uo::LogicalNot, Some(Sk::Bool))
871                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
872                    other => {
873                        log::debug!("Op {op:?} kind {other:?}");
874                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
875                    }
876                }
877                ShaderStages::all()
878            }
879            E::Binary { op, left, right } => {
880                use crate::BinaryOperator as Bo;
881                let left_inner = &resolver[left];
882                let right_inner = &resolver[right];
883                let good = match op {
884                    Bo::Add | Bo::Subtract => match *left_inner {
885                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
886                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
887                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
888                        },
889                        Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => {
890                            left_inner == right_inner
891                        }
892                        _ => false,
893                    },
894                    Bo::Divide | Bo::Modulo => match *left_inner {
895                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
896                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
897                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
898                        },
899                        _ => false,
900                    },
901                    Bo::Multiply => {
902                        let kind_allowed = match left_inner.scalar_kind() {
903                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
904                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
905                        };
906                        let types_match = match (left_inner, right_inner) {
907                            // Straight scalar and mixed scalar/vector.
908                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
909                            | (
910                                &Ti::Vector {
911                                    scalar: scalar1, ..
912                                },
913                                &Ti::Scalar(scalar2),
914                            )
915                            | (
916                                &Ti::Scalar(scalar1),
917                                &Ti::Vector {
918                                    scalar: scalar2, ..
919                                },
920                            ) => scalar1 == scalar2,
921                            // Scalar * matrix.
922                            (
923                                &Ti::Scalar(Sc {
924                                    kind: Sk::Float, ..
925                                }),
926                                &Ti::Matrix { .. },
927                            )
928                            | (
929                                &Ti::Matrix { .. },
930                                &Ti::Scalar(Sc {
931                                    kind: Sk::Float, ..
932                                }),
933                            ) => true,
934                            // Vector * vector.
935                            (
936                                &Ti::Vector {
937                                    size: size1,
938                                    scalar: scalar1,
939                                },
940                                &Ti::Vector {
941                                    size: size2,
942                                    scalar: scalar2,
943                                },
944                            ) => scalar1 == scalar2 && size1 == size2,
945                            // Matrix * vector.
946                            (
947                                &Ti::Matrix { columns, .. },
948                                &Ti::Vector {
949                                    size,
950                                    scalar:
951                                        Sc {
952                                            kind: Sk::Float, ..
953                                        },
954                                },
955                            ) => columns == size,
956                            // Vector * matrix.
957                            (
958                                &Ti::Vector {
959                                    size,
960                                    scalar:
961                                        Sc {
962                                            kind: Sk::Float, ..
963                                        },
964                                },
965                                &Ti::Matrix { rows, .. },
966                            ) => size == rows,
967                            // Matrix * matrix.
968                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
969                                columns == rows
970                            }
971                            // Scalar * coop matrix.
972                            (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. })
973                            | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => {
974                                s1 == s2
975                            }
976                            _ => false,
977                        };
978                        let left_width = left_inner.scalar_width().unwrap_or(0);
979                        let right_width = right_inner.scalar_width().unwrap_or(0);
980                        kind_allowed && types_match && left_width == right_width
981                    }
982                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
983                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
984                        match *left_inner {
985                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
986                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
987                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
988                            },
989                            ref other => {
990                                log::debug!("Op {op:?} left type {other:?}");
991                                false
992                            }
993                        }
994                    }
995                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
996                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
997                        | Ti::Vector {
998                            scalar: Sc { kind: Sk::Bool, .. },
999                            ..
1000                        } => left_inner == right_inner,
1001                        ref other => {
1002                            log::debug!("Op {op:?} left type {other:?}");
1003                            false
1004                        }
1005                    },
1006                    Bo::And | Bo::InclusiveOr => match *left_inner {
1007                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1008                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
1009                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1010                        },
1011                        ref other => {
1012                            log::debug!("Op {op:?} left type {other:?}");
1013                            false
1014                        }
1015                    },
1016                    Bo::ExclusiveOr => match *left_inner {
1017                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1018                            Sk::Sint | Sk::Uint => left_inner == right_inner,
1019                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1020                        },
1021                        ref other => {
1022                            log::debug!("Op {op:?} left type {other:?}");
1023                            false
1024                        }
1025                    },
1026                    Bo::ShiftLeft | Bo::ShiftRight => {
1027                        let (base_size, base_scalar) = match *left_inner {
1028                            Ti::Scalar(scalar) => (Ok(None), scalar),
1029                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
1030                            ref other => {
1031                                log::debug!("Op {op:?} base type {other:?}");
1032                                (Err(()), Sc::BOOL)
1033                            }
1034                        };
1035                        let shift_size = match *right_inner {
1036                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
1037                            Ti::Vector {
1038                                size,
1039                                scalar: Sc { kind: Sk::Uint, .. },
1040                            } => Ok(Some(size)),
1041                            ref other => {
1042                                log::debug!("Op {op:?} shift type {other:?}");
1043                                Err(())
1044                            }
1045                        };
1046                        match base_scalar.kind {
1047                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
1048                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
1049                        }
1050                    }
1051                };
1052                if !good {
1053                    log::debug!(
1054                        "Left: {:?} of type {:?}",
1055                        function.expressions[left],
1056                        left_inner
1057                    );
1058                    log::debug!(
1059                        "Right: {:?} of type {:?}",
1060                        function.expressions[right],
1061                        right_inner
1062                    );
1063                    return Err(ExpressionError::InvalidBinaryOperandTypes {
1064                        op,
1065                        lhs_expr: left,
1066                        lhs_type: left_inner.clone(),
1067                        rhs_expr: right,
1068                        rhs_type: right_inner.clone(),
1069                    });
1070                }
1071                // For shift operations, check if the constant shift amount exceeds the bit width
1072                if matches!(op, Bo::ShiftLeft | Bo::ShiftRight) {
1073                    Self::validate_constant_shift_amounts(left_inner, right, module, function)?;
1074                }
1075                ShaderStages::all()
1076            }
1077            E::Select {
1078                condition,
1079                accept,
1080                reject,
1081            } => {
1082                let accept_inner = &resolver[accept];
1083                let reject_inner = &resolver[reject];
1084                let condition_ty = &resolver[condition];
1085                let condition_good = match *condition_ty {
1086                    Ti::Scalar(Sc {
1087                        kind: Sk::Bool,
1088                        width: _,
1089                    }) => {
1090                        // When `condition` is a single boolean, `accept` and
1091                        // `reject` can be vectors or scalars.
1092                        match *accept_inner {
1093                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
1094                            _ => false,
1095                        }
1096                    }
1097                    Ti::Vector {
1098                        size,
1099                        scalar:
1100                            Sc {
1101                                kind: Sk::Bool,
1102                                width: _,
1103                            },
1104                    } => match *accept_inner {
1105                        Ti::Vector {
1106                            size: other_size, ..
1107                        } => size == other_size,
1108                        _ => false,
1109                    },
1110                    _ => false,
1111                };
1112                if accept_inner != reject_inner {
1113                    return Err(ExpressionError::SelectValuesTypeMismatch {
1114                        accept: accept_inner.clone(),
1115                        reject: reject_inner.clone(),
1116                    });
1117                }
1118                if !condition_good {
1119                    return Err(ExpressionError::SelectConditionNotABool {
1120                        actual: condition_ty.clone(),
1121                    });
1122                }
1123                ShaderStages::all()
1124            }
1125            E::Derivative { expr, .. } => {
1126                match resolver[expr] {
1127                    Ti::Scalar(Sc {
1128                        kind: Sk::Float, ..
1129                    })
1130                    | Ti::Vector {
1131                        scalar:
1132                            Sc {
1133                                kind: Sk::Float, ..
1134                            },
1135                        ..
1136                    } => {}
1137                    _ => return Err(ExpressionError::InvalidDerivative),
1138                }
1139                ShaderStages::FRAGMENT
1140            }
1141            E::Relational { fun, argument } => {
1142                use crate::RelationalFunction as Rf;
1143                let argument_inner = &resolver[argument];
1144                match fun {
1145                    Rf::All | Rf::Any => match *argument_inner {
1146                        Ti::Vector {
1147                            scalar: Sc { kind: Sk::Bool, .. },
1148                            ..
1149                        } => {}
1150                        ref other => {
1151                            log::debug!("All/Any of type {other:?}");
1152                            return Err(ExpressionError::InvalidBooleanVector(argument));
1153                        }
1154                    },
1155                    Rf::IsNan | Rf::IsInf => match *argument_inner {
1156                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1157                            if scalar.kind == Sk::Float => {}
1158                        ref other => {
1159                            log::debug!("Float test of type {other:?}");
1160                            return Err(ExpressionError::InvalidFloatArgument(argument));
1161                        }
1162                    },
1163                }
1164                ShaderStages::all()
1165            }
1166            E::Math {
1167                fun,
1168                arg,
1169                arg1,
1170                arg2,
1171                arg3,
1172            } => {
1173                if matches!(
1174                    fun,
1175                    crate::MathFunction::QuantizeToF16
1176                        | crate::MathFunction::Pack2x16float
1177                        | crate::MathFunction::Unpack2x16float
1178                ) && !self
1179                    .capabilities
1180                    .contains(crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32)
1181                {
1182                    return Err(ExpressionError::MissingCapabilities(
1183                        crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32,
1184                    ));
1185                }
1186
1187                let actuals: &[_] = match (arg1, arg2, arg3) {
1188                    (None, None, None) => &[arg],
1189                    (Some(arg1), None, None) => &[arg, arg1],
1190                    (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1191                    (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1192                    _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1193                };
1194
1195                let resolve = |arg| &resolver[arg];
1196                let actual_types: &[_] = match *actuals {
1197                    [arg0] => &[resolve(arg0)],
1198                    [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1199                    [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1200                    [arg0, arg1, arg2, arg3] => {
1201                        &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1202                    }
1203                    _ => unreachable!(),
1204                };
1205
1206                // Start with the set of all overloads available for `fun`.
1207                let mut overloads = fun.overloads();
1208                log::debug!(
1209                    "initial overloads for {:?}: {:#?}",
1210                    fun,
1211                    overloads.for_debug(&module.types)
1212                );
1213
1214                // If any argument is not a constant expression, then no
1215                // overloads that accept abstract values should be considered.
1216                // `OverloadSet::concrete_only` is supposed to help impose this
1217                // restriction. However, no `MathFunction` accepts a mix of
1218                // abstract and concrete arguments, so we don't need to worry
1219                // about that here.
1220
1221                for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1222                    // Remove overloads that cannot accept an `i`'th
1223                    // argument arguments of type `ty`.
1224                    overloads = overloads.arg(i, ty, &module.types);
1225                    log::debug!(
1226                        "overloads after arg {i}: {:#?}",
1227                        overloads.for_debug(&module.types)
1228                    );
1229
1230                    if overloads.is_empty() {
1231                        log::debug!("all overloads eliminated");
1232                        return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1233                    }
1234                }
1235
1236                if actuals.len() < overloads.min_arguments() {
1237                    return Err(ExpressionError::WrongArgumentCount(fun));
1238                }
1239
1240                ShaderStages::all()
1241            }
1242            E::As {
1243                expr,
1244                kind,
1245                convert,
1246            } => {
1247                let mut base_scalar = match resolver[expr] {
1248                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1249                        scalar
1250                    }
1251                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1252                    _ => return Err(ExpressionError::InvalidCastArgument),
1253                };
1254                base_scalar.kind = kind;
1255                if let Some(width) = convert {
1256                    base_scalar.width = width;
1257                }
1258                if self.check_width(base_scalar).is_err() {
1259                    return Err(ExpressionError::InvalidCastArgument);
1260                }
1261                ShaderStages::all()
1262            }
1263            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1264            E::AtomicResult { .. } => {
1265                // These expressions are validated when we check the `Atomic` statement
1266                // that refers to them, because we have all the information we need at
1267                // that point. The checks driven by `Validator::needs_visit` ensure
1268                // that this expression is indeed visited by one `Atomic` statement.
1269                ShaderStages::all()
1270            }
1271            E::WorkGroupUniformLoadResult { ty } => {
1272                if self.types[ty.index()]
1273                    .flags
1274                    // Sized | Constructible is exactly the types currently supported by
1275                    // WorkGroupUniformLoad
1276                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1277                {
1278                    ShaderStages::COMPUTE_LIKE
1279                } else {
1280                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1281                }
1282            }
1283            E::ArrayLength(expr) => match resolver[expr] {
1284                Ti::Pointer { base, .. } => {
1285                    let base_ty = &resolver.types[base];
1286                    if let Ti::Array {
1287                        size: crate::ArraySize::Dynamic,
1288                        ..
1289                    } = base_ty.inner
1290                    {
1291                        ShaderStages::all()
1292                    } else {
1293                        return Err(ExpressionError::InvalidArrayType(expr));
1294                    }
1295                }
1296                ref other => {
1297                    log::debug!("Array length of {other:?}");
1298                    return Err(ExpressionError::InvalidArrayType(expr));
1299                }
1300            },
1301            E::RayQueryProceedResult => ShaderStages::all(),
1302            E::RayQueryGetIntersection {
1303                query,
1304                committed: _,
1305            } => match resolver[query] {
1306                Ti::Pointer {
1307                    base,
1308                    space: crate::AddressSpace::Function,
1309                } => match resolver.types[base].inner {
1310                    Ti::RayQuery { .. } => ShaderStages::all(),
1311                    ref other => {
1312                        log::debug!("Intersection result of a pointer to {other:?}");
1313                        return Err(ExpressionError::InvalidRayQueryType(query));
1314                    }
1315                },
1316                ref other => {
1317                    log::debug!("Intersection result of {other:?}");
1318                    return Err(ExpressionError::InvalidRayQueryType(query));
1319                }
1320            },
1321            E::RayQueryVertexPositions {
1322                query,
1323                committed: _,
1324            } => match resolver[query] {
1325                Ti::Pointer {
1326                    base,
1327                    space: crate::AddressSpace::Function,
1328                } => match resolver.types[base].inner {
1329                    Ti::RayQuery {
1330                        vertex_return: true,
1331                    } => ShaderStages::all(),
1332                    ref other => {
1333                        log::debug!("Intersection result of a pointer to {other:?}");
1334                        return Err(ExpressionError::InvalidRayQueryType(query));
1335                    }
1336                },
1337                ref other => {
1338                    log::debug!("Intersection result of {other:?}");
1339                    return Err(ExpressionError::InvalidRayQueryType(query));
1340                }
1341            },
1342            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1343            E::CooperativeLoad { ref data, .. } => {
1344                if resolver[data.pointer]
1345                    .pointer_base_type()
1346                    .and_then(|tr| tr.inner_with(&module.types).scalar())
1347                    .is_none()
1348                {
1349                    return Err(ExpressionError::InvalidPointerType(data.pointer));
1350                }
1351                ShaderStages::COMPUTE
1352            }
1353            E::CooperativeMultiplyAdd { a, b, c } => {
1354                let roles = [
1355                    crate::CooperativeRole::A,
1356                    crate::CooperativeRole::B,
1357                    crate::CooperativeRole::C,
1358                ];
1359                for (operand, expected_role) in [a, b, c].into_iter().zip(roles) {
1360                    match resolver[operand] {
1361                        Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
1362                        ref other => {
1363                            log::debug!("{expected_role:?} operand type: {other:?}");
1364                            return Err(ExpressionError::InvalidCooperativeOperand(a));
1365                        }
1366                    }
1367                }
1368                ShaderStages::COMPUTE
1369            }
1370        };
1371        Ok(stages)
1372    }
1373
1374    fn global_var_ty(
1375        module: &crate::Module,
1376        function: &crate::Function,
1377        expr: Handle<crate::Expression>,
1378    ) -> Result<Handle<crate::Type>, ExpressionError> {
1379        use crate::Expression as Ex;
1380
1381        match function.expressions[expr] {
1382            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1383            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1384            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1385                match function.expressions[base] {
1386                    Ex::GlobalVariable(var_handle) => {
1387                        let array_ty = module.global_variables[var_handle].ty;
1388
1389                        match module.types[array_ty].inner {
1390                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1391                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1392                        }
1393                    }
1394                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1395                }
1396            }
1397            _ => Err(ExpressionError::ExpectedGlobalVariable),
1398        }
1399    }
1400
1401    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1402        let _ = self.check_width(literal.scalar())?;
1403        check_literal_value(literal)?;
1404
1405        Ok(())
1406    }
1407}
1408
1409pub const fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1410    let is_nan = match literal {
1411        crate::Literal::F64(v) => v.is_nan(),
1412        crate::Literal::F32(v) => v.is_nan(),
1413        _ => false,
1414    };
1415    if is_nan {
1416        return Err(LiteralError::NaN);
1417    }
1418
1419    let is_infinite = match literal {
1420        crate::Literal::F64(v) => v.is_infinite(),
1421        crate::Literal::F32(v) => v.is_infinite(),
1422        _ => false,
1423    };
1424    if is_infinite {
1425        return Err(LiteralError::Infinity);
1426    }
1427
1428    Ok(())
1429}
1430
1431#[cfg(test)]
1432/// Validate a module containing the given expression, expecting an error.
1433fn validate_with_expression(
1434    expr: crate::Expression,
1435    caps: super::Capabilities,
1436) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1437    use crate::span::Span;
1438
1439    let mut function = crate::Function::default();
1440    function.expressions.append(expr, Span::default());
1441    function.body.push(
1442        crate::Statement::Emit(function.expressions.range_from(0)),
1443        Span::default(),
1444    );
1445
1446    let mut module = crate::Module::default();
1447    module.functions.append(function, Span::default());
1448
1449    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1450
1451    validator.validate(&module)
1452}
1453
1454#[cfg(test)]
1455/// Validate a module containing the given constant expression, expecting an error.
1456fn validate_with_const_expression(
1457    expr: crate::Expression,
1458    caps: super::Capabilities,
1459) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1460    use crate::span::Span;
1461
1462    let mut module = crate::Module::default();
1463    module.global_expressions.append(expr, Span::default());
1464
1465    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1466
1467    validator.validate(&module)
1468}
1469
1470/// Using F64 in a function's expression arena is forbidden.
1471#[test]
1472fn f64_runtime_literals() {
1473    let result = validate_with_expression(
1474        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1475        super::Capabilities::default(),
1476    );
1477    let error = result.unwrap_err().into_inner();
1478    assert!(matches!(
1479        error,
1480        crate::valid::ValidationError::Function {
1481            source: super::FunctionError::Expression {
1482                source: ExpressionError::Literal(LiteralError::Width(
1483                    super::r#type::WidthError::MissingCapability {
1484                        name: "f64",
1485                        flag: "FLOAT64",
1486                    }
1487                ),),
1488                ..
1489            },
1490            ..
1491        }
1492    ));
1493
1494    let result = validate_with_expression(
1495        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1496        super::Capabilities::default() | super::Capabilities::FLOAT64,
1497    );
1498    assert!(result.is_ok());
1499}
1500
1501/// Using F64 in a module's constant expression arena is forbidden.
1502#[test]
1503fn f64_const_literals() {
1504    let result = validate_with_const_expression(
1505        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1506        super::Capabilities::default(),
1507    );
1508    let error = result.unwrap_err().into_inner();
1509    assert!(matches!(
1510        error,
1511        crate::valid::ValidationError::ConstExpression {
1512            source: ConstExpressionError::Literal(LiteralError::Width(
1513                super::r#type::WidthError::MissingCapability {
1514                    name: "f64",
1515                    flag: "FLOAT64",
1516                }
1517            )),
1518            ..
1519        }
1520    ));
1521
1522    let result = validate_with_const_expression(
1523        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1524        super::Capabilities::default() | super::Capabilities::FLOAT64,
1525    );
1526    assert!(result.is_ok());
1527}