naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19    arena::{Handle, HandleSet},
20    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21    FastHashSet,
22};
23
24//TODO: analyze the model at the same time as we validate it,
25// merge the corresponding matches over expressions and statements.
26
27use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38/// Maximum size of a type, in bytes.
39pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; // 1GB
40
41bitflags::bitflags! {
42    /// Validation flags.
43    ///
44    /// If you are working with trusted shaders, then you may be able
45    /// to save some time by skipping validation.
46    ///
47    /// If you do not perform full validation, invalid shaders may
48    /// cause Naga to panic. If you do perform full validation and
49    /// [`Validator::validate`] returns `Ok`, then Naga promises that
50    /// code generation will either succeed or return an error; it
51    /// should never panic.
52    ///
53    /// The default value for `ValidationFlags` is
54    /// `ValidationFlags::all()`.
55    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58    pub struct ValidationFlags: u8 {
59        /// Expressions.
60        const EXPRESSIONS = 0x1;
61        /// Statements and blocks of them.
62        const BLOCKS = 0x2;
63        /// Uniformity of control flow for operations that require it.
64        const CONTROL_FLOW_UNIFORMITY = 0x4;
65        /// Host-shareable structure layouts.
66        const STRUCT_LAYOUTS = 0x8;
67        /// Constants.
68        const CONSTANTS = 0x10;
69        /// Group, binding, and location attributes.
70        const BINDINGS = 0x20;
71    }
72}
73
74impl Default for ValidationFlags {
75    fn default() -> Self {
76        Self::all()
77    }
78}
79
80bitflags::bitflags! {
81    /// Allowed IR capabilities.
82    #[must_use]
83    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86    pub struct Capabilities: u32 {
87        /// Support for [`AddressSpace::PushConstant`][1].
88        ///
89        /// [1]: crate::AddressSpace::PushConstant
90        const PUSH_CONSTANT = 1 << 0;
91        /// Float values with width = 8.
92        const FLOAT64 = 1 << 1;
93        /// Support for [`BuiltIn::PrimitiveIndex`][1].
94        ///
95        /// [1]: crate::BuiltIn::PrimitiveIndex
96        const PRIMITIVE_INDEX = 1 << 2;
97        /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
98        const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99        /// Support for non-uniform indexing of storage texture arrays.
100        const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101        /// Support for non-uniform indexing of uniform buffer arrays.
102        const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103        /// Support for non-uniform indexing of samplers.
104        const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105        /// Support for [`BuiltIn::ClipDistance`].
106        ///
107        /// [`BuiltIn::ClipDistance`]: crate::BuiltIn::ClipDistance
108        const CLIP_DISTANCE = 1 << 7;
109        /// Support for [`BuiltIn::CullDistance`].
110        ///
111        /// [`BuiltIn::CullDistance`]: crate::BuiltIn::CullDistance
112        const CULL_DISTANCE = 1 << 8;
113        /// Support for 16-bit normalized storage texture formats.
114        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115        /// Support for [`BuiltIn::ViewIndex`].
116        ///
117        /// [`BuiltIn::ViewIndex`]: crate::BuiltIn::ViewIndex
118        const MULTIVIEW = 1 << 10;
119        /// Support for `early_depth_test`.
120        const EARLY_DEPTH_TEST = 1 << 11;
121        /// Support for [`BuiltIn::SampleIndex`] and [`Sampling::Sample`].
122        ///
123        /// [`BuiltIn::SampleIndex`]: crate::BuiltIn::SampleIndex
124        /// [`Sampling::Sample`]: crate::Sampling::Sample
125        const MULTISAMPLED_SHADING = 1 << 12;
126        /// Support for ray queries and acceleration structures.
127        const RAY_QUERY = 1 << 13;
128        /// Support for generating two sources for blending from fragment shaders.
129        const DUAL_SOURCE_BLENDING = 1 << 14;
130        /// Support for arrayed cube textures.
131        const CUBE_ARRAY_TEXTURES = 1 << 15;
132        /// Support for 64-bit signed and unsigned integers.
133        const SHADER_INT64 = 1 << 16;
134        /// Support for subgroup operations.
135        /// Implies support for subgroup operations in both fragment and compute stages,
136        /// but not necessarily in the vertex stage, which requires [`Capabilities::SUBGROUP_VERTEX_STAGE`].
137        const SUBGROUP = 1 << 17;
138        /// Support for subgroup barriers.
139        const SUBGROUP_BARRIER = 1 << 18;
140        /// Support for subgroup operations in the vertex stage.
141        const SUBGROUP_VERTEX_STAGE = 1 << 19;
142        /// Support for [`AtomicFunction::Min`] and [`AtomicFunction::Max`] on
143        /// 64-bit integers in the [`Storage`] address space, when the return
144        /// value is not used.
145        ///
146        /// This is the only 64-bit atomic functionality available on Metal 3.1.
147        ///
148        /// [`AtomicFunction::Min`]: crate::AtomicFunction::Min
149        /// [`AtomicFunction::Max`]: crate::AtomicFunction::Max
150        /// [`Storage`]: crate::AddressSpace::Storage
151        const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
152        /// Support for all atomic operations on 64-bit integers.
153        const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
154        /// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`],
155        /// and [`AtomicFunction::Exchange { compare: None }`] on 32-bit floating-point numbers
156        /// in the [`Storage`] address space.
157        ///
158        /// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
159        /// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
160        /// [`AtomicFunction::Exchange { compare: None }`]: crate::AtomicFunction::Exchange
161        /// [`Storage`]: crate::AddressSpace::Storage
162        const SHADER_FLOAT32_ATOMIC = 1 << 22;
163        /// Support for atomic operations on images.
164        const TEXTURE_ATOMIC = 1 << 23;
165        /// Support for atomic operations on 64-bit images.
166        const TEXTURE_INT64_ATOMIC = 1 << 24;
167        /// Support for ray queries returning vertex position
168        const RAY_HIT_VERTEX_POSITION = 1 << 25;
169        /// Support for 16-bit floating-point types.
170        const SHADER_FLOAT16 = 1 << 26;
171        /// Support for [`ImageClass::External`]
172        const TEXTURE_EXTERNAL = 1 << 27;
173        /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
174        /// `f16`-precision values in `f32`s.
175        const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
176    }
177}
178
179impl Capabilities {
180    /// Returns the extension corresponding to this capability, if there is one.
181    ///
182    /// This is used by integration tests.
183    #[cfg(feature = "wgsl-in")]
184    #[doc(hidden)]
185    pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
186        use crate::front::wgsl::ImplementedEnableExtension as Ext;
187        match *self {
188            Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
189            // NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
190            Self::SHADER_FLOAT16 => Some(Ext::F16),
191            Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
192            _ => None,
193        }
194    }
195}
196
197impl Default for Capabilities {
198    fn default() -> Self {
199        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
200    }
201}
202
203bitflags::bitflags! {
204    /// Supported subgroup operations
205    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
206    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
207    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
208    pub struct SubgroupOperationSet: u8 {
209        /// Elect, Barrier
210        const BASIC = 1 << 0;
211        /// Any, All
212        const VOTE = 1 << 1;
213        /// reductions, scans
214        const ARITHMETIC = 1 << 2;
215        /// ballot, broadcast
216        const BALLOT = 1 << 3;
217        /// shuffle, shuffle xor
218        const SHUFFLE = 1 << 4;
219        /// shuffle up, down
220        const SHUFFLE_RELATIVE = 1 << 5;
221        // We don't support these operations yet
222        // /// Clustered
223        // const CLUSTERED = 1 << 6;
224        /// Quad supported
225        const QUAD_FRAGMENT_COMPUTE = 1 << 7;
226        // /// Quad supported in all stages
227        // const QUAD_ALL_STAGES = 1 << 8;
228    }
229}
230
231impl super::SubgroupOperation {
232    const fn required_operations(&self) -> SubgroupOperationSet {
233        use SubgroupOperationSet as S;
234        match *self {
235            Self::All | Self::Any => S::VOTE,
236            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
237                S::ARITHMETIC
238            }
239        }
240    }
241}
242
243impl super::GatherMode {
244    const fn required_operations(&self) -> SubgroupOperationSet {
245        use SubgroupOperationSet as S;
246        match *self {
247            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
248            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
249            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
250            Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
251        }
252    }
253}
254
255bitflags::bitflags! {
256    /// Validation flags.
257    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
258    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
259    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
260    pub struct ShaderStages: u8 {
261        const VERTEX = 0x1;
262        const FRAGMENT = 0x2;
263        const COMPUTE = 0x4;
264    }
265}
266
267#[derive(Debug, Clone, Default)]
268#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
269#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
270pub struct ModuleInfo {
271    type_flags: Vec<TypeFlags>,
272    functions: Vec<FunctionInfo>,
273    entry_points: Vec<FunctionInfo>,
274    const_expression_types: Box<[TypeResolution]>,
275}
276
277impl ops::Index<Handle<crate::Type>> for ModuleInfo {
278    type Output = TypeFlags;
279    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
280        &self.type_flags[handle.index()]
281    }
282}
283
284impl ops::Index<Handle<crate::Function>> for ModuleInfo {
285    type Output = FunctionInfo;
286    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
287        &self.functions[handle.index()]
288    }
289}
290
291impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
292    type Output = TypeResolution;
293    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
294        &self.const_expression_types[handle.index()]
295    }
296}
297
298#[derive(Debug)]
299pub struct Validator {
300    flags: ValidationFlags,
301    capabilities: Capabilities,
302    subgroup_stages: ShaderStages,
303    subgroup_operations: SubgroupOperationSet,
304    types: Vec<r#type::TypeInfo>,
305    layouter: Layouter,
306    location_mask: BitSet,
307    blend_src_mask: BitSet,
308    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
309    #[allow(dead_code)]
310    switch_values: FastHashSet<crate::SwitchValue>,
311    valid_expression_list: Vec<Handle<crate::Expression>>,
312    valid_expression_set: HandleSet<crate::Expression>,
313    override_ids: FastHashSet<u16>,
314
315    /// Treat overrides whose initializers are not fully-evaluated
316    /// constant expressions as errors.
317    overrides_resolved: bool,
318
319    /// A checklist of expressions that must be visited by a specific kind of
320    /// statement.
321    ///
322    /// For example:
323    ///
324    /// - [`CallResult`] expressions must be visited by a [`Call`] statement.
325    /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
326    ///
327    /// Be sure not to remove any [`Expression`] handle from this set unless
328    /// you've explicitly checked that it is the right kind of expression for
329    /// the visiting [`Statement`].
330    ///
331    /// [`CallResult`]: crate::Expression::CallResult
332    /// [`Call`]: crate::Statement::Call
333    /// [`AtomicResult`]: crate::Expression::AtomicResult
334    /// [`Atomic`]: crate::Statement::Atomic
335    /// [`Expression`]: crate::Expression
336    /// [`Statement`]: crate::Statement
337    needs_visit: HandleSet<crate::Expression>,
338}
339
340#[derive(Clone, Debug, thiserror::Error)]
341#[cfg_attr(test, derive(PartialEq))]
342pub enum ConstantError {
343    #[error("Initializer must be a const-expression")]
344    InitializerExprType,
345    #[error("The type doesn't match the constant")]
346    InvalidType,
347    #[error("The type is not constructible")]
348    NonConstructibleType,
349}
350
351#[derive(Clone, Debug, thiserror::Error)]
352#[cfg_attr(test, derive(PartialEq))]
353pub enum OverrideError {
354    #[error("Override name and ID are missing")]
355    MissingNameAndID,
356    #[error("Override ID must be unique")]
357    DuplicateID,
358    #[error("Initializer must be a const-expression or override-expression")]
359    InitializerExprType,
360    #[error("The type doesn't match the override")]
361    InvalidType,
362    #[error("The type is not constructible")]
363    NonConstructibleType,
364    #[error("The type is not a scalar")]
365    TypeNotScalar,
366    #[error("Override declarations are not allowed")]
367    NotAllowed,
368    #[error("Override is uninitialized")]
369    UninitializedOverride,
370    #[error("Constant expression {handle:?} is invalid")]
371    ConstExpression {
372        handle: Handle<crate::Expression>,
373        source: ConstExpressionError,
374    },
375}
376
377#[derive(Clone, Debug, thiserror::Error)]
378#[cfg_attr(test, derive(PartialEq))]
379pub enum ValidationError {
380    #[error(transparent)]
381    InvalidHandle(#[from] InvalidHandleError),
382    #[error(transparent)]
383    Layouter(#[from] LayoutError),
384    #[error("Type {handle:?} '{name}' is invalid")]
385    Type {
386        handle: Handle<crate::Type>,
387        name: String,
388        source: TypeError,
389    },
390    #[error("Constant expression {handle:?} is invalid")]
391    ConstExpression {
392        handle: Handle<crate::Expression>,
393        source: ConstExpressionError,
394    },
395    #[error("Array size expression {handle:?} is not strictly positive")]
396    ArraySizeError { handle: Handle<crate::Expression> },
397    #[error("Constant {handle:?} '{name}' is invalid")]
398    Constant {
399        handle: Handle<crate::Constant>,
400        name: String,
401        source: ConstantError,
402    },
403    #[error("Override {handle:?} '{name}' is invalid")]
404    Override {
405        handle: Handle<crate::Override>,
406        name: String,
407        source: OverrideError,
408    },
409    #[error("Global variable {handle:?} '{name}' is invalid")]
410    GlobalVariable {
411        handle: Handle<crate::GlobalVariable>,
412        name: String,
413        source: GlobalVariableError,
414    },
415    #[error("Function {handle:?} '{name}' is invalid")]
416    Function {
417        handle: Handle<crate::Function>,
418        name: String,
419        source: FunctionError,
420    },
421    #[error("Entry point {name} at {stage:?} is invalid")]
422    EntryPoint {
423        stage: crate::ShaderStage,
424        name: String,
425        source: EntryPointError,
426    },
427    #[error("Module is corrupted")]
428    Corrupted,
429}
430
431impl crate::TypeInner {
432    const fn is_sized(&self) -> bool {
433        match *self {
434            Self::Scalar { .. }
435            | Self::Vector { .. }
436            | Self::Matrix { .. }
437            | Self::Array {
438                size: crate::ArraySize::Constant(_),
439                ..
440            }
441            | Self::Atomic { .. }
442            | Self::Pointer { .. }
443            | Self::ValuePointer { .. }
444            | Self::Struct { .. } => true,
445            Self::Array { .. }
446            | Self::Image { .. }
447            | Self::Sampler { .. }
448            | Self::AccelerationStructure { .. }
449            | Self::RayQuery { .. }
450            | Self::BindingArray { .. } => false,
451        }
452    }
453
454    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
455    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
456        match *self {
457            Self::Scalar(crate::Scalar {
458                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
459                ..
460            }) => Some(crate::ImageDimension::D1),
461            Self::Vector {
462                size: crate::VectorSize::Bi,
463                scalar:
464                    crate::Scalar {
465                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
466                        ..
467                    },
468            } => Some(crate::ImageDimension::D2),
469            Self::Vector {
470                size: crate::VectorSize::Tri,
471                scalar:
472                    crate::Scalar {
473                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
474                        ..
475                    },
476            } => Some(crate::ImageDimension::D3),
477            _ => None,
478        }
479    }
480}
481
482impl Validator {
483    /// Create a validator for Naga [`Module`]s.
484    ///
485    /// The `flags` argument indicates which stages of validation the
486    /// returned `Validator` should perform. Skipping stages can make
487    /// validation somewhat faster, but the validator may not reject some
488    /// invalid modules. Regardless of `flags`, validation always returns
489    /// a usable [`ModuleInfo`] value on success.
490    ///
491    /// If `flags` contains everything in `ValidationFlags::default()`,
492    /// then the returned Naga [`Validator`] will reject any [`Module`]
493    /// that would use capabilities not included in `capabilities`.
494    ///
495    /// [`Module`]: crate::Module
496    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
497        let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
498            use SubgroupOperationSet as S;
499            S::BASIC
500                | S::VOTE
501                | S::ARITHMETIC
502                | S::BALLOT
503                | S::SHUFFLE
504                | S::SHUFFLE_RELATIVE
505                | S::QUAD_FRAGMENT_COMPUTE
506        } else {
507            SubgroupOperationSet::empty()
508        };
509        let subgroup_stages = {
510            let mut stages = ShaderStages::empty();
511            if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
512                stages |= ShaderStages::VERTEX;
513            }
514            if capabilities.contains(Capabilities::SUBGROUP) {
515                stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
516            }
517            stages
518        };
519
520        Validator {
521            flags,
522            capabilities,
523            subgroup_stages,
524            subgroup_operations,
525            types: Vec::new(),
526            layouter: Layouter::default(),
527            location_mask: BitSet::new(),
528            blend_src_mask: BitSet::new(),
529            ep_resource_bindings: FastHashSet::default(),
530            switch_values: FastHashSet::default(),
531            valid_expression_list: Vec::new(),
532            valid_expression_set: HandleSet::new(),
533            override_ids: FastHashSet::default(),
534            overrides_resolved: false,
535            needs_visit: HandleSet::new(),
536        }
537    }
538
539    pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
540        self.subgroup_stages = stages;
541        self
542    }
543
544    pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
545        self.subgroup_operations = operations;
546        self
547    }
548
549    /// Reset the validator internals
550    pub fn reset(&mut self) {
551        self.types.clear();
552        self.layouter.clear();
553        self.location_mask.clear();
554        self.blend_src_mask.clear();
555        self.ep_resource_bindings.clear();
556        self.switch_values.clear();
557        self.valid_expression_list.clear();
558        self.valid_expression_set.clear();
559        self.override_ids.clear();
560    }
561
562    fn validate_constant(
563        &self,
564        handle: Handle<crate::Constant>,
565        gctx: crate::proc::GlobalCtx,
566        mod_info: &ModuleInfo,
567        global_expr_kind: &ExpressionKindTracker,
568    ) -> Result<(), ConstantError> {
569        let con = &gctx.constants[handle];
570
571        let type_info = &self.types[con.ty.index()];
572        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
573            return Err(ConstantError::NonConstructibleType);
574        }
575
576        if !global_expr_kind.is_const(con.init) {
577            return Err(ConstantError::InitializerExprType);
578        }
579
580        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
581            return Err(ConstantError::InvalidType);
582        }
583
584        Ok(())
585    }
586
587    fn validate_override(
588        &mut self,
589        handle: Handle<crate::Override>,
590        gctx: crate::proc::GlobalCtx,
591        mod_info: &ModuleInfo,
592    ) -> Result<(), OverrideError> {
593        let o = &gctx.overrides[handle];
594
595        if let Some(id) = o.id {
596            if !self.override_ids.insert(id) {
597                return Err(OverrideError::DuplicateID);
598            }
599        }
600
601        let type_info = &self.types[o.ty.index()];
602        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
603            return Err(OverrideError::NonConstructibleType);
604        }
605
606        match gctx.types[o.ty].inner {
607            crate::TypeInner::Scalar(
608                crate::Scalar::BOOL
609                | crate::Scalar::I32
610                | crate::Scalar::U32
611                | crate::Scalar::F16
612                | crate::Scalar::F32
613                | crate::Scalar::F64,
614            ) => {}
615            _ => return Err(OverrideError::TypeNotScalar),
616        }
617
618        if let Some(init) = o.init {
619            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
620                return Err(OverrideError::InvalidType);
621            }
622        } else if self.overrides_resolved {
623            return Err(OverrideError::UninitializedOverride);
624        }
625
626        Ok(())
627    }
628
629    /// Check the given module to be valid.
630    pub fn validate(
631        &mut self,
632        module: &crate::Module,
633    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
634        self.overrides_resolved = false;
635        self.validate_impl(module)
636    }
637
638    /// Check the given module to be valid, requiring overrides to be resolved.
639    ///
640    /// This is the same as [`validate`], except that any override
641    /// whose value is not a fully-evaluated constant expression is
642    /// treated as an error.
643    ///
644    /// [`validate`]: Validator::validate
645    pub fn validate_resolved_overrides(
646        &mut self,
647        module: &crate::Module,
648    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
649        self.overrides_resolved = true;
650        self.validate_impl(module)
651    }
652
653    fn validate_impl(
654        &mut self,
655        module: &crate::Module,
656    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
657        self.reset();
658        self.reset_types(module.types.len());
659
660        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
661
662        self.layouter.update(module.to_ctx()).map_err(|e| {
663            let handle = e.ty;
664            ValidationError::from(e).with_span_handle(handle, &module.types)
665        })?;
666
667        // These should all get overwritten.
668        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
669            kind: crate::ScalarKind::Bool,
670            width: 0,
671        }));
672
673        let mut mod_info = ModuleInfo {
674            type_flags: Vec::with_capacity(module.types.len()),
675            functions: Vec::with_capacity(module.functions.len()),
676            entry_points: Vec::with_capacity(module.entry_points.len()),
677            const_expression_types: vec![placeholder; module.global_expressions.len()]
678                .into_boxed_slice(),
679        };
680
681        for (handle, ty) in module.types.iter() {
682            let ty_info = self
683                .validate_type(handle, module.to_ctx())
684                .map_err(|source| {
685                    ValidationError::Type {
686                        handle,
687                        name: ty.name.clone().unwrap_or_default(),
688                        source,
689                    }
690                    .with_span_handle(handle, &module.types)
691                })?;
692            mod_info.type_flags.push(ty_info.flags);
693            self.types[handle.index()] = ty_info;
694        }
695
696        {
697            let t = crate::Arena::new();
698            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
699            for (handle, _) in module.global_expressions.iter() {
700                mod_info
701                    .process_const_expression(handle, &resolve_context, module.to_ctx())
702                    .map_err(|source| {
703                        ValidationError::ConstExpression { handle, source }
704                            .with_span_handle(handle, &module.global_expressions)
705                    })?
706            }
707        }
708
709        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
710
711        if self.flags.contains(ValidationFlags::CONSTANTS) {
712            for (handle, _) in module.global_expressions.iter() {
713                self.validate_const_expression(
714                    handle,
715                    module.to_ctx(),
716                    &mod_info,
717                    &global_expr_kind,
718                )
719                .map_err(|source| {
720                    ValidationError::ConstExpression { handle, source }
721                        .with_span_handle(handle, &module.global_expressions)
722                })?
723            }
724
725            for (handle, constant) in module.constants.iter() {
726                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
727                    .map_err(|source| {
728                        ValidationError::Constant {
729                            handle,
730                            name: constant.name.clone().unwrap_or_default(),
731                            source,
732                        }
733                        .with_span_handle(handle, &module.constants)
734                    })?
735            }
736
737            for (handle, r#override) in module.overrides.iter() {
738                self.validate_override(handle, module.to_ctx(), &mod_info)
739                    .map_err(|source| {
740                        ValidationError::Override {
741                            handle,
742                            name: r#override.name.clone().unwrap_or_default(),
743                            source,
744                        }
745                        .with_span_handle(handle, &module.overrides)
746                    })?;
747            }
748        }
749
750        for (var_handle, var) in module.global_variables.iter() {
751            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
752                .map_err(|source| {
753                    ValidationError::GlobalVariable {
754                        handle: var_handle,
755                        name: var.name.clone().unwrap_or_default(),
756                        source,
757                    }
758                    .with_span_handle(var_handle, &module.global_variables)
759                })?;
760        }
761
762        for (handle, fun) in module.functions.iter() {
763            match self.validate_function(fun, module, &mod_info, false) {
764                Ok(info) => mod_info.functions.push(info),
765                Err(error) => {
766                    return Err(error.and_then(|source| {
767                        ValidationError::Function {
768                            handle,
769                            name: fun.name.clone().unwrap_or_default(),
770                            source,
771                        }
772                        .with_span_handle(handle, &module.functions)
773                    }))
774                }
775            }
776        }
777
778        let mut ep_map = FastHashSet::default();
779        for ep in module.entry_points.iter() {
780            if !ep_map.insert((ep.stage, &ep.name)) {
781                return Err(ValidationError::EntryPoint {
782                    stage: ep.stage,
783                    name: ep.name.clone(),
784                    source: EntryPointError::Conflict,
785                }
786                .with_span()); // TODO: keep some EP span information?
787            }
788
789            match self.validate_entry_point(ep, module, &mod_info) {
790                Ok(info) => mod_info.entry_points.push(info),
791                Err(error) => {
792                    return Err(error.and_then(|source| {
793                        ValidationError::EntryPoint {
794                            stage: ep.stage,
795                            name: ep.name.clone(),
796                            source,
797                        }
798                        .with_span()
799                    }));
800                }
801            }
802        }
803
804        Ok(mod_info)
805    }
806}
807
808fn validate_atomic_compare_exchange_struct(
809    types: &crate::UniqueArena<crate::Type>,
810    members: &[crate::StructMember],
811    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
812) -> bool {
813    members.len() == 2
814        && members[0].name.as_deref() == Some("old_value")
815        && scalar_predicate(&types[members[0].ty].inner)
816        && members[1].name.as_deref() == Some("exchanged")
817        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
818}