naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::valid::expression::builtin::validate_zero_value;
4use crate::{
5    arena::Handle,
6    proc::OverloadSet as _,
7    proc::{IndexableLengthError, ResolveError},
8};
9
10pub mod builtin;
11
12#[derive(Clone, Debug, thiserror::Error)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum ExpressionError {
15    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
16    NotInScope,
17    #[error("Base type {0:?} is not compatible with this expression")]
18    InvalidBaseType(Handle<crate::Expression>),
19    #[error("Accessing with index {0:?} can't be done")]
20    InvalidIndexType(Handle<crate::Expression>),
21    #[error("Accessing {0:?} via a negative index is invalid")]
22    NegativeIndex(Handle<crate::Expression>),
23    #[error("Accessing index {1} is out of {0:?} bounds")]
24    IndexOutOfBounds(Handle<crate::Expression>, u32),
25    #[error("Function argument {0:?} doesn't exist")]
26    FunctionArgumentDoesntExist(u32),
27    #[error("Loading of {0:?} can't be done")]
28    InvalidPointerType(Handle<crate::Expression>),
29    #[error("Array length of {0:?} can't be done")]
30    InvalidArrayType(Handle<crate::Expression>),
31    #[error("Get intersection of {0:?} can't be done")]
32    InvalidRayQueryType(Handle<crate::Expression>),
33    #[error("Splatting {0:?} can't be done")]
34    InvalidSplatType(Handle<crate::Expression>),
35    #[error("Swizzling {0:?} can't be done")]
36    InvalidVectorType(Handle<crate::Expression>),
37    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
38    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
39    #[error(transparent)]
40    Compose(#[from] super::ComposeError),
41    #[error(transparent)]
42    ZeroValue(#[from] super::ZeroValueError),
43    #[error(transparent)]
44    IndexableLength(#[from] IndexableLengthError),
45    #[error("Operation {0:?} can't work with {1:?}")]
46    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
47    #[error(
48        "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
49        op,
50        lhs_expr,
51        lhs_type,
52        rhs_expr,
53        rhs_type
54    )]
55    InvalidBinaryOperandTypes {
56        op: crate::BinaryOperator,
57        lhs_expr: Handle<crate::Expression>,
58        lhs_type: crate::TypeInner,
59        rhs_expr: Handle<crate::Expression>,
60        rhs_type: crate::TypeInner,
61    },
62    #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
63    SelectValuesTypeMismatch {
64        accept: crate::TypeInner,
65        reject: crate::TypeInner,
66    },
67    #[error("Expected selection condition to be a boolean value, got {actual:?}")]
68    SelectConditionNotABool { actual: crate::TypeInner },
69    #[error("Relational argument {0:?} is not a boolean vector")]
70    InvalidBooleanVector(Handle<crate::Expression>),
71    #[error("Relational argument {0:?} is not a float")]
72    InvalidFloatArgument(Handle<crate::Expression>),
73    #[error("Type resolution failed")]
74    Type(#[from] ResolveError),
75    #[error("Not a global variable")]
76    ExpectedGlobalVariable,
77    #[error("Not a global variable or a function argument")]
78    ExpectedGlobalOrArgument,
79    #[error("Needs to be an binding array instead of {0:?}")]
80    ExpectedBindingArrayType(Handle<crate::Type>),
81    #[error("Needs to be an image instead of {0:?}")]
82    ExpectedImageType(Handle<crate::Type>),
83    #[error("Needs to be an image instead of {0:?}")]
84    ExpectedSamplerType(Handle<crate::Type>),
85    #[error("Unable to operate on image class {0:?}")]
86    InvalidImageClass(crate::ImageClass),
87    #[error("Image atomics are not supported for storage format {0:?}")]
88    InvalidImageFormat(crate::StorageFormat),
89    #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
90    InvalidImageStorageAccess(crate::StorageAccess),
91    #[error("Derivatives can only be taken from scalar and vector floats")]
92    InvalidDerivative,
93    #[error("Image array index parameter is misplaced")]
94    InvalidImageArrayIndex,
95    #[error("Inappropriate sample or level-of-detail index for texel access")]
96    InvalidImageOtherIndex,
97    #[error("Image array index type of {0:?} is not an integer scalar")]
98    InvalidImageArrayIndexType(Handle<crate::Expression>),
99    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
100    InvalidImageOtherIndexType(Handle<crate::Expression>),
101    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
102    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
103    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
104    ComparisonSamplingMismatch {
105        image: crate::ImageClass,
106        sampler: bool,
107        has_ref: bool,
108    },
109    #[error("Sample offset must be a const-expression")]
110    InvalidSampleOffsetExprType,
111    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
112    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
113    #[error("Depth reference {0:?} is not a scalar float")]
114    InvalidDepthReference(Handle<crate::Expression>),
115    #[error("Depth sample level can only be Auto or Zero")]
116    InvalidDepthSampleLevel,
117    #[error("Gather level can only be Zero")]
118    InvalidGatherLevel,
119    #[error("Gather component {0:?} doesn't exist in the image")]
120    InvalidGatherComponent(crate::SwizzleComponent),
121    #[error("Gather can't be done for image dimension {0:?}")]
122    InvalidGatherDimension(crate::ImageDimension),
123    #[error("Sample level (exact) type {0:?} has an invalid type")]
124    InvalidSampleLevelExactType(Handle<crate::Expression>),
125    #[error("Sample level (bias) type {0:?} is not a scalar float")]
126    InvalidSampleLevelBiasType(Handle<crate::Expression>),
127    #[error("Bias can't be done for image dimension {0:?}")]
128    InvalidSampleLevelBiasDimension(crate::ImageDimension),
129    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
130    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
131    #[error("Clamping sample coordinate to edge is not supported with {0}")]
132    InvalidSampleClampCoordinateToEdge(alloc::string::String),
133    #[error("Unable to cast")]
134    InvalidCastArgument,
135    #[error("Invalid argument count for {0:?}")]
136    WrongArgumentCount(crate::MathFunction),
137    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
138    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
139    #[error(
140        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
141    )]
142    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
143    #[error("Shader requires capability {0:?}")]
144    MissingCapabilities(super::Capabilities),
145    #[error(transparent)]
146    Literal(#[from] LiteralError),
147    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
148    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
149    #[error("Invalid operand for cooperative op")]
150    InvalidCooperativeOperand(Handle<crate::Expression>),
151}
152
153#[derive(Clone, Debug, thiserror::Error)]
154#[cfg_attr(test, derive(PartialEq))]
155pub enum ConstExpressionError {
156    #[error("The expression is not a constant or override expression")]
157    NonConstOrOverride,
158    #[error("The expression is not a fully evaluated constant expression")]
159    NonFullyEvaluatedConst,
160    #[error(transparent)]
161    Compose(#[from] super::ComposeError),
162    #[error("Splatting {0:?} can't be done")]
163    InvalidSplatType(Handle<crate::Expression>),
164    #[error("Type resolution failed")]
165    Type(#[from] ResolveError),
166    #[error(transparent)]
167    Literal(#[from] LiteralError),
168    #[error(transparent)]
169    Width(#[from] super::r#type::WidthError),
170}
171
172#[derive(Clone, Debug, thiserror::Error)]
173#[cfg_attr(test, derive(PartialEq))]
174pub enum LiteralError {
175    #[error("Float literal is NaN")]
176    NaN,
177    #[error("Float literal is infinite")]
178    Infinity,
179    #[error(transparent)]
180    Width(#[from] super::r#type::WidthError),
181}
182
183struct ExpressionTypeResolver<'a> {
184    root: Handle<crate::Expression>,
185    types: &'a UniqueArena<crate::Type>,
186    info: &'a FunctionInfo,
187}
188
189impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
190    type Output = crate::TypeInner;
191
192    #[allow(clippy::panic)]
193    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
194        if handle < self.root {
195            self.info[handle].ty.inner_with(self.types)
196        } else {
197            // `Validator::validate_module_handles` should have caught this.
198            panic!(
199                "Depends on {:?}, which has not been processed yet",
200                self.root
201            )
202        }
203    }
204}
205
206impl super::Validator {
207    pub(super) fn validate_const_expression(
208        &self,
209        handle: Handle<crate::Expression>,
210        gctx: crate::proc::GlobalCtx,
211        mod_info: &ModuleInfo,
212        global_expr_kind: &crate::proc::ExpressionKindTracker,
213    ) -> Result<(), ConstExpressionError> {
214        use crate::Expression as E;
215
216        if !global_expr_kind.is_const_or_override(handle) {
217            return Err(ConstExpressionError::NonConstOrOverride);
218        }
219
220        match gctx.global_expressions[handle] {
221            E::Literal(literal) => {
222                self.validate_literal(literal)?;
223            }
224            E::Constant(_) | E::ZeroValue(_) => {}
225            E::Compose { ref components, ty } => {
226                validate_compose(
227                    ty,
228                    gctx,
229                    components.iter().map(|&handle| mod_info[handle].clone()),
230                )?;
231            }
232            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
233                crate::TypeInner::Scalar { .. } => {}
234                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
235            },
236            _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
237                return Err(ConstExpressionError::NonFullyEvaluatedConst)
238            }
239            // the constant evaluator will report errors about override-expressions
240            _ => {}
241        }
242
243        Ok(())
244    }
245
246    #[allow(clippy::too_many_arguments)]
247    pub(super) fn validate_expression(
248        &self,
249        root: Handle<crate::Expression>,
250        expression: &crate::Expression,
251        function: &crate::Function,
252        module: &crate::Module,
253        info: &FunctionInfo,
254        mod_info: &ModuleInfo,
255        expr_kind: &crate::proc::ExpressionKindTracker,
256    ) -> Result<ShaderStages, ExpressionError> {
257        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
258
259        let resolver = ExpressionTypeResolver {
260            root,
261            types: &module.types,
262            info,
263        };
264
265        let stages = match *expression {
266            E::Access { base, index } => {
267                let base_type = &resolver[base];
268                match *base_type {
269                    Ti::Matrix { .. }
270                    | Ti::Vector { .. }
271                    | Ti::Array { .. }
272                    | Ti::Pointer { .. }
273                    | Ti::ValuePointer { size: Some(_), .. }
274                    | Ti::BindingArray { .. } => {}
275                    ref other => {
276                        log::error!("Indexing of {other:?}");
277                        return Err(ExpressionError::InvalidBaseType(base));
278                    }
279                };
280                match resolver[index] {
281                    //TODO: only allow one of these
282                    Ti::Scalar(Sc {
283                        kind: Sk::Sint | Sk::Uint,
284                        ..
285                    }) => {}
286                    ref other => {
287                        log::error!("Indexing by {other:?}");
288                        return Err(ExpressionError::InvalidIndexType(index));
289                    }
290                }
291
292                // If index is const we can do check for non-negative index
293                match module
294                    .to_ctx()
295                    .get_const_val_from(index, &function.expressions)
296                {
297                    Ok(value) => {
298                        let length = if self.overrides_resolved {
299                            base_type.indexable_length_resolved(module)
300                        } else {
301                            base_type.indexable_length_pending(module)
302                        }?;
303                        // If we know both the length and the index, we can do the
304                        // bounds check now.
305                        if let crate::proc::IndexableLength::Known(known_length) = length {
306                            if value >= known_length {
307                                return Err(ExpressionError::IndexOutOfBounds(base, value));
308                            }
309                        }
310                    }
311                    Err(crate::proc::ConstValueError::Negative) => {
312                        return Err(ExpressionError::NegativeIndex(base))
313                    }
314                    Err(crate::proc::ConstValueError::NonConst) => {}
315                    Err(crate::proc::ConstValueError::InvalidType) => {
316                        return Err(ExpressionError::InvalidIndexType(index))
317                    }
318                }
319
320                ShaderStages::all()
321            }
322            E::AccessIndex { base, index } => {
323                fn resolve_index_limit(
324                    module: &crate::Module,
325                    top: Handle<crate::Expression>,
326                    ty: &crate::TypeInner,
327                    top_level: bool,
328                ) -> Result<u32, ExpressionError> {
329                    let limit = match *ty {
330                        Ti::Vector { size, .. }
331                        | Ti::ValuePointer {
332                            size: Some(size), ..
333                        } => size as u32,
334                        Ti::Matrix { columns, .. } => columns as u32,
335                        Ti::Array {
336                            size: crate::ArraySize::Constant(len),
337                            ..
338                        } => len.get(),
339                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
340                        Ti::Pointer { base, .. } if top_level => {
341                            resolve_index_limit(module, top, &module.types[base].inner, false)?
342                        }
343                        Ti::Struct { ref members, .. } => members.len() as u32,
344                        ref other => {
345                            log::error!("Indexing of {other:?}");
346                            return Err(ExpressionError::InvalidBaseType(top));
347                        }
348                    };
349                    Ok(limit)
350                }
351
352                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
353                if index >= limit {
354                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
355                }
356                ShaderStages::all()
357            }
358            E::Splat { size: _, value } => match resolver[value] {
359                Ti::Scalar { .. } => ShaderStages::all(),
360                ref other => {
361                    log::error!("Splat scalar type {other:?}");
362                    return Err(ExpressionError::InvalidSplatType(value));
363                }
364            },
365            E::Swizzle {
366                size,
367                vector,
368                pattern,
369            } => {
370                let vec_size = match resolver[vector] {
371                    Ti::Vector { size: vec_size, .. } => vec_size,
372                    ref other => {
373                        log::error!("Swizzle vector type {other:?}");
374                        return Err(ExpressionError::InvalidVectorType(vector));
375                    }
376                };
377                for &sc in pattern[..size as usize].iter() {
378                    if sc as u8 >= vec_size as u8 {
379                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
380                    }
381                }
382                ShaderStages::all()
383            }
384            E::Literal(literal) => {
385                self.validate_literal(literal)?;
386                ShaderStages::all()
387            }
388            E::Constant(_) | E::Override(_) => ShaderStages::all(),
389            E::ZeroValue(ty) => {
390                validate_zero_value(ty, module.to_ctx())?;
391                ShaderStages::all()
392            }
393            E::Compose { ref components, ty } => {
394                validate_compose(
395                    ty,
396                    module.to_ctx(),
397                    components.iter().map(|&handle| info[handle].ty.clone()),
398                )?;
399                ShaderStages::all()
400            }
401            E::FunctionArgument(index) => {
402                if index >= function.arguments.len() as u32 {
403                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
404                }
405                ShaderStages::all()
406            }
407            E::GlobalVariable(_handle) => ShaderStages::all(),
408            E::LocalVariable(_handle) => ShaderStages::all(),
409            E::Load { pointer } => {
410                match resolver[pointer] {
411                    Ti::Pointer { base, .. }
412                        if self.types[base.index()]
413                            .flags
414                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
415                    Ti::ValuePointer { .. } => {}
416                    ref other => {
417                        log::error!("Loading {other:?}");
418                        return Err(ExpressionError::InvalidPointerType(pointer));
419                    }
420                }
421                ShaderStages::all()
422            }
423            E::ImageSample {
424                image,
425                sampler,
426                gather,
427                coordinate,
428                array_index,
429                offset,
430                level,
431                depth_ref,
432                clamp_to_edge,
433            } => {
434                // check the validity of expressions
435                let image_ty = Self::global_var_ty(module, function, image)?;
436                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
437
438                let comparison = match module.types[sampler_ty].inner {
439                    Ti::Sampler { comparison } => comparison,
440                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
441                };
442
443                let (class, dim) = match module.types[image_ty].inner {
444                    Ti::Image {
445                        class,
446                        arrayed,
447                        dim,
448                    } => {
449                        // check the array property
450                        if arrayed != array_index.is_some() {
451                            return Err(ExpressionError::InvalidImageArrayIndex);
452                        }
453                        if let Some(expr) = array_index {
454                            match resolver[expr] {
455                                Ti::Scalar(Sc {
456                                    kind: Sk::Sint | Sk::Uint,
457                                    ..
458                                }) => {}
459                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
460                            }
461                        }
462                        (class, dim)
463                    }
464                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
465                };
466
467                // check sampling and comparison properties
468                let image_depth = match class {
469                    crate::ImageClass::Sampled {
470                        kind: crate::ScalarKind::Float,
471                        multi: false,
472                    } => false,
473                    crate::ImageClass::Sampled {
474                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
475                        multi: false,
476                    } if gather.is_some() => false,
477                    crate::ImageClass::External => false,
478                    crate::ImageClass::Depth { multi: false } => true,
479                    _ => return Err(ExpressionError::InvalidImageClass(class)),
480                };
481                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
482                    return Err(ExpressionError::ComparisonSamplingMismatch {
483                        image: class,
484                        sampler: comparison,
485                        has_ref: depth_ref.is_some(),
486                    });
487                }
488
489                // check texture coordinates type
490                let num_components = match dim {
491                    crate::ImageDimension::D1 => 1,
492                    crate::ImageDimension::D2 => 2,
493                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
494                };
495                match resolver[coordinate] {
496                    Ti::Scalar(Sc {
497                        kind: Sk::Float, ..
498                    }) if num_components == 1 => {}
499                    Ti::Vector {
500                        size,
501                        scalar:
502                            Sc {
503                                kind: Sk::Float, ..
504                            },
505                    } if size as u32 == num_components => {}
506                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
507                }
508
509                // check constant offset
510                if let Some(const_expr) = offset {
511                    if !expr_kind.is_const(const_expr) {
512                        return Err(ExpressionError::InvalidSampleOffsetExprType);
513                    }
514
515                    match resolver[const_expr] {
516                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
517                        Ti::Vector {
518                            size,
519                            scalar: Sc { kind: Sk::Sint, .. },
520                        } if size as u32 == num_components => {}
521                        _ => {
522                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
523                        }
524                    }
525                }
526
527                // check depth reference type
528                if let Some(expr) = depth_ref {
529                    match resolver[expr] {
530                        Ti::Scalar(Sc {
531                            kind: Sk::Float, ..
532                        }) => {}
533                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
534                    }
535                    match level {
536                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
537                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
538                    }
539                }
540
541                if let Some(component) = gather {
542                    match dim {
543                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
544                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
545                            return Err(ExpressionError::InvalidGatherDimension(dim))
546                        }
547                    };
548                    let max_component = match class {
549                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
550                        _ => crate::SwizzleComponent::W,
551                    };
552                    if component > max_component {
553                        return Err(ExpressionError::InvalidGatherComponent(component));
554                    }
555                    match level {
556                        crate::SampleLevel::Zero => {}
557                        _ => return Err(ExpressionError::InvalidGatherLevel),
558                    }
559                }
560
561                // Clamping coordinate to edge is only supported with 2d non-arrayed, sampled images
562                // when sampling from level Zero without any offset, gather, or depth comparison.
563                if clamp_to_edge {
564                    if !matches!(
565                        class,
566                        crate::ImageClass::Sampled {
567                            kind: crate::ScalarKind::Float,
568                            multi: false
569                        } | crate::ImageClass::External
570                    ) {
571                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
572                            alloc::format!("image class `{class:?}`"),
573                        ));
574                    }
575                    if dim != crate::ImageDimension::D2 {
576                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
577                            alloc::format!("image dimension `{dim:?}`"),
578                        ));
579                    }
580                    if gather.is_some() {
581                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
582                            "gather".into(),
583                        ));
584                    }
585                    if array_index.is_some() {
586                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
587                            "array index".into(),
588                        ));
589                    }
590                    if offset.is_some() {
591                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
592                            "offset".into(),
593                        ));
594                    }
595                    if level != crate::SampleLevel::Zero {
596                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
597                            "non-zero level".into(),
598                        ));
599                    }
600                    if depth_ref.is_some() {
601                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
602                            "depth comparison".into(),
603                        ));
604                    }
605                }
606
607                // External textures can only be sampled using clamp_to_edge.
608                if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
609                    return Err(ExpressionError::InvalidImageClass(class));
610                }
611
612                // check level properties
613                match level {
614                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
615                    crate::SampleLevel::Zero => ShaderStages::all(),
616                    crate::SampleLevel::Exact(expr) => {
617                        match class {
618                            crate::ImageClass::Depth { .. } => match resolver[expr] {
619                                Ti::Scalar(Sc {
620                                    kind: Sk::Sint | Sk::Uint,
621                                    ..
622                                }) => {}
623                                _ => {
624                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
625                                }
626                            },
627                            _ => match resolver[expr] {
628                                Ti::Scalar(Sc {
629                                    kind: Sk::Float, ..
630                                }) => {}
631                                _ => {
632                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
633                                }
634                            },
635                        }
636                        ShaderStages::all()
637                    }
638                    crate::SampleLevel::Bias(expr) => {
639                        match resolver[expr] {
640                            Ti::Scalar(Sc {
641                                kind: Sk::Float, ..
642                            }) => {}
643                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
644                        }
645                        match class {
646                            crate::ImageClass::Sampled {
647                                kind: Sk::Float,
648                                multi: false,
649                            } => {
650                                if dim == crate::ImageDimension::D1 {
651                                    return Err(ExpressionError::InvalidSampleLevelBiasDimension(
652                                        dim,
653                                    ));
654                                }
655                            }
656                            _ => return Err(ExpressionError::InvalidImageClass(class)),
657                        }
658                        ShaderStages::FRAGMENT
659                    }
660                    crate::SampleLevel::Gradient { x, y } => {
661                        match resolver[x] {
662                            Ti::Scalar(Sc {
663                                kind: Sk::Float, ..
664                            }) if num_components == 1 => {}
665                            Ti::Vector {
666                                size,
667                                scalar:
668                                    Sc {
669                                        kind: Sk::Float, ..
670                                    },
671                            } if size as u32 == num_components => {}
672                            _ => {
673                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
674                            }
675                        }
676                        match resolver[y] {
677                            Ti::Scalar(Sc {
678                                kind: Sk::Float, ..
679                            }) if num_components == 1 => {}
680                            Ti::Vector {
681                                size,
682                                scalar:
683                                    Sc {
684                                        kind: Sk::Float, ..
685                                    },
686                            } if size as u32 == num_components => {}
687                            _ => {
688                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
689                            }
690                        }
691                        ShaderStages::all()
692                    }
693                }
694            }
695            E::ImageLoad {
696                image,
697                coordinate,
698                array_index,
699                sample,
700                level,
701            } => {
702                let ty = Self::global_var_ty(module, function, image)?;
703                let Ti::Image {
704                    class,
705                    arrayed,
706                    dim,
707                } = module.types[ty].inner
708                else {
709                    return Err(ExpressionError::ExpectedImageType(ty));
710                };
711
712                match resolver[coordinate].image_storage_coordinates() {
713                    Some(coord_dim) if coord_dim == dim => {}
714                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
715                };
716                if arrayed != array_index.is_some() {
717                    return Err(ExpressionError::InvalidImageArrayIndex);
718                }
719                if let Some(expr) = array_index {
720                    if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
721                        return Err(ExpressionError::InvalidImageArrayIndexType(expr));
722                    }
723                }
724
725                match (sample, class.is_multisampled()) {
726                    (None, false) => {}
727                    (Some(sample), true) => {
728                        if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
729                            return Err(ExpressionError::InvalidImageOtherIndexType(sample));
730                        }
731                    }
732                    _ => {
733                        return Err(ExpressionError::InvalidImageOtherIndex);
734                    }
735                }
736
737                match (level, class.is_mipmapped()) {
738                    (None, false) => {}
739                    (Some(level), true) => match resolver[level] {
740                        Ti::Scalar(Sc {
741                            kind: Sk::Sint | Sk::Uint,
742                            width: _,
743                        }) => {}
744                        _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
745                    },
746                    _ => {
747                        return Err(ExpressionError::InvalidImageOtherIndex);
748                    }
749                }
750                ShaderStages::all()
751            }
752            E::ImageQuery { image, query } => {
753                let ty = Self::global_var_ty(module, function, image)?;
754                match module.types[ty].inner {
755                    Ti::Image { class, arrayed, .. } => {
756                        let good = match query {
757                            crate::ImageQuery::NumLayers => arrayed,
758                            crate::ImageQuery::Size { level: None } => true,
759                            crate::ImageQuery::Size { level: Some(level) } => {
760                                match resolver[level] {
761                                    Ti::Scalar(Sc::I32 | Sc::U32) => {}
762                                    _ => {
763                                        return Err(ExpressionError::InvalidImageOtherIndexType(
764                                            level,
765                                        ))
766                                    }
767                                }
768                                class.is_mipmapped()
769                            }
770                            crate::ImageQuery::NumLevels => class.is_mipmapped(),
771                            crate::ImageQuery::NumSamples => class.is_multisampled(),
772                        };
773                        if !good {
774                            return Err(ExpressionError::InvalidImageClass(class));
775                        }
776                    }
777                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
778                }
779                ShaderStages::all()
780            }
781            E::Unary { op, expr } => {
782                use crate::UnaryOperator as Uo;
783                let inner = &resolver[expr];
784                match (op, inner.scalar_kind()) {
785                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
786                    | (Uo::LogicalNot, Some(Sk::Bool))
787                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
788                    other => {
789                        log::error!("Op {op:?} kind {other:?}");
790                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
791                    }
792                }
793                ShaderStages::all()
794            }
795            E::Binary { op, left, right } => {
796                use crate::BinaryOperator as Bo;
797                let left_inner = &resolver[left];
798                let right_inner = &resolver[right];
799                let good = match op {
800                    Bo::Add | Bo::Subtract => match *left_inner {
801                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
802                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
803                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
804                        },
805                        Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => {
806                            left_inner == right_inner
807                        }
808                        _ => false,
809                    },
810                    Bo::Divide | Bo::Modulo => match *left_inner {
811                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
812                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
813                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
814                        },
815                        _ => false,
816                    },
817                    Bo::Multiply => {
818                        let kind_allowed = match left_inner.scalar_kind() {
819                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
820                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
821                        };
822                        let types_match = match (left_inner, right_inner) {
823                            // Straight scalar and mixed scalar/vector.
824                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
825                            | (
826                                &Ti::Vector {
827                                    scalar: scalar1, ..
828                                },
829                                &Ti::Scalar(scalar2),
830                            )
831                            | (
832                                &Ti::Scalar(scalar1),
833                                &Ti::Vector {
834                                    scalar: scalar2, ..
835                                },
836                            ) => scalar1 == scalar2,
837                            // Scalar * matrix.
838                            (
839                                &Ti::Scalar(Sc {
840                                    kind: Sk::Float, ..
841                                }),
842                                &Ti::Matrix { .. },
843                            )
844                            | (
845                                &Ti::Matrix { .. },
846                                &Ti::Scalar(Sc {
847                                    kind: Sk::Float, ..
848                                }),
849                            ) => true,
850                            // Vector * vector.
851                            (
852                                &Ti::Vector {
853                                    size: size1,
854                                    scalar: scalar1,
855                                },
856                                &Ti::Vector {
857                                    size: size2,
858                                    scalar: scalar2,
859                                },
860                            ) => scalar1 == scalar2 && size1 == size2,
861                            // Matrix * vector.
862                            (
863                                &Ti::Matrix { columns, .. },
864                                &Ti::Vector {
865                                    size,
866                                    scalar:
867                                        Sc {
868                                            kind: Sk::Float, ..
869                                        },
870                                },
871                            ) => columns == size,
872                            // Vector * matrix.
873                            (
874                                &Ti::Vector {
875                                    size,
876                                    scalar:
877                                        Sc {
878                                            kind: Sk::Float, ..
879                                        },
880                                },
881                                &Ti::Matrix { rows, .. },
882                            ) => size == rows,
883                            // Matrix * matrix.
884                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
885                                columns == rows
886                            }
887                            // Scalar * coop matrix.
888                            (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. })
889                            | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => {
890                                s1 == s2
891                            }
892                            _ => false,
893                        };
894                        let left_width = left_inner.scalar_width().unwrap_or(0);
895                        let right_width = right_inner.scalar_width().unwrap_or(0);
896                        kind_allowed && types_match && left_width == right_width
897                    }
898                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
899                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
900                        match *left_inner {
901                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
902                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
903                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
904                            },
905                            ref other => {
906                                log::error!("Op {op:?} left type {other:?}");
907                                false
908                            }
909                        }
910                    }
911                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
912                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
913                        | Ti::Vector {
914                            scalar: Sc { kind: Sk::Bool, .. },
915                            ..
916                        } => left_inner == right_inner,
917                        ref other => {
918                            log::error!("Op {op:?} left type {other:?}");
919                            false
920                        }
921                    },
922                    Bo::And | Bo::InclusiveOr => match *left_inner {
923                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
924                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
925                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
926                        },
927                        ref other => {
928                            log::error!("Op {op:?} left type {other:?}");
929                            false
930                        }
931                    },
932                    Bo::ExclusiveOr => match *left_inner {
933                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
934                            Sk::Sint | Sk::Uint => left_inner == right_inner,
935                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
936                        },
937                        ref other => {
938                            log::error!("Op {op:?} left type {other:?}");
939                            false
940                        }
941                    },
942                    Bo::ShiftLeft | Bo::ShiftRight => {
943                        let (base_size, base_scalar) = match *left_inner {
944                            Ti::Scalar(scalar) => (Ok(None), scalar),
945                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
946                            ref other => {
947                                log::error!("Op {op:?} base type {other:?}");
948                                (Err(()), Sc::BOOL)
949                            }
950                        };
951                        let shift_size = match *right_inner {
952                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
953                            Ti::Vector {
954                                size,
955                                scalar: Sc { kind: Sk::Uint, .. },
956                            } => Ok(Some(size)),
957                            ref other => {
958                                log::error!("Op {op:?} shift type {other:?}");
959                                Err(())
960                            }
961                        };
962                        match base_scalar.kind {
963                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
964                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
965                        }
966                    }
967                };
968                if !good {
969                    log::error!(
970                        "Left: {:?} of type {:?}",
971                        function.expressions[left],
972                        left_inner
973                    );
974                    log::error!(
975                        "Right: {:?} of type {:?}",
976                        function.expressions[right],
977                        right_inner
978                    );
979                    return Err(ExpressionError::InvalidBinaryOperandTypes {
980                        op,
981                        lhs_expr: left,
982                        lhs_type: left_inner.clone(),
983                        rhs_expr: right,
984                        rhs_type: right_inner.clone(),
985                    });
986                }
987                ShaderStages::all()
988            }
989            E::Select {
990                condition,
991                accept,
992                reject,
993            } => {
994                let accept_inner = &resolver[accept];
995                let reject_inner = &resolver[reject];
996                let condition_ty = &resolver[condition];
997                let condition_good = match *condition_ty {
998                    Ti::Scalar(Sc {
999                        kind: Sk::Bool,
1000                        width: _,
1001                    }) => {
1002                        // When `condition` is a single boolean, `accept` and
1003                        // `reject` can be vectors or scalars.
1004                        match *accept_inner {
1005                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
1006                            _ => false,
1007                        }
1008                    }
1009                    Ti::Vector {
1010                        size,
1011                        scalar:
1012                            Sc {
1013                                kind: Sk::Bool,
1014                                width: _,
1015                            },
1016                    } => match *accept_inner {
1017                        Ti::Vector {
1018                            size: other_size, ..
1019                        } => size == other_size,
1020                        _ => false,
1021                    },
1022                    _ => false,
1023                };
1024                if accept_inner != reject_inner {
1025                    return Err(ExpressionError::SelectValuesTypeMismatch {
1026                        accept: accept_inner.clone(),
1027                        reject: reject_inner.clone(),
1028                    });
1029                }
1030                if !condition_good {
1031                    return Err(ExpressionError::SelectConditionNotABool {
1032                        actual: condition_ty.clone(),
1033                    });
1034                }
1035                ShaderStages::all()
1036            }
1037            E::Derivative { expr, .. } => {
1038                match resolver[expr] {
1039                    Ti::Scalar(Sc {
1040                        kind: Sk::Float, ..
1041                    })
1042                    | Ti::Vector {
1043                        scalar:
1044                            Sc {
1045                                kind: Sk::Float, ..
1046                            },
1047                        ..
1048                    } => {}
1049                    _ => return Err(ExpressionError::InvalidDerivative),
1050                }
1051                ShaderStages::FRAGMENT
1052            }
1053            E::Relational { fun, argument } => {
1054                use crate::RelationalFunction as Rf;
1055                let argument_inner = &resolver[argument];
1056                match fun {
1057                    Rf::All | Rf::Any => match *argument_inner {
1058                        Ti::Vector {
1059                            scalar: Sc { kind: Sk::Bool, .. },
1060                            ..
1061                        } => {}
1062                        ref other => {
1063                            log::error!("All/Any of type {other:?}");
1064                            return Err(ExpressionError::InvalidBooleanVector(argument));
1065                        }
1066                    },
1067                    Rf::IsNan | Rf::IsInf => match *argument_inner {
1068                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1069                            if scalar.kind == Sk::Float => {}
1070                        ref other => {
1071                            log::error!("Float test of type {other:?}");
1072                            return Err(ExpressionError::InvalidFloatArgument(argument));
1073                        }
1074                    },
1075                }
1076                ShaderStages::all()
1077            }
1078            E::Math {
1079                fun,
1080                arg,
1081                arg1,
1082                arg2,
1083                arg3,
1084            } => {
1085                if matches!(
1086                    fun,
1087                    crate::MathFunction::QuantizeToF16
1088                        | crate::MathFunction::Pack2x16float
1089                        | crate::MathFunction::Unpack2x16float
1090                ) && !self
1091                    .capabilities
1092                    .contains(crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32)
1093                {
1094                    return Err(ExpressionError::MissingCapabilities(
1095                        crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32,
1096                    ));
1097                }
1098
1099                let actuals: &[_] = match (arg1, arg2, arg3) {
1100                    (None, None, None) => &[arg],
1101                    (Some(arg1), None, None) => &[arg, arg1],
1102                    (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1103                    (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1104                    _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1105                };
1106
1107                let resolve = |arg| &resolver[arg];
1108                let actual_types: &[_] = match *actuals {
1109                    [arg0] => &[resolve(arg0)],
1110                    [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1111                    [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1112                    [arg0, arg1, arg2, arg3] => {
1113                        &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1114                    }
1115                    _ => unreachable!(),
1116                };
1117
1118                // Start with the set of all overloads available for `fun`.
1119                let mut overloads = fun.overloads();
1120                log::debug!(
1121                    "initial overloads for {:?}: {:#?}",
1122                    fun,
1123                    overloads.for_debug(&module.types)
1124                );
1125
1126                // If any argument is not a constant expression, then no
1127                // overloads that accept abstract values should be considered.
1128                // `OverloadSet::concrete_only` is supposed to help impose this
1129                // restriction. However, no `MathFunction` accepts a mix of
1130                // abstract and concrete arguments, so we don't need to worry
1131                // about that here.
1132
1133                for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1134                    // Remove overloads that cannot accept an `i`'th
1135                    // argument arguments of type `ty`.
1136                    overloads = overloads.arg(i, ty, &module.types);
1137                    log::debug!(
1138                        "overloads after arg {i}: {:#?}",
1139                        overloads.for_debug(&module.types)
1140                    );
1141
1142                    if overloads.is_empty() {
1143                        log::debug!("all overloads eliminated");
1144                        return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1145                    }
1146                }
1147
1148                if actuals.len() < overloads.min_arguments() {
1149                    return Err(ExpressionError::WrongArgumentCount(fun));
1150                }
1151
1152                ShaderStages::all()
1153            }
1154            E::As {
1155                expr,
1156                kind,
1157                convert,
1158            } => {
1159                let mut base_scalar = match resolver[expr] {
1160                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1161                        scalar
1162                    }
1163                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1164                    _ => return Err(ExpressionError::InvalidCastArgument),
1165                };
1166                base_scalar.kind = kind;
1167                if let Some(width) = convert {
1168                    base_scalar.width = width;
1169                }
1170                if self.check_width(base_scalar).is_err() {
1171                    return Err(ExpressionError::InvalidCastArgument);
1172                }
1173                ShaderStages::all()
1174            }
1175            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1176            E::AtomicResult { .. } => {
1177                // These expressions are validated when we check the `Atomic` statement
1178                // that refers to them, because we have all the information we need at
1179                // that point. The checks driven by `Validator::needs_visit` ensure
1180                // that this expression is indeed visited by one `Atomic` statement.
1181                ShaderStages::all()
1182            }
1183            E::WorkGroupUniformLoadResult { ty } => {
1184                if self.types[ty.index()]
1185                    .flags
1186                    // Sized | Constructible is exactly the types currently supported by
1187                    // WorkGroupUniformLoad
1188                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1189                {
1190                    ShaderStages::COMPUTE_LIKE
1191                } else {
1192                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1193                }
1194            }
1195            E::ArrayLength(expr) => match resolver[expr] {
1196                Ti::Pointer { base, .. } => {
1197                    let base_ty = &resolver.types[base];
1198                    if let Ti::Array {
1199                        size: crate::ArraySize::Dynamic,
1200                        ..
1201                    } = base_ty.inner
1202                    {
1203                        ShaderStages::all()
1204                    } else {
1205                        return Err(ExpressionError::InvalidArrayType(expr));
1206                    }
1207                }
1208                ref other => {
1209                    log::error!("Array length of {other:?}");
1210                    return Err(ExpressionError::InvalidArrayType(expr));
1211                }
1212            },
1213            E::RayQueryProceedResult => ShaderStages::all(),
1214            E::RayQueryGetIntersection {
1215                query,
1216                committed: _,
1217            } => match resolver[query] {
1218                Ti::Pointer {
1219                    base,
1220                    space: crate::AddressSpace::Function,
1221                } => match resolver.types[base].inner {
1222                    Ti::RayQuery { .. } => ShaderStages::all(),
1223                    ref other => {
1224                        log::error!("Intersection result of a pointer to {other:?}");
1225                        return Err(ExpressionError::InvalidRayQueryType(query));
1226                    }
1227                },
1228                ref other => {
1229                    log::error!("Intersection result of {other:?}");
1230                    return Err(ExpressionError::InvalidRayQueryType(query));
1231                }
1232            },
1233            E::RayQueryVertexPositions {
1234                query,
1235                committed: _,
1236            } => match resolver[query] {
1237                Ti::Pointer {
1238                    base,
1239                    space: crate::AddressSpace::Function,
1240                } => match resolver.types[base].inner {
1241                    Ti::RayQuery {
1242                        vertex_return: true,
1243                    } => ShaderStages::all(),
1244                    ref other => {
1245                        log::error!("Intersection result of a pointer to {other:?}");
1246                        return Err(ExpressionError::InvalidRayQueryType(query));
1247                    }
1248                },
1249                ref other => {
1250                    log::error!("Intersection result of {other:?}");
1251                    return Err(ExpressionError::InvalidRayQueryType(query));
1252                }
1253            },
1254            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1255            E::CooperativeLoad { ref data, .. } => {
1256                if resolver[data.pointer]
1257                    .pointer_base_type()
1258                    .and_then(|tr| tr.inner_with(&module.types).scalar())
1259                    .is_none()
1260                {
1261                    return Err(ExpressionError::InvalidPointerType(data.pointer));
1262                }
1263                ShaderStages::COMPUTE
1264            }
1265            E::CooperativeMultiplyAdd { a, b, c } => {
1266                let roles = [
1267                    crate::CooperativeRole::A,
1268                    crate::CooperativeRole::B,
1269                    crate::CooperativeRole::C,
1270                ];
1271                for (operand, expected_role) in [a, b, c].into_iter().zip(roles) {
1272                    match resolver[operand] {
1273                        Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
1274                        ref other => {
1275                            log::error!("{expected_role:?} operand type: {other:?}");
1276                            return Err(ExpressionError::InvalidCooperativeOperand(a));
1277                        }
1278                    }
1279                }
1280                ShaderStages::COMPUTE
1281            }
1282        };
1283        Ok(stages)
1284    }
1285
1286    fn global_var_ty(
1287        module: &crate::Module,
1288        function: &crate::Function,
1289        expr: Handle<crate::Expression>,
1290    ) -> Result<Handle<crate::Type>, ExpressionError> {
1291        use crate::Expression as Ex;
1292
1293        match function.expressions[expr] {
1294            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1295            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1296            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1297                match function.expressions[base] {
1298                    Ex::GlobalVariable(var_handle) => {
1299                        let array_ty = module.global_variables[var_handle].ty;
1300
1301                        match module.types[array_ty].inner {
1302                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1303                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1304                        }
1305                    }
1306                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1307                }
1308            }
1309            _ => Err(ExpressionError::ExpectedGlobalVariable),
1310        }
1311    }
1312
1313    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1314        let _ = self.check_width(literal.scalar())?;
1315        check_literal_value(literal)?;
1316
1317        Ok(())
1318    }
1319}
1320
1321pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1322    let is_nan = match literal {
1323        crate::Literal::F64(v) => v.is_nan(),
1324        crate::Literal::F32(v) => v.is_nan(),
1325        _ => false,
1326    };
1327    if is_nan {
1328        return Err(LiteralError::NaN);
1329    }
1330
1331    let is_infinite = match literal {
1332        crate::Literal::F64(v) => v.is_infinite(),
1333        crate::Literal::F32(v) => v.is_infinite(),
1334        _ => false,
1335    };
1336    if is_infinite {
1337        return Err(LiteralError::Infinity);
1338    }
1339
1340    Ok(())
1341}
1342
1343#[cfg(test)]
1344/// Validate a module containing the given expression, expecting an error.
1345fn validate_with_expression(
1346    expr: crate::Expression,
1347    caps: super::Capabilities,
1348) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1349    use crate::span::Span;
1350
1351    let mut function = crate::Function::default();
1352    function.expressions.append(expr, Span::default());
1353    function.body.push(
1354        crate::Statement::Emit(function.expressions.range_from(0)),
1355        Span::default(),
1356    );
1357
1358    let mut module = crate::Module::default();
1359    module.functions.append(function, Span::default());
1360
1361    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1362
1363    validator.validate(&module)
1364}
1365
1366#[cfg(test)]
1367/// Validate a module containing the given constant expression, expecting an error.
1368fn validate_with_const_expression(
1369    expr: crate::Expression,
1370    caps: super::Capabilities,
1371) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1372    use crate::span::Span;
1373
1374    let mut module = crate::Module::default();
1375    module.global_expressions.append(expr, Span::default());
1376
1377    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1378
1379    validator.validate(&module)
1380}
1381
1382/// Using F64 in a function's expression arena is forbidden.
1383#[test]
1384fn f64_runtime_literals() {
1385    let result = validate_with_expression(
1386        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1387        super::Capabilities::default(),
1388    );
1389    let error = result.unwrap_err().into_inner();
1390    assert!(matches!(
1391        error,
1392        crate::valid::ValidationError::Function {
1393            source: super::FunctionError::Expression {
1394                source: ExpressionError::Literal(LiteralError::Width(
1395                    super::r#type::WidthError::MissingCapability {
1396                        name: "f64",
1397                        flag: "FLOAT64",
1398                    }
1399                ),),
1400                ..
1401            },
1402            ..
1403        }
1404    ));
1405
1406    let result = validate_with_expression(
1407        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1408        super::Capabilities::default() | super::Capabilities::FLOAT64,
1409    );
1410    assert!(result.is_ok());
1411}
1412
1413/// Using F64 in a module's constant expression arena is forbidden.
1414#[test]
1415fn f64_const_literals() {
1416    let result = validate_with_const_expression(
1417        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1418        super::Capabilities::default(),
1419    );
1420    let error = result.unwrap_err().into_inner();
1421    assert!(matches!(
1422        error,
1423        crate::valid::ValidationError::ConstExpression {
1424            source: ConstExpressionError::Literal(LiteralError::Width(
1425                super::r#type::WidthError::MissingCapability {
1426                    name: "f64",
1427                    flag: "FLOAT64",
1428                }
1429            )),
1430            ..
1431        }
1432    ));
1433
1434    let result = validate_with_const_expression(
1435        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1436        super::Capabilities::default() | super::Capabilities::FLOAT64,
1437    );
1438    assert!(result.is_ok());
1439}