naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19    arena::{Handle, HandleSet},
20    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21    FastHashSet,
22};
23
24//TODO: analyze the model at the same time as we validate it,
25// merge the corresponding matches over expressions and statements.
26
27use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38/// Maximum size of a type, in bytes.
39pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; // 1GB
40
41bitflags::bitflags! {
42    /// Validation flags.
43    ///
44    /// If you are working with trusted shaders, then you may be able
45    /// to save some time by skipping validation.
46    ///
47    /// If you do not perform full validation, invalid shaders may
48    /// cause Naga to panic. If you do perform full validation and
49    /// [`Validator::validate`] returns `Ok`, then Naga promises that
50    /// code generation will either succeed or return an error; it
51    /// should never panic.
52    ///
53    /// The default value for `ValidationFlags` is
54    /// `ValidationFlags::all()`.
55    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58    pub struct ValidationFlags: u8 {
59        /// Expressions.
60        const EXPRESSIONS = 0x1;
61        /// Statements and blocks of them.
62        const BLOCKS = 0x2;
63        /// Uniformity of control flow for operations that require it.
64        const CONTROL_FLOW_UNIFORMITY = 0x4;
65        /// Host-shareable structure layouts.
66        const STRUCT_LAYOUTS = 0x8;
67        /// Constants.
68        const CONSTANTS = 0x10;
69        /// Group, binding, and location attributes.
70        const BINDINGS = 0x20;
71    }
72}
73
74impl Default for ValidationFlags {
75    fn default() -> Self {
76        Self::all()
77    }
78}
79
80bitflags::bitflags! {
81    /// Allowed IR capabilities.
82    #[must_use]
83    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86    pub struct Capabilities: u32 {
87        /// Support for [`AddressSpace::PushConstant`][1].
88        ///
89        /// [1]: crate::AddressSpace::PushConstant
90        const PUSH_CONSTANT = 1 << 0;
91        /// Float values with width = 8.
92        const FLOAT64 = 1 << 1;
93        /// Support for [`BuiltIn::PrimitiveIndex`][1].
94        ///
95        /// [1]: crate::BuiltIn::PrimitiveIndex
96        const PRIMITIVE_INDEX = 1 << 2;
97        /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
98        const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99        /// Support for non-uniform indexing of storage texture arrays.
100        const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101        /// Support for non-uniform indexing of uniform buffer arrays.
102        const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103        /// Support for non-uniform indexing of samplers.
104        const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105        /// Support for [`BuiltIn::ClipDistance`].
106        ///
107        /// [`BuiltIn::ClipDistance`]: crate::BuiltIn::ClipDistance
108        const CLIP_DISTANCE = 1 << 7;
109        /// Support for [`BuiltIn::CullDistance`].
110        ///
111        /// [`BuiltIn::CullDistance`]: crate::BuiltIn::CullDistance
112        const CULL_DISTANCE = 1 << 8;
113        /// Support for 16-bit normalized storage texture formats.
114        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115        /// Support for [`BuiltIn::ViewIndex`].
116        ///
117        /// [`BuiltIn::ViewIndex`]: crate::BuiltIn::ViewIndex
118        const MULTIVIEW = 1 << 10;
119        /// Support for `early_depth_test`.
120        const EARLY_DEPTH_TEST = 1 << 11;
121        /// Support for [`BuiltIn::SampleIndex`] and [`Sampling::Sample`].
122        ///
123        /// [`BuiltIn::SampleIndex`]: crate::BuiltIn::SampleIndex
124        /// [`Sampling::Sample`]: crate::Sampling::Sample
125        const MULTISAMPLED_SHADING = 1 << 12;
126        /// Support for ray queries and acceleration structures.
127        const RAY_QUERY = 1 << 13;
128        /// Support for generating two sources for blending from fragment shaders.
129        const DUAL_SOURCE_BLENDING = 1 << 14;
130        /// Support for arrayed cube textures.
131        const CUBE_ARRAY_TEXTURES = 1 << 15;
132        /// Support for 64-bit signed and unsigned integers.
133        const SHADER_INT64 = 1 << 16;
134        /// Support for subgroup operations (except barriers) in fragment and compute shaders.
135        ///
136        /// Subgroup operations in the vertex stage require
137        /// [`Capabilities::SUBGROUP_VERTEX_STAGE`] in addition to `Capabilities::SUBGROUP`.
138        /// (But note that `create_validator` automatically sets
139        /// `Capabilities::SUBGROUP` whenever `Features::SUBGROUP_VERTEX` is
140        /// available.)
141        ///
142        /// Subgroup barriers require [`Capabilities::SUBGROUP_BARRIER`] in addition to
143        /// `Capabilities::SUBGROUP`.
144        const SUBGROUP = 1 << 17;
145        /// Support for subgroup barriers in compute shaders.
146        ///
147        /// Requires [`Capabilities::SUBGROUP`]. Without it, enables nothing.
148        const SUBGROUP_BARRIER = 1 << 18;
149        /// Support for subgroup operations (not including barriers) in the vertex stage.
150        ///
151        /// Without [`Capabilities::SUBGROUP`], enables nothing. (But note that
152        /// `create_validator` automatically sets `Capabilities::SUBGROUP`
153        /// whenever `Features::SUBGROUP_VERTEX` is available.)
154        const SUBGROUP_VERTEX_STAGE = 1 << 19;
155        /// Support for [`AtomicFunction::Min`] and [`AtomicFunction::Max`] on
156        /// 64-bit integers in the [`Storage`] address space, when the return
157        /// value is not used.
158        ///
159        /// This is the only 64-bit atomic functionality available on Metal 3.1.
160        ///
161        /// [`AtomicFunction::Min`]: crate::AtomicFunction::Min
162        /// [`AtomicFunction::Max`]: crate::AtomicFunction::Max
163        /// [`Storage`]: crate::AddressSpace::Storage
164        const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
165        /// Support for all atomic operations on 64-bit integers.
166        const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
167        /// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`],
168        /// and [`AtomicFunction::Exchange { compare: None }`] on 32-bit floating-point numbers
169        /// in the [`Storage`] address space.
170        ///
171        /// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
172        /// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
173        /// [`AtomicFunction::Exchange { compare: None }`]: crate::AtomicFunction::Exchange
174        /// [`Storage`]: crate::AddressSpace::Storage
175        const SHADER_FLOAT32_ATOMIC = 1 << 22;
176        /// Support for atomic operations on images.
177        const TEXTURE_ATOMIC = 1 << 23;
178        /// Support for atomic operations on 64-bit images.
179        const TEXTURE_INT64_ATOMIC = 1 << 24;
180        /// Support for ray queries returning vertex position
181        const RAY_HIT_VERTEX_POSITION = 1 << 25;
182        /// Support for 16-bit floating-point types.
183        const SHADER_FLOAT16 = 1 << 26;
184        /// Support for [`ImageClass::External`]
185        const TEXTURE_EXTERNAL = 1 << 27;
186        /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
187        /// `f16`-precision values in `f32`s.
188        const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189        /// Support for fragment shader barycentric coordinates.
190        const SHADER_BARYCENTRICS = 1 << 29;
191        /// Support for task shaders, mesh shaders, and per-primitive fragment inputs
192        const MESH_SHADER = 1 << 30;
193        /// Support for mesh shaders which output points.
194        const MESH_SHADER_POINT_TOPOLOGY = 1 << 30;
195    }
196}
197
198impl Capabilities {
199    /// Returns the extension corresponding to this capability, if there is one.
200    ///
201    /// This is used by integration tests.
202    #[cfg(feature = "wgsl-in")]
203    #[doc(hidden)]
204    pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
205        use crate::front::wgsl::ImplementedEnableExtension as Ext;
206        match *self {
207            Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
208            // NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
209            Self::SHADER_FLOAT16 => Some(Ext::F16),
210            Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
211            Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
212            Self::RAY_HIT_VERTEX_POSITION => Some(Ext::WgpuRayQueryVertexReturn),
213            _ => None,
214        }
215    }
216}
217
218impl Default for Capabilities {
219    fn default() -> Self {
220        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
221    }
222}
223
224bitflags::bitflags! {
225    /// Supported subgroup operations
226    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
227    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
228    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
229    pub struct SubgroupOperationSet: u8 {
230        /// Barriers
231        // Possibly elections, when that is supported.
232        // https://github.com/gfx-rs/wgpu/issues/6042#issuecomment-3272603431
233        // Contrary to what the name "basic" suggests, HLSL/DX12 support the
234        // other subgroup operations, but do not support subgroup barriers.
235        const BASIC = 1 << 0;
236        /// Any, All
237        const VOTE = 1 << 1;
238        /// reductions, scans
239        const ARITHMETIC = 1 << 2;
240        /// ballot, broadcast
241        const BALLOT = 1 << 3;
242        /// shuffle, shuffle xor
243        const SHUFFLE = 1 << 4;
244        /// shuffle up, down
245        const SHUFFLE_RELATIVE = 1 << 5;
246        // We don't support these operations yet
247        // /// Clustered
248        // const CLUSTERED = 1 << 6;
249        /// Quad supported
250        const QUAD_FRAGMENT_COMPUTE = 1 << 7;
251        // /// Quad supported in all stages
252        // const QUAD_ALL_STAGES = 1 << 8;
253    }
254}
255
256impl super::SubgroupOperation {
257    const fn required_operations(&self) -> SubgroupOperationSet {
258        use SubgroupOperationSet as S;
259        match *self {
260            Self::All | Self::Any => S::VOTE,
261            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
262                S::ARITHMETIC
263            }
264        }
265    }
266}
267
268impl super::GatherMode {
269    const fn required_operations(&self) -> SubgroupOperationSet {
270        use SubgroupOperationSet as S;
271        match *self {
272            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
273            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
274            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
275            Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
276        }
277    }
278}
279
280bitflags::bitflags! {
281    /// Validation flags.
282    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
283    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
284    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
285    pub struct ShaderStages: u8 {
286        const VERTEX = 0x1;
287        const FRAGMENT = 0x2;
288        const COMPUTE = 0x4;
289        const MESH = 0x8;
290        const TASK = 0x10;
291    }
292}
293
294#[derive(Debug, Clone, Default)]
295#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
296#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
297pub struct ModuleInfo {
298    type_flags: Vec<TypeFlags>,
299    functions: Vec<FunctionInfo>,
300    entry_points: Vec<FunctionInfo>,
301    const_expression_types: Box<[TypeResolution]>,
302}
303
304impl ops::Index<Handle<crate::Type>> for ModuleInfo {
305    type Output = TypeFlags;
306    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
307        &self.type_flags[handle.index()]
308    }
309}
310
311impl ops::Index<Handle<crate::Function>> for ModuleInfo {
312    type Output = FunctionInfo;
313    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
314        &self.functions[handle.index()]
315    }
316}
317
318impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
319    type Output = TypeResolution;
320    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
321        &self.const_expression_types[handle.index()]
322    }
323}
324
325#[derive(Debug)]
326pub struct Validator {
327    flags: ValidationFlags,
328    capabilities: Capabilities,
329    subgroup_stages: ShaderStages,
330    subgroup_operations: SubgroupOperationSet,
331    types: Vec<r#type::TypeInfo>,
332    layouter: Layouter,
333    location_mask: BitSet,
334    blend_src_mask: BitSet,
335    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
336    #[allow(dead_code)]
337    switch_values: FastHashSet<crate::SwitchValue>,
338    valid_expression_list: Vec<Handle<crate::Expression>>,
339    valid_expression_set: HandleSet<crate::Expression>,
340    override_ids: FastHashSet<u16>,
341
342    /// Treat overrides whose initializers are not fully-evaluated
343    /// constant expressions as errors.
344    overrides_resolved: bool,
345
346    /// A checklist of expressions that must be visited by a specific kind of
347    /// statement.
348    ///
349    /// For example:
350    ///
351    /// - [`CallResult`] expressions must be visited by a [`Call`] statement.
352    /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
353    ///
354    /// Be sure not to remove any [`Expression`] handle from this set unless
355    /// you've explicitly checked that it is the right kind of expression for
356    /// the visiting [`Statement`].
357    ///
358    /// [`CallResult`]: crate::Expression::CallResult
359    /// [`Call`]: crate::Statement::Call
360    /// [`AtomicResult`]: crate::Expression::AtomicResult
361    /// [`Atomic`]: crate::Statement::Atomic
362    /// [`Expression`]: crate::Expression
363    /// [`Statement`]: crate::Statement
364    needs_visit: HandleSet<crate::Expression>,
365}
366
367#[derive(Clone, Debug, thiserror::Error)]
368#[cfg_attr(test, derive(PartialEq))]
369pub enum ConstantError {
370    #[error("Initializer must be a const-expression")]
371    InitializerExprType,
372    #[error("The type doesn't match the constant")]
373    InvalidType,
374    #[error("The type is not constructible")]
375    NonConstructibleType,
376}
377
378#[derive(Clone, Debug, thiserror::Error)]
379#[cfg_attr(test, derive(PartialEq))]
380pub enum OverrideError {
381    #[error("Override name and ID are missing")]
382    MissingNameAndID,
383    #[error("Override ID must be unique")]
384    DuplicateID,
385    #[error("Initializer must be a const-expression or override-expression")]
386    InitializerExprType,
387    #[error("The type doesn't match the override")]
388    InvalidType,
389    #[error("The type is not constructible")]
390    NonConstructibleType,
391    #[error("The type is not a scalar")]
392    TypeNotScalar,
393    #[error("Override declarations are not allowed")]
394    NotAllowed,
395    #[error("Override is uninitialized")]
396    UninitializedOverride,
397    #[error("Constant expression {handle:?} is invalid")]
398    ConstExpression {
399        handle: Handle<crate::Expression>,
400        source: ConstExpressionError,
401    },
402}
403
404#[derive(Clone, Debug, thiserror::Error)]
405#[cfg_attr(test, derive(PartialEq))]
406pub enum ValidationError {
407    #[error(transparent)]
408    InvalidHandle(#[from] InvalidHandleError),
409    #[error(transparent)]
410    Layouter(#[from] LayoutError),
411    #[error("Type {handle:?} '{name}' is invalid")]
412    Type {
413        handle: Handle<crate::Type>,
414        name: String,
415        source: TypeError,
416    },
417    #[error("Constant expression {handle:?} is invalid")]
418    ConstExpression {
419        handle: Handle<crate::Expression>,
420        source: ConstExpressionError,
421    },
422    #[error("Array size expression {handle:?} is not strictly positive")]
423    ArraySizeError { handle: Handle<crate::Expression> },
424    #[error("Constant {handle:?} '{name}' is invalid")]
425    Constant {
426        handle: Handle<crate::Constant>,
427        name: String,
428        source: ConstantError,
429    },
430    #[error("Override {handle:?} '{name}' is invalid")]
431    Override {
432        handle: Handle<crate::Override>,
433        name: String,
434        source: OverrideError,
435    },
436    #[error("Global variable {handle:?} '{name}' is invalid")]
437    GlobalVariable {
438        handle: Handle<crate::GlobalVariable>,
439        name: String,
440        source: GlobalVariableError,
441    },
442    #[error("Function {handle:?} '{name}' is invalid")]
443    Function {
444        handle: Handle<crate::Function>,
445        name: String,
446        source: FunctionError,
447    },
448    #[error("Entry point {name} at {stage:?} is invalid")]
449    EntryPoint {
450        stage: crate::ShaderStage,
451        name: String,
452        source: EntryPointError,
453    },
454    #[error("Module is corrupted")]
455    Corrupted,
456}
457
458impl crate::TypeInner {
459    const fn is_sized(&self) -> bool {
460        match *self {
461            Self::Scalar { .. }
462            | Self::Vector { .. }
463            | Self::Matrix { .. }
464            | Self::Array {
465                size: crate::ArraySize::Constant(_),
466                ..
467            }
468            | Self::Atomic { .. }
469            | Self::Pointer { .. }
470            | Self::ValuePointer { .. }
471            | Self::Struct { .. } => true,
472            Self::Array { .. }
473            | Self::Image { .. }
474            | Self::Sampler { .. }
475            | Self::AccelerationStructure { .. }
476            | Self::RayQuery { .. }
477            | Self::BindingArray { .. } => false,
478        }
479    }
480
481    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
482    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
483        match *self {
484            Self::Scalar(crate::Scalar {
485                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
486                ..
487            }) => Some(crate::ImageDimension::D1),
488            Self::Vector {
489                size: crate::VectorSize::Bi,
490                scalar:
491                    crate::Scalar {
492                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
493                        ..
494                    },
495            } => Some(crate::ImageDimension::D2),
496            Self::Vector {
497                size: crate::VectorSize::Tri,
498                scalar:
499                    crate::Scalar {
500                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
501                        ..
502                    },
503            } => Some(crate::ImageDimension::D3),
504            _ => None,
505        }
506    }
507}
508
509impl Validator {
510    /// Create a validator for Naga [`Module`]s.
511    ///
512    /// The `flags` argument indicates which stages of validation the
513    /// returned `Validator` should perform. Skipping stages can make
514    /// validation somewhat faster, but the validator may not reject some
515    /// invalid modules. Regardless of `flags`, validation always returns
516    /// a usable [`ModuleInfo`] value on success.
517    ///
518    /// If `flags` contains everything in `ValidationFlags::default()`,
519    /// then the returned Naga [`Validator`] will reject any [`Module`]
520    /// that would use capabilities not included in `capabilities`.
521    ///
522    /// [`Module`]: crate::Module
523    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
524        let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
525            use SubgroupOperationSet as S;
526            S::BASIC
527                | S::VOTE
528                | S::ARITHMETIC
529                | S::BALLOT
530                | S::SHUFFLE
531                | S::SHUFFLE_RELATIVE
532                | S::QUAD_FRAGMENT_COMPUTE
533        } else {
534            SubgroupOperationSet::empty()
535        };
536        let subgroup_stages = {
537            let mut stages = ShaderStages::empty();
538            if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
539                stages |= ShaderStages::VERTEX;
540            }
541            if capabilities.contains(Capabilities::SUBGROUP) {
542                stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
543            }
544            stages
545        };
546
547        Validator {
548            flags,
549            capabilities,
550            subgroup_stages,
551            subgroup_operations,
552            types: Vec::new(),
553            layouter: Layouter::default(),
554            location_mask: BitSet::new(),
555            blend_src_mask: BitSet::new(),
556            ep_resource_bindings: FastHashSet::default(),
557            switch_values: FastHashSet::default(),
558            valid_expression_list: Vec::new(),
559            valid_expression_set: HandleSet::new(),
560            override_ids: FastHashSet::default(),
561            overrides_resolved: false,
562            needs_visit: HandleSet::new(),
563        }
564    }
565
566    pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
567        self.subgroup_stages = stages;
568        self
569    }
570
571    pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
572        self.subgroup_operations = operations;
573        self
574    }
575
576    /// Reset the validator internals
577    pub fn reset(&mut self) {
578        self.types.clear();
579        self.layouter.clear();
580        self.location_mask.clear();
581        self.blend_src_mask.clear();
582        self.ep_resource_bindings.clear();
583        self.switch_values.clear();
584        self.valid_expression_list.clear();
585        self.valid_expression_set.clear();
586        self.override_ids.clear();
587    }
588
589    fn validate_constant(
590        &self,
591        handle: Handle<crate::Constant>,
592        gctx: crate::proc::GlobalCtx,
593        mod_info: &ModuleInfo,
594        global_expr_kind: &ExpressionKindTracker,
595    ) -> Result<(), ConstantError> {
596        let con = &gctx.constants[handle];
597
598        let type_info = &self.types[con.ty.index()];
599        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
600            return Err(ConstantError::NonConstructibleType);
601        }
602
603        if !global_expr_kind.is_const(con.init) {
604            return Err(ConstantError::InitializerExprType);
605        }
606
607        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
608            return Err(ConstantError::InvalidType);
609        }
610
611        Ok(())
612    }
613
614    fn validate_override(
615        &mut self,
616        handle: Handle<crate::Override>,
617        gctx: crate::proc::GlobalCtx,
618        mod_info: &ModuleInfo,
619    ) -> Result<(), OverrideError> {
620        let o = &gctx.overrides[handle];
621
622        if let Some(id) = o.id {
623            if !self.override_ids.insert(id) {
624                return Err(OverrideError::DuplicateID);
625            }
626        }
627
628        let type_info = &self.types[o.ty.index()];
629        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
630            return Err(OverrideError::NonConstructibleType);
631        }
632
633        match gctx.types[o.ty].inner {
634            crate::TypeInner::Scalar(
635                crate::Scalar::BOOL
636                | crate::Scalar::I32
637                | crate::Scalar::U32
638                | crate::Scalar::F16
639                | crate::Scalar::F32
640                | crate::Scalar::F64,
641            ) => {}
642            _ => return Err(OverrideError::TypeNotScalar),
643        }
644
645        if let Some(init) = o.init {
646            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
647                return Err(OverrideError::InvalidType);
648            }
649        } else if self.overrides_resolved {
650            return Err(OverrideError::UninitializedOverride);
651        }
652
653        Ok(())
654    }
655
656    /// Check the given module to be valid.
657    pub fn validate(
658        &mut self,
659        module: &crate::Module,
660    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
661        self.overrides_resolved = false;
662        self.validate_impl(module)
663    }
664
665    /// Check the given module to be valid, requiring overrides to be resolved.
666    ///
667    /// This is the same as [`validate`], except that any override
668    /// whose value is not a fully-evaluated constant expression is
669    /// treated as an error.
670    ///
671    /// [`validate`]: Validator::validate
672    pub fn validate_resolved_overrides(
673        &mut self,
674        module: &crate::Module,
675    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
676        self.overrides_resolved = true;
677        self.validate_impl(module)
678    }
679
680    fn validate_impl(
681        &mut self,
682        module: &crate::Module,
683    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
684        self.reset();
685        self.reset_types(module.types.len());
686
687        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
688
689        self.layouter.update(module.to_ctx()).map_err(|e| {
690            let handle = e.ty;
691            ValidationError::from(e).with_span_handle(handle, &module.types)
692        })?;
693
694        // These should all get overwritten.
695        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
696            kind: crate::ScalarKind::Bool,
697            width: 0,
698        }));
699
700        let mut mod_info = ModuleInfo {
701            type_flags: Vec::with_capacity(module.types.len()),
702            functions: Vec::with_capacity(module.functions.len()),
703            entry_points: Vec::with_capacity(module.entry_points.len()),
704            const_expression_types: vec![placeholder; module.global_expressions.len()]
705                .into_boxed_slice(),
706        };
707
708        for (handle, ty) in module.types.iter() {
709            let ty_info = self
710                .validate_type(handle, module.to_ctx())
711                .map_err(|source| {
712                    ValidationError::Type {
713                        handle,
714                        name: ty.name.clone().unwrap_or_default(),
715                        source,
716                    }
717                    .with_span_handle(handle, &module.types)
718                })?;
719            mod_info.type_flags.push(ty_info.flags);
720            self.types[handle.index()] = ty_info;
721        }
722
723        {
724            let t = crate::Arena::new();
725            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
726            for (handle, _) in module.global_expressions.iter() {
727                mod_info
728                    .process_const_expression(handle, &resolve_context, module.to_ctx())
729                    .map_err(|source| {
730                        ValidationError::ConstExpression { handle, source }
731                            .with_span_handle(handle, &module.global_expressions)
732                    })?
733            }
734        }
735
736        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
737
738        if self.flags.contains(ValidationFlags::CONSTANTS) {
739            for (handle, _) in module.global_expressions.iter() {
740                self.validate_const_expression(
741                    handle,
742                    module.to_ctx(),
743                    &mod_info,
744                    &global_expr_kind,
745                )
746                .map_err(|source| {
747                    ValidationError::ConstExpression { handle, source }
748                        .with_span_handle(handle, &module.global_expressions)
749                })?
750            }
751
752            for (handle, constant) in module.constants.iter() {
753                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
754                    .map_err(|source| {
755                        ValidationError::Constant {
756                            handle,
757                            name: constant.name.clone().unwrap_or_default(),
758                            source,
759                        }
760                        .with_span_handle(handle, &module.constants)
761                    })?
762            }
763
764            for (handle, r#override) in module.overrides.iter() {
765                self.validate_override(handle, module.to_ctx(), &mod_info)
766                    .map_err(|source| {
767                        ValidationError::Override {
768                            handle,
769                            name: r#override.name.clone().unwrap_or_default(),
770                            source,
771                        }
772                        .with_span_handle(handle, &module.overrides)
773                    })?;
774            }
775        }
776
777        for (var_handle, var) in module.global_variables.iter() {
778            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
779                .map_err(|source| {
780                    ValidationError::GlobalVariable {
781                        handle: var_handle,
782                        name: var.name.clone().unwrap_or_default(),
783                        source,
784                    }
785                    .with_span_handle(var_handle, &module.global_variables)
786                })?;
787        }
788
789        for (handle, fun) in module.functions.iter() {
790            match self.validate_function(fun, module, &mod_info, false) {
791                Ok(info) => mod_info.functions.push(info),
792                Err(error) => {
793                    return Err(error.and_then(|source| {
794                        ValidationError::Function {
795                            handle,
796                            name: fun.name.clone().unwrap_or_default(),
797                            source,
798                        }
799                        .with_span_handle(handle, &module.functions)
800                    }))
801                }
802            }
803        }
804
805        let mut ep_map = FastHashSet::default();
806        for ep in module.entry_points.iter() {
807            if !ep_map.insert((ep.stage, &ep.name)) {
808                return Err(ValidationError::EntryPoint {
809                    stage: ep.stage,
810                    name: ep.name.clone(),
811                    source: EntryPointError::Conflict,
812                }
813                .with_span()); // TODO: keep some EP span information?
814            }
815
816            match self.validate_entry_point(ep, module, &mod_info) {
817                Ok(info) => mod_info.entry_points.push(info),
818                Err(error) => {
819                    return Err(error.and_then(|source| {
820                        ValidationError::EntryPoint {
821                            stage: ep.stage,
822                            name: ep.name.clone(),
823                            source,
824                        }
825                        .with_span()
826                    }));
827                }
828            }
829        }
830
831        Ok(mod_info)
832    }
833}
834
835fn validate_atomic_compare_exchange_struct(
836    types: &crate::UniqueArena<crate::Type>,
837    members: &[crate::StructMember],
838    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
839) -> bool {
840    members.len() == 2
841        && members[0].name.as_deref() == Some("old_value")
842        && scalar_predicate(&types[members[0].ty].inner)
843        && members[1].name.as_deref() == Some("exchanged")
844        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
845}