naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4    arena::Handle,
5    proc::OverloadSet as _,
6    proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13    NotInScope,
14    #[error("Base type {0:?} is not compatible with this expression")]
15    InvalidBaseType(Handle<crate::Expression>),
16    #[error("Accessing with index {0:?} can't be done")]
17    InvalidIndexType(Handle<crate::Expression>),
18    #[error("Accessing {0:?} via a negative index is invalid")]
19    NegativeIndex(Handle<crate::Expression>),
20    #[error("Accessing index {1} is out of {0:?} bounds")]
21    IndexOutOfBounds(Handle<crate::Expression>, u32),
22    #[error("Function argument {0:?} doesn't exist")]
23    FunctionArgumentDoesntExist(u32),
24    #[error("Loading of {0:?} can't be done")]
25    InvalidPointerType(Handle<crate::Expression>),
26    #[error("Array length of {0:?} can't be done")]
27    InvalidArrayType(Handle<crate::Expression>),
28    #[error("Get intersection of {0:?} can't be done")]
29    InvalidRayQueryType(Handle<crate::Expression>),
30    #[error("Splatting {0:?} can't be done")]
31    InvalidSplatType(Handle<crate::Expression>),
32    #[error("Swizzling {0:?} can't be done")]
33    InvalidVectorType(Handle<crate::Expression>),
34    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36    #[error(transparent)]
37    Compose(#[from] super::ComposeError),
38    #[error(transparent)]
39    IndexableLength(#[from] IndexableLengthError),
40    #[error("Operation {0:?} can't work with {1:?}")]
41    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42    #[error(
43        "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
44        op,
45        lhs_expr,
46        lhs_type,
47        rhs_expr,
48        rhs_type
49    )]
50    InvalidBinaryOperandTypes {
51        op: crate::BinaryOperator,
52        lhs_expr: Handle<crate::Expression>,
53        lhs_type: crate::TypeInner,
54        rhs_expr: Handle<crate::Expression>,
55        rhs_type: crate::TypeInner,
56    },
57    #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
58    SelectValuesTypeMismatch {
59        accept: crate::TypeInner,
60        reject: crate::TypeInner,
61    },
62    #[error("Expected selection condition to be a boolean value, got {actual:?}")]
63    SelectConditionNotABool { actual: crate::TypeInner },
64    #[error("Relational argument {0:?} is not a boolean vector")]
65    InvalidBooleanVector(Handle<crate::Expression>),
66    #[error("Relational argument {0:?} is not a float")]
67    InvalidFloatArgument(Handle<crate::Expression>),
68    #[error("Type resolution failed")]
69    Type(#[from] ResolveError),
70    #[error("Not a global variable")]
71    ExpectedGlobalVariable,
72    #[error("Not a global variable or a function argument")]
73    ExpectedGlobalOrArgument,
74    #[error("Needs to be an binding array instead of {0:?}")]
75    ExpectedBindingArrayType(Handle<crate::Type>),
76    #[error("Needs to be an image instead of {0:?}")]
77    ExpectedImageType(Handle<crate::Type>),
78    #[error("Needs to be an image instead of {0:?}")]
79    ExpectedSamplerType(Handle<crate::Type>),
80    #[error("Unable to operate on image class {0:?}")]
81    InvalidImageClass(crate::ImageClass),
82    #[error("Image atomics are not supported for storage format {0:?}")]
83    InvalidImageFormat(crate::StorageFormat),
84    #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
85    InvalidImageStorageAccess(crate::StorageAccess),
86    #[error("Derivatives can only be taken from scalar and vector floats")]
87    InvalidDerivative,
88    #[error("Image array index parameter is misplaced")]
89    InvalidImageArrayIndex,
90    #[error("Inappropriate sample or level-of-detail index for texel access")]
91    InvalidImageOtherIndex,
92    #[error("Image array index type of {0:?} is not an integer scalar")]
93    InvalidImageArrayIndexType(Handle<crate::Expression>),
94    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
95    InvalidImageOtherIndexType(Handle<crate::Expression>),
96    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
97    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
98    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
99    ComparisonSamplingMismatch {
100        image: crate::ImageClass,
101        sampler: bool,
102        has_ref: bool,
103    },
104    #[error("Sample offset must be a const-expression")]
105    InvalidSampleOffsetExprType,
106    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
107    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
108    #[error("Depth reference {0:?} is not a scalar float")]
109    InvalidDepthReference(Handle<crate::Expression>),
110    #[error("Depth sample level can only be Auto or Zero")]
111    InvalidDepthSampleLevel,
112    #[error("Gather level can only be Zero")]
113    InvalidGatherLevel,
114    #[error("Gather component {0:?} doesn't exist in the image")]
115    InvalidGatherComponent(crate::SwizzleComponent),
116    #[error("Gather can't be done for image dimension {0:?}")]
117    InvalidGatherDimension(crate::ImageDimension),
118    #[error("Sample level (exact) type {0:?} has an invalid type")]
119    InvalidSampleLevelExactType(Handle<crate::Expression>),
120    #[error("Sample level (bias) type {0:?} is not a scalar float")]
121    InvalidSampleLevelBiasType(Handle<crate::Expression>),
122    #[error("Bias can't be done for image dimension {0:?}")]
123    InvalidSampleLevelBiasDimension(crate::ImageDimension),
124    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
125    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
126    #[error("Clamping sample coordinate to edge is not supported with {0}")]
127    InvalidSampleClampCoordinateToEdge(alloc::string::String),
128    #[error("Unable to cast")]
129    InvalidCastArgument,
130    #[error("Invalid argument count for {0:?}")]
131    WrongArgumentCount(crate::MathFunction),
132    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
133    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
134    #[error(
135        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
136    )]
137    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
138    #[error("Shader requires capability {0:?}")]
139    MissingCapabilities(super::Capabilities),
140    #[error(transparent)]
141    Literal(#[from] LiteralError),
142    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
143    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
144}
145
146#[derive(Clone, Debug, thiserror::Error)]
147#[cfg_attr(test, derive(PartialEq))]
148pub enum ConstExpressionError {
149    #[error("The expression is not a constant or override expression")]
150    NonConstOrOverride,
151    #[error("The expression is not a fully evaluated constant expression")]
152    NonFullyEvaluatedConst,
153    #[error(transparent)]
154    Compose(#[from] super::ComposeError),
155    #[error("Splatting {0:?} can't be done")]
156    InvalidSplatType(Handle<crate::Expression>),
157    #[error("Type resolution failed")]
158    Type(#[from] ResolveError),
159    #[error(transparent)]
160    Literal(#[from] LiteralError),
161    #[error(transparent)]
162    Width(#[from] super::r#type::WidthError),
163}
164
165#[derive(Clone, Debug, thiserror::Error)]
166#[cfg_attr(test, derive(PartialEq))]
167pub enum LiteralError {
168    #[error("Float literal is NaN")]
169    NaN,
170    #[error("Float literal is infinite")]
171    Infinity,
172    #[error(transparent)]
173    Width(#[from] super::r#type::WidthError),
174}
175
176struct ExpressionTypeResolver<'a> {
177    root: Handle<crate::Expression>,
178    types: &'a UniqueArena<crate::Type>,
179    info: &'a FunctionInfo,
180}
181
182impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
183    type Output = crate::TypeInner;
184
185    #[allow(clippy::panic)]
186    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
187        if handle < self.root {
188            self.info[handle].ty.inner_with(self.types)
189        } else {
190            // `Validator::validate_module_handles` should have caught this.
191            panic!(
192                "Depends on {:?}, which has not been processed yet",
193                self.root
194            )
195        }
196    }
197}
198
199impl super::Validator {
200    pub(super) fn validate_const_expression(
201        &self,
202        handle: Handle<crate::Expression>,
203        gctx: crate::proc::GlobalCtx,
204        mod_info: &ModuleInfo,
205        global_expr_kind: &crate::proc::ExpressionKindTracker,
206    ) -> Result<(), ConstExpressionError> {
207        use crate::Expression as E;
208
209        if !global_expr_kind.is_const_or_override(handle) {
210            return Err(ConstExpressionError::NonConstOrOverride);
211        }
212
213        match gctx.global_expressions[handle] {
214            E::Literal(literal) => {
215                self.validate_literal(literal)?;
216            }
217            E::Constant(_) | E::ZeroValue(_) => {}
218            E::Compose { ref components, ty } => {
219                validate_compose(
220                    ty,
221                    gctx,
222                    components.iter().map(|&handle| mod_info[handle].clone()),
223                )?;
224            }
225            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
226                crate::TypeInner::Scalar { .. } => {}
227                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
228            },
229            _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
230                return Err(ConstExpressionError::NonFullyEvaluatedConst)
231            }
232            // the constant evaluator will report errors about override-expressions
233            _ => {}
234        }
235
236        Ok(())
237    }
238
239    #[allow(clippy::too_many_arguments)]
240    pub(super) fn validate_expression(
241        &self,
242        root: Handle<crate::Expression>,
243        expression: &crate::Expression,
244        function: &crate::Function,
245        module: &crate::Module,
246        info: &FunctionInfo,
247        mod_info: &ModuleInfo,
248        expr_kind: &crate::proc::ExpressionKindTracker,
249    ) -> Result<ShaderStages, ExpressionError> {
250        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
251
252        let resolver = ExpressionTypeResolver {
253            root,
254            types: &module.types,
255            info,
256        };
257
258        let stages = match *expression {
259            E::Access { base, index } => {
260                let base_type = &resolver[base];
261                match *base_type {
262                    Ti::Matrix { .. }
263                    | Ti::Vector { .. }
264                    | Ti::Array { .. }
265                    | Ti::Pointer { .. }
266                    | Ti::ValuePointer { size: Some(_), .. }
267                    | Ti::BindingArray { .. } => {}
268                    ref other => {
269                        log::error!("Indexing of {other:?}");
270                        return Err(ExpressionError::InvalidBaseType(base));
271                    }
272                };
273                match resolver[index] {
274                    //TODO: only allow one of these
275                    Ti::Scalar(Sc {
276                        kind: Sk::Sint | Sk::Uint,
277                        ..
278                    }) => {}
279                    ref other => {
280                        log::error!("Indexing by {other:?}");
281                        return Err(ExpressionError::InvalidIndexType(index));
282                    }
283                }
284
285                // If index is const we can do check for non-negative index
286                match module
287                    .to_ctx()
288                    .eval_expr_to_u32_from(index, &function.expressions)
289                {
290                    Ok(value) => {
291                        let length = if self.overrides_resolved {
292                            base_type.indexable_length_resolved(module)
293                        } else {
294                            base_type.indexable_length_pending(module)
295                        }?;
296                        // If we know both the length and the index, we can do the
297                        // bounds check now.
298                        if let crate::proc::IndexableLength::Known(known_length) = length {
299                            if value >= known_length {
300                                return Err(ExpressionError::IndexOutOfBounds(base, value));
301                            }
302                        }
303                    }
304                    Err(crate::proc::U32EvalError::Negative) => {
305                        return Err(ExpressionError::NegativeIndex(base))
306                    }
307                    Err(crate::proc::U32EvalError::NonConst) => {}
308                }
309
310                ShaderStages::all()
311            }
312            E::AccessIndex { base, index } => {
313                fn resolve_index_limit(
314                    module: &crate::Module,
315                    top: Handle<crate::Expression>,
316                    ty: &crate::TypeInner,
317                    top_level: bool,
318                ) -> Result<u32, ExpressionError> {
319                    let limit = match *ty {
320                        Ti::Vector { size, .. }
321                        | Ti::ValuePointer {
322                            size: Some(size), ..
323                        } => size as u32,
324                        Ti::Matrix { columns, .. } => columns as u32,
325                        Ti::Array {
326                            size: crate::ArraySize::Constant(len),
327                            ..
328                        } => len.get(),
329                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
330                        Ti::Pointer { base, .. } if top_level => {
331                            resolve_index_limit(module, top, &module.types[base].inner, false)?
332                        }
333                        Ti::Struct { ref members, .. } => members.len() as u32,
334                        ref other => {
335                            log::error!("Indexing of {other:?}");
336                            return Err(ExpressionError::InvalidBaseType(top));
337                        }
338                    };
339                    Ok(limit)
340                }
341
342                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
343                if index >= limit {
344                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
345                }
346                ShaderStages::all()
347            }
348            E::Splat { size: _, value } => match resolver[value] {
349                Ti::Scalar { .. } => ShaderStages::all(),
350                ref other => {
351                    log::error!("Splat scalar type {other:?}");
352                    return Err(ExpressionError::InvalidSplatType(value));
353                }
354            },
355            E::Swizzle {
356                size,
357                vector,
358                pattern,
359            } => {
360                let vec_size = match resolver[vector] {
361                    Ti::Vector { size: vec_size, .. } => vec_size,
362                    ref other => {
363                        log::error!("Swizzle vector type {other:?}");
364                        return Err(ExpressionError::InvalidVectorType(vector));
365                    }
366                };
367                for &sc in pattern[..size as usize].iter() {
368                    if sc as u8 >= vec_size as u8 {
369                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
370                    }
371                }
372                ShaderStages::all()
373            }
374            E::Literal(literal) => {
375                self.validate_literal(literal)?;
376                ShaderStages::all()
377            }
378            E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
379            E::Compose { ref components, ty } => {
380                validate_compose(
381                    ty,
382                    module.to_ctx(),
383                    components.iter().map(|&handle| info[handle].ty.clone()),
384                )?;
385                ShaderStages::all()
386            }
387            E::FunctionArgument(index) => {
388                if index >= function.arguments.len() as u32 {
389                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
390                }
391                ShaderStages::all()
392            }
393            E::GlobalVariable(_handle) => ShaderStages::all(),
394            E::LocalVariable(_handle) => ShaderStages::all(),
395            E::Load { pointer } => {
396                match resolver[pointer] {
397                    Ti::Pointer { base, .. }
398                        if self.types[base.index()]
399                            .flags
400                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
401                    Ti::ValuePointer { .. } => {}
402                    ref other => {
403                        log::error!("Loading {other:?}");
404                        return Err(ExpressionError::InvalidPointerType(pointer));
405                    }
406                }
407                ShaderStages::all()
408            }
409            E::ImageSample {
410                image,
411                sampler,
412                gather,
413                coordinate,
414                array_index,
415                offset,
416                level,
417                depth_ref,
418                clamp_to_edge,
419            } => {
420                // check the validity of expressions
421                let image_ty = Self::global_var_ty(module, function, image)?;
422                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
423
424                let comparison = match module.types[sampler_ty].inner {
425                    Ti::Sampler { comparison } => comparison,
426                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
427                };
428
429                let (class, dim) = match module.types[image_ty].inner {
430                    Ti::Image {
431                        class,
432                        arrayed,
433                        dim,
434                    } => {
435                        // check the array property
436                        if arrayed != array_index.is_some() {
437                            return Err(ExpressionError::InvalidImageArrayIndex);
438                        }
439                        if let Some(expr) = array_index {
440                            match resolver[expr] {
441                                Ti::Scalar(Sc {
442                                    kind: Sk::Sint | Sk::Uint,
443                                    ..
444                                }) => {}
445                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
446                            }
447                        }
448                        (class, dim)
449                    }
450                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
451                };
452
453                // check sampling and comparison properties
454                let image_depth = match class {
455                    crate::ImageClass::Sampled {
456                        kind: crate::ScalarKind::Float,
457                        multi: false,
458                    } => false,
459                    crate::ImageClass::Sampled {
460                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
461                        multi: false,
462                    } if gather.is_some() => false,
463                    crate::ImageClass::External => false,
464                    crate::ImageClass::Depth { multi: false } => true,
465                    _ => return Err(ExpressionError::InvalidImageClass(class)),
466                };
467                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
468                    return Err(ExpressionError::ComparisonSamplingMismatch {
469                        image: class,
470                        sampler: comparison,
471                        has_ref: depth_ref.is_some(),
472                    });
473                }
474
475                // check texture coordinates type
476                let num_components = match dim {
477                    crate::ImageDimension::D1 => 1,
478                    crate::ImageDimension::D2 => 2,
479                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
480                };
481                match resolver[coordinate] {
482                    Ti::Scalar(Sc {
483                        kind: Sk::Float, ..
484                    }) if num_components == 1 => {}
485                    Ti::Vector {
486                        size,
487                        scalar:
488                            Sc {
489                                kind: Sk::Float, ..
490                            },
491                    } if size as u32 == num_components => {}
492                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
493                }
494
495                // check constant offset
496                if let Some(const_expr) = offset {
497                    if !expr_kind.is_const(const_expr) {
498                        return Err(ExpressionError::InvalidSampleOffsetExprType);
499                    }
500
501                    match resolver[const_expr] {
502                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
503                        Ti::Vector {
504                            size,
505                            scalar: Sc { kind: Sk::Sint, .. },
506                        } if size as u32 == num_components => {}
507                        _ => {
508                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
509                        }
510                    }
511                }
512
513                // check depth reference type
514                if let Some(expr) = depth_ref {
515                    match resolver[expr] {
516                        Ti::Scalar(Sc {
517                            kind: Sk::Float, ..
518                        }) => {}
519                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
520                    }
521                    match level {
522                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
523                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
524                    }
525                }
526
527                if let Some(component) = gather {
528                    match dim {
529                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
530                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
531                            return Err(ExpressionError::InvalidGatherDimension(dim))
532                        }
533                    };
534                    let max_component = match class {
535                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
536                        _ => crate::SwizzleComponent::W,
537                    };
538                    if component > max_component {
539                        return Err(ExpressionError::InvalidGatherComponent(component));
540                    }
541                    match level {
542                        crate::SampleLevel::Zero => {}
543                        _ => return Err(ExpressionError::InvalidGatherLevel),
544                    }
545                }
546
547                // Clamping coordinate to edge is only supported with 2d non-arrayed, sampled images
548                // when sampling from level Zero without any offset, gather, or depth comparison.
549                if clamp_to_edge {
550                    if !matches!(
551                        class,
552                        crate::ImageClass::Sampled {
553                            kind: crate::ScalarKind::Float,
554                            multi: false
555                        } | crate::ImageClass::External
556                    ) {
557                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
558                            alloc::format!("image class `{class:?}`"),
559                        ));
560                    }
561                    if dim != crate::ImageDimension::D2 {
562                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
563                            alloc::format!("image dimension `{dim:?}`"),
564                        ));
565                    }
566                    if gather.is_some() {
567                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
568                            "gather".into(),
569                        ));
570                    }
571                    if array_index.is_some() {
572                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
573                            "array index".into(),
574                        ));
575                    }
576                    if offset.is_some() {
577                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
578                            "offset".into(),
579                        ));
580                    }
581                    if level != crate::SampleLevel::Zero {
582                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
583                            "non-zero level".into(),
584                        ));
585                    }
586                    if depth_ref.is_some() {
587                        return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
588                            "depth comparison".into(),
589                        ));
590                    }
591                }
592
593                // External textures can only be sampled using clamp_to_edge.
594                if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
595                    return Err(ExpressionError::InvalidImageClass(class));
596                }
597
598                // check level properties
599                match level {
600                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
601                    crate::SampleLevel::Zero => ShaderStages::all(),
602                    crate::SampleLevel::Exact(expr) => {
603                        match class {
604                            crate::ImageClass::Depth { .. } => match resolver[expr] {
605                                Ti::Scalar(Sc {
606                                    kind: Sk::Sint | Sk::Uint,
607                                    ..
608                                }) => {}
609                                _ => {
610                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
611                                }
612                            },
613                            _ => match resolver[expr] {
614                                Ti::Scalar(Sc {
615                                    kind: Sk::Float, ..
616                                }) => {}
617                                _ => {
618                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
619                                }
620                            },
621                        }
622                        ShaderStages::all()
623                    }
624                    crate::SampleLevel::Bias(expr) => {
625                        match resolver[expr] {
626                            Ti::Scalar(Sc {
627                                kind: Sk::Float, ..
628                            }) => {}
629                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
630                        }
631                        match class {
632                            crate::ImageClass::Sampled {
633                                kind: Sk::Float,
634                                multi: false,
635                            } => {
636                                if dim == crate::ImageDimension::D1 {
637                                    return Err(ExpressionError::InvalidSampleLevelBiasDimension(
638                                        dim,
639                                    ));
640                                }
641                            }
642                            _ => return Err(ExpressionError::InvalidImageClass(class)),
643                        }
644                        ShaderStages::FRAGMENT
645                    }
646                    crate::SampleLevel::Gradient { x, y } => {
647                        match resolver[x] {
648                            Ti::Scalar(Sc {
649                                kind: Sk::Float, ..
650                            }) if num_components == 1 => {}
651                            Ti::Vector {
652                                size,
653                                scalar:
654                                    Sc {
655                                        kind: Sk::Float, ..
656                                    },
657                            } if size as u32 == num_components => {}
658                            _ => {
659                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
660                            }
661                        }
662                        match resolver[y] {
663                            Ti::Scalar(Sc {
664                                kind: Sk::Float, ..
665                            }) if num_components == 1 => {}
666                            Ti::Vector {
667                                size,
668                                scalar:
669                                    Sc {
670                                        kind: Sk::Float, ..
671                                    },
672                            } if size as u32 == num_components => {}
673                            _ => {
674                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
675                            }
676                        }
677                        ShaderStages::all()
678                    }
679                }
680            }
681            E::ImageLoad {
682                image,
683                coordinate,
684                array_index,
685                sample,
686                level,
687            } => {
688                let ty = Self::global_var_ty(module, function, image)?;
689                let Ti::Image {
690                    class,
691                    arrayed,
692                    dim,
693                } = module.types[ty].inner
694                else {
695                    return Err(ExpressionError::ExpectedImageType(ty));
696                };
697
698                match resolver[coordinate].image_storage_coordinates() {
699                    Some(coord_dim) if coord_dim == dim => {}
700                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
701                };
702                if arrayed != array_index.is_some() {
703                    return Err(ExpressionError::InvalidImageArrayIndex);
704                }
705                if let Some(expr) = array_index {
706                    if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
707                        return Err(ExpressionError::InvalidImageArrayIndexType(expr));
708                    }
709                }
710
711                match (sample, class.is_multisampled()) {
712                    (None, false) => {}
713                    (Some(sample), true) => {
714                        if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
715                            return Err(ExpressionError::InvalidImageOtherIndexType(sample));
716                        }
717                    }
718                    _ => {
719                        return Err(ExpressionError::InvalidImageOtherIndex);
720                    }
721                }
722
723                match (level, class.is_mipmapped()) {
724                    (None, false) => {}
725                    (Some(level), true) => match resolver[level] {
726                        Ti::Scalar(Sc {
727                            kind: Sk::Sint | Sk::Uint,
728                            width: _,
729                        }) => {}
730                        _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
731                    },
732                    _ => {
733                        return Err(ExpressionError::InvalidImageOtherIndex);
734                    }
735                }
736                ShaderStages::all()
737            }
738            E::ImageQuery { image, query } => {
739                let ty = Self::global_var_ty(module, function, image)?;
740                match module.types[ty].inner {
741                    Ti::Image { class, arrayed, .. } => {
742                        let good = match query {
743                            crate::ImageQuery::NumLayers => arrayed,
744                            crate::ImageQuery::Size { level: None } => true,
745                            crate::ImageQuery::Size { level: Some(level) } => {
746                                match resolver[level] {
747                                    Ti::Scalar(Sc::I32 | Sc::U32) => {}
748                                    _ => {
749                                        return Err(ExpressionError::InvalidImageOtherIndexType(
750                                            level,
751                                        ))
752                                    }
753                                }
754                                class.is_mipmapped()
755                            }
756                            crate::ImageQuery::NumLevels => class.is_mipmapped(),
757                            crate::ImageQuery::NumSamples => class.is_multisampled(),
758                        };
759                        if !good {
760                            return Err(ExpressionError::InvalidImageClass(class));
761                        }
762                    }
763                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
764                }
765                ShaderStages::all()
766            }
767            E::Unary { op, expr } => {
768                use crate::UnaryOperator as Uo;
769                let inner = &resolver[expr];
770                match (op, inner.scalar_kind()) {
771                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
772                    | (Uo::LogicalNot, Some(Sk::Bool))
773                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
774                    other => {
775                        log::error!("Op {op:?} kind {other:?}");
776                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
777                    }
778                }
779                ShaderStages::all()
780            }
781            E::Binary { op, left, right } => {
782                use crate::BinaryOperator as Bo;
783                let left_inner = &resolver[left];
784                let right_inner = &resolver[right];
785                let good = match op {
786                    Bo::Add | Bo::Subtract => match *left_inner {
787                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
788                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
789                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
790                        },
791                        Ti::Matrix { .. } => left_inner == right_inner,
792                        _ => false,
793                    },
794                    Bo::Divide | Bo::Modulo => match *left_inner {
795                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
796                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
797                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
798                        },
799                        _ => false,
800                    },
801                    Bo::Multiply => {
802                        let kind_allowed = match left_inner.scalar_kind() {
803                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
804                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
805                        };
806                        let types_match = match (left_inner, right_inner) {
807                            // Straight scalar and mixed scalar/vector.
808                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
809                            | (
810                                &Ti::Vector {
811                                    scalar: scalar1, ..
812                                },
813                                &Ti::Scalar(scalar2),
814                            )
815                            | (
816                                &Ti::Scalar(scalar1),
817                                &Ti::Vector {
818                                    scalar: scalar2, ..
819                                },
820                            ) => scalar1 == scalar2,
821                            // Scalar/matrix.
822                            (
823                                &Ti::Scalar(Sc {
824                                    kind: Sk::Float, ..
825                                }),
826                                &Ti::Matrix { .. },
827                            )
828                            | (
829                                &Ti::Matrix { .. },
830                                &Ti::Scalar(Sc {
831                                    kind: Sk::Float, ..
832                                }),
833                            ) => true,
834                            // Vector/vector.
835                            (
836                                &Ti::Vector {
837                                    size: size1,
838                                    scalar: scalar1,
839                                },
840                                &Ti::Vector {
841                                    size: size2,
842                                    scalar: scalar2,
843                                },
844                            ) => scalar1 == scalar2 && size1 == size2,
845                            // Matrix * vector.
846                            (
847                                &Ti::Matrix { columns, .. },
848                                &Ti::Vector {
849                                    size,
850                                    scalar:
851                                        Sc {
852                                            kind: Sk::Float, ..
853                                        },
854                                },
855                            ) => columns == size,
856                            // Vector * matrix.
857                            (
858                                &Ti::Vector {
859                                    size,
860                                    scalar:
861                                        Sc {
862                                            kind: Sk::Float, ..
863                                        },
864                                },
865                                &Ti::Matrix { rows, .. },
866                            ) => size == rows,
867                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
868                                columns == rows
869                            }
870                            _ => false,
871                        };
872                        let left_width = left_inner.scalar_width().unwrap_or(0);
873                        let right_width = right_inner.scalar_width().unwrap_or(0);
874                        kind_allowed && types_match && left_width == right_width
875                    }
876                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
877                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
878                        match *left_inner {
879                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
880                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
881                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
882                            },
883                            ref other => {
884                                log::error!("Op {op:?} left type {other:?}");
885                                false
886                            }
887                        }
888                    }
889                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
890                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
891                        | Ti::Vector {
892                            scalar: Sc { kind: Sk::Bool, .. },
893                            ..
894                        } => left_inner == right_inner,
895                        ref other => {
896                            log::error!("Op {op:?} left type {other:?}");
897                            false
898                        }
899                    },
900                    Bo::And | Bo::InclusiveOr => match *left_inner {
901                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
902                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
903                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
904                        },
905                        ref other => {
906                            log::error!("Op {op:?} left type {other:?}");
907                            false
908                        }
909                    },
910                    Bo::ExclusiveOr => match *left_inner {
911                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
912                            Sk::Sint | Sk::Uint => left_inner == right_inner,
913                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
914                        },
915                        ref other => {
916                            log::error!("Op {op:?} left type {other:?}");
917                            false
918                        }
919                    },
920                    Bo::ShiftLeft | Bo::ShiftRight => {
921                        let (base_size, base_scalar) = match *left_inner {
922                            Ti::Scalar(scalar) => (Ok(None), scalar),
923                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
924                            ref other => {
925                                log::error!("Op {op:?} base type {other:?}");
926                                (Err(()), Sc::BOOL)
927                            }
928                        };
929                        let shift_size = match *right_inner {
930                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
931                            Ti::Vector {
932                                size,
933                                scalar: Sc { kind: Sk::Uint, .. },
934                            } => Ok(Some(size)),
935                            ref other => {
936                                log::error!("Op {op:?} shift type {other:?}");
937                                Err(())
938                            }
939                        };
940                        match base_scalar.kind {
941                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
942                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
943                        }
944                    }
945                };
946                if !good {
947                    log::error!(
948                        "Left: {:?} of type {:?}",
949                        function.expressions[left],
950                        left_inner
951                    );
952                    log::error!(
953                        "Right: {:?} of type {:?}",
954                        function.expressions[right],
955                        right_inner
956                    );
957                    return Err(ExpressionError::InvalidBinaryOperandTypes {
958                        op,
959                        lhs_expr: left,
960                        lhs_type: left_inner.clone(),
961                        rhs_expr: right,
962                        rhs_type: right_inner.clone(),
963                    });
964                }
965                ShaderStages::all()
966            }
967            E::Select {
968                condition,
969                accept,
970                reject,
971            } => {
972                let accept_inner = &resolver[accept];
973                let reject_inner = &resolver[reject];
974                let condition_ty = &resolver[condition];
975                let condition_good = match *condition_ty {
976                    Ti::Scalar(Sc {
977                        kind: Sk::Bool,
978                        width: _,
979                    }) => {
980                        // When `condition` is a single boolean, `accept` and
981                        // `reject` can be vectors or scalars.
982                        match *accept_inner {
983                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
984                            _ => false,
985                        }
986                    }
987                    Ti::Vector {
988                        size,
989                        scalar:
990                            Sc {
991                                kind: Sk::Bool,
992                                width: _,
993                            },
994                    } => match *accept_inner {
995                        Ti::Vector {
996                            size: other_size, ..
997                        } => size == other_size,
998                        _ => false,
999                    },
1000                    _ => false,
1001                };
1002                if accept_inner != reject_inner {
1003                    return Err(ExpressionError::SelectValuesTypeMismatch {
1004                        accept: accept_inner.clone(),
1005                        reject: reject_inner.clone(),
1006                    });
1007                }
1008                if !condition_good {
1009                    return Err(ExpressionError::SelectConditionNotABool {
1010                        actual: condition_ty.clone(),
1011                    });
1012                }
1013                ShaderStages::all()
1014            }
1015            E::Derivative { expr, .. } => {
1016                match resolver[expr] {
1017                    Ti::Scalar(Sc {
1018                        kind: Sk::Float, ..
1019                    })
1020                    | Ti::Vector {
1021                        scalar:
1022                            Sc {
1023                                kind: Sk::Float, ..
1024                            },
1025                        ..
1026                    } => {}
1027                    _ => return Err(ExpressionError::InvalidDerivative),
1028                }
1029                ShaderStages::FRAGMENT
1030            }
1031            E::Relational { fun, argument } => {
1032                use crate::RelationalFunction as Rf;
1033                let argument_inner = &resolver[argument];
1034                match fun {
1035                    Rf::All | Rf::Any => match *argument_inner {
1036                        Ti::Vector {
1037                            scalar: Sc { kind: Sk::Bool, .. },
1038                            ..
1039                        } => {}
1040                        ref other => {
1041                            log::error!("All/Any of type {other:?}");
1042                            return Err(ExpressionError::InvalidBooleanVector(argument));
1043                        }
1044                    },
1045                    Rf::IsNan | Rf::IsInf => match *argument_inner {
1046                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1047                            if scalar.kind == Sk::Float => {}
1048                        ref other => {
1049                            log::error!("Float test of type {other:?}");
1050                            return Err(ExpressionError::InvalidFloatArgument(argument));
1051                        }
1052                    },
1053                }
1054                ShaderStages::all()
1055            }
1056            E::Math {
1057                fun,
1058                arg,
1059                arg1,
1060                arg2,
1061                arg3,
1062            } => {
1063                let actuals: &[_] = match (arg1, arg2, arg3) {
1064                    (None, None, None) => &[arg],
1065                    (Some(arg1), None, None) => &[arg, arg1],
1066                    (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1067                    (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1068                    _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1069                };
1070
1071                let resolve = |arg| &resolver[arg];
1072                let actual_types: &[_] = match *actuals {
1073                    [arg0] => &[resolve(arg0)],
1074                    [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1075                    [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1076                    [arg0, arg1, arg2, arg3] => {
1077                        &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1078                    }
1079                    _ => unreachable!(),
1080                };
1081
1082                // Start with the set of all overloads available for `fun`.
1083                let mut overloads = fun.overloads();
1084                log::debug!(
1085                    "initial overloads for {:?}: {:#?}",
1086                    fun,
1087                    overloads.for_debug(&module.types)
1088                );
1089
1090                // If any argument is not a constant expression, then no
1091                // overloads that accept abstract values should be considered.
1092                // `OverloadSet::concrete_only` is supposed to help impose this
1093                // restriction. However, no `MathFunction` accepts a mix of
1094                // abstract and concrete arguments, so we don't need to worry
1095                // about that here.
1096
1097                for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1098                    // Remove overloads that cannot accept an `i`'th
1099                    // argument arguments of type `ty`.
1100                    overloads = overloads.arg(i, ty, &module.types);
1101                    log::debug!(
1102                        "overloads after arg {i}: {:#?}",
1103                        overloads.for_debug(&module.types)
1104                    );
1105
1106                    if overloads.is_empty() {
1107                        log::debug!("all overloads eliminated");
1108                        return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1109                    }
1110                }
1111
1112                if actuals.len() < overloads.min_arguments() {
1113                    return Err(ExpressionError::WrongArgumentCount(fun));
1114                }
1115
1116                ShaderStages::all()
1117            }
1118            E::As {
1119                expr,
1120                kind,
1121                convert,
1122            } => {
1123                let mut base_scalar = match resolver[expr] {
1124                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1125                        scalar
1126                    }
1127                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1128                    _ => return Err(ExpressionError::InvalidCastArgument),
1129                };
1130                base_scalar.kind = kind;
1131                if let Some(width) = convert {
1132                    base_scalar.width = width;
1133                }
1134                if self.check_width(base_scalar).is_err() {
1135                    return Err(ExpressionError::InvalidCastArgument);
1136                }
1137                ShaderStages::all()
1138            }
1139            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1140            E::AtomicResult { .. } => {
1141                // These expressions are validated when we check the `Atomic` statement
1142                // that refers to them, because we have all the information we need at
1143                // that point. The checks driven by `Validator::needs_visit` ensure
1144                // that this expression is indeed visited by one `Atomic` statement.
1145                ShaderStages::all()
1146            }
1147            E::WorkGroupUniformLoadResult { ty } => {
1148                if self.types[ty.index()]
1149                    .flags
1150                    // Sized | Constructible is exactly the types currently supported by
1151                    // WorkGroupUniformLoad
1152                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1153                {
1154                    ShaderStages::COMPUTE
1155                } else {
1156                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1157                }
1158            }
1159            E::ArrayLength(expr) => match resolver[expr] {
1160                Ti::Pointer { base, .. } => {
1161                    let base_ty = &resolver.types[base];
1162                    if let Ti::Array {
1163                        size: crate::ArraySize::Dynamic,
1164                        ..
1165                    } = base_ty.inner
1166                    {
1167                        ShaderStages::all()
1168                    } else {
1169                        return Err(ExpressionError::InvalidArrayType(expr));
1170                    }
1171                }
1172                ref other => {
1173                    log::error!("Array length of {other:?}");
1174                    return Err(ExpressionError::InvalidArrayType(expr));
1175                }
1176            },
1177            E::RayQueryProceedResult => ShaderStages::all(),
1178            E::RayQueryGetIntersection {
1179                query,
1180                committed: _,
1181            } => match resolver[query] {
1182                Ti::Pointer {
1183                    base,
1184                    space: crate::AddressSpace::Function,
1185                } => match resolver.types[base].inner {
1186                    Ti::RayQuery { .. } => ShaderStages::all(),
1187                    ref other => {
1188                        log::error!("Intersection result of a pointer to {other:?}");
1189                        return Err(ExpressionError::InvalidRayQueryType(query));
1190                    }
1191                },
1192                ref other => {
1193                    log::error!("Intersection result of {other:?}");
1194                    return Err(ExpressionError::InvalidRayQueryType(query));
1195                }
1196            },
1197            E::RayQueryVertexPositions {
1198                query,
1199                committed: _,
1200            } => match resolver[query] {
1201                Ti::Pointer {
1202                    base,
1203                    space: crate::AddressSpace::Function,
1204                } => match resolver.types[base].inner {
1205                    Ti::RayQuery {
1206                        vertex_return: true,
1207                    } => ShaderStages::all(),
1208                    ref other => {
1209                        log::error!("Intersection result of a pointer to {other:?}");
1210                        return Err(ExpressionError::InvalidRayQueryType(query));
1211                    }
1212                },
1213                ref other => {
1214                    log::error!("Intersection result of {other:?}");
1215                    return Err(ExpressionError::InvalidRayQueryType(query));
1216                }
1217            },
1218            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1219        };
1220        Ok(stages)
1221    }
1222
1223    fn global_var_ty(
1224        module: &crate::Module,
1225        function: &crate::Function,
1226        expr: Handle<crate::Expression>,
1227    ) -> Result<Handle<crate::Type>, ExpressionError> {
1228        use crate::Expression as Ex;
1229
1230        match function.expressions[expr] {
1231            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1232            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1233            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1234                match function.expressions[base] {
1235                    Ex::GlobalVariable(var_handle) => {
1236                        let array_ty = module.global_variables[var_handle].ty;
1237
1238                        match module.types[array_ty].inner {
1239                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1240                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1241                        }
1242                    }
1243                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1244                }
1245            }
1246            _ => Err(ExpressionError::ExpectedGlobalVariable),
1247        }
1248    }
1249
1250    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1251        let _ = self.check_width(literal.scalar())?;
1252        check_literal_value(literal)?;
1253
1254        Ok(())
1255    }
1256}
1257
1258pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1259    let is_nan = match literal {
1260        crate::Literal::F64(v) => v.is_nan(),
1261        crate::Literal::F32(v) => v.is_nan(),
1262        _ => false,
1263    };
1264    if is_nan {
1265        return Err(LiteralError::NaN);
1266    }
1267
1268    let is_infinite = match literal {
1269        crate::Literal::F64(v) => v.is_infinite(),
1270        crate::Literal::F32(v) => v.is_infinite(),
1271        _ => false,
1272    };
1273    if is_infinite {
1274        return Err(LiteralError::Infinity);
1275    }
1276
1277    Ok(())
1278}
1279
1280#[cfg(test)]
1281/// Validate a module containing the given expression, expecting an error.
1282fn validate_with_expression(
1283    expr: crate::Expression,
1284    caps: super::Capabilities,
1285) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1286    use crate::span::Span;
1287
1288    let mut function = crate::Function::default();
1289    function.expressions.append(expr, Span::default());
1290    function.body.push(
1291        crate::Statement::Emit(function.expressions.range_from(0)),
1292        Span::default(),
1293    );
1294
1295    let mut module = crate::Module::default();
1296    module.functions.append(function, Span::default());
1297
1298    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1299
1300    validator.validate(&module)
1301}
1302
1303#[cfg(test)]
1304/// Validate a module containing the given constant expression, expecting an error.
1305fn validate_with_const_expression(
1306    expr: crate::Expression,
1307    caps: super::Capabilities,
1308) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1309    use crate::span::Span;
1310
1311    let mut module = crate::Module::default();
1312    module.global_expressions.append(expr, Span::default());
1313
1314    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1315
1316    validator.validate(&module)
1317}
1318
1319/// Using F64 in a function's expression arena is forbidden.
1320#[test]
1321fn f64_runtime_literals() {
1322    let result = validate_with_expression(
1323        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1324        super::Capabilities::default(),
1325    );
1326    let error = result.unwrap_err().into_inner();
1327    assert!(matches!(
1328        error,
1329        crate::valid::ValidationError::Function {
1330            source: super::FunctionError::Expression {
1331                source: ExpressionError::Literal(LiteralError::Width(
1332                    super::r#type::WidthError::MissingCapability {
1333                        name: "f64",
1334                        flag: "FLOAT64",
1335                    }
1336                ),),
1337                ..
1338            },
1339            ..
1340        }
1341    ));
1342
1343    let result = validate_with_expression(
1344        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1345        super::Capabilities::default() | super::Capabilities::FLOAT64,
1346    );
1347    assert!(result.is_ok());
1348}
1349
1350/// Using F64 in a module's constant expression arena is forbidden.
1351#[test]
1352fn f64_const_literals() {
1353    let result = validate_with_const_expression(
1354        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1355        super::Capabilities::default(),
1356    );
1357    let error = result.unwrap_err().into_inner();
1358    assert!(matches!(
1359        error,
1360        crate::valid::ValidationError::ConstExpression {
1361            source: ConstExpressionError::Literal(LiteralError::Width(
1362                super::r#type::WidthError::MissingCapability {
1363                    name: "f64",
1364                    flag: "FLOAT64",
1365                }
1366            )),
1367            ..
1368        }
1369    ));
1370
1371    let result = validate_with_const_expression(
1372        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1373        super::Capabilities::default() | super::Capabilities::FLOAT64,
1374    );
1375    assert!(result.is_ok());
1376}