naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19    arena::{Handle, HandleSet},
20    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21    FastHashSet,
22};
23
24//TODO: analyze the model at the same time as we validate it,
25// merge the corresponding matches over expressions and statements.
26
27use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::builtin::ZeroValueError;
31pub use expression::{check_literal_value, LiteralError};
32pub use expression::{ConstExpressionError, ExpressionError};
33pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
34pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
35pub use r#type::{Disalignment, ImmediateError, TypeError, TypeFlags, WidthError};
36
37use self::handles::InvalidHandleError;
38
39/// Maximum size of a type, in bytes.
40pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; // 1GB
41
42bitflags::bitflags! {
43    /// Validation flags.
44    ///
45    /// If you are working with trusted shaders, then you may be able
46    /// to save some time by skipping validation.
47    ///
48    /// If you do not perform full validation, invalid shaders may
49    /// cause Naga to panic. If you do perform full validation and
50    /// [`Validator::validate`] returns `Ok`, then Naga promises that
51    /// code generation will either succeed or return an error; it
52    /// should never panic.
53    ///
54    /// The default value for `ValidationFlags` is
55    /// `ValidationFlags::all()`.
56    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
57    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
58    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
59    pub struct ValidationFlags: u8 {
60        /// Expressions.
61        const EXPRESSIONS = 0x1;
62        /// Statements and blocks of them.
63        const BLOCKS = 0x2;
64        /// Uniformity of control flow for operations that require it.
65        const CONTROL_FLOW_UNIFORMITY = 0x4;
66        /// Host-shareable structure layouts.
67        const STRUCT_LAYOUTS = 0x8;
68        /// Constants.
69        const CONSTANTS = 0x10;
70        /// Group, binding, and location attributes.
71        const BINDINGS = 0x20;
72    }
73}
74
75impl Default for ValidationFlags {
76    fn default() -> Self {
77        Self::all()
78    }
79}
80
81bitflags::bitflags! {
82    /// Allowed IR capabilities.
83    #[must_use]
84    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
85    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
86    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
87    pub struct Capabilities: u64 {
88        /// Support for [`AddressSpace::Immediate`][1].
89        ///
90        /// [1]: crate::AddressSpace::Immediate
91        const IMMEDIATES = 1 << 0;
92        /// Float values with width = 8.
93        const FLOAT64 = 1 << 1;
94        /// Support for [`BuiltIn::PrimitiveIndex`][1].
95        ///
96        /// [1]: crate::BuiltIn::PrimitiveIndex
97        const PRIMITIVE_INDEX = 1 << 2;
98        /// Support for binding arrays of sampled textures and samplers.
99        const TEXTURE_AND_SAMPLER_BINDING_ARRAY = 1 << 3;
100        /// Support for binding arrays of uniform buffers.
101        const BUFFER_BINDING_ARRAY = 1 << 4;
102        /// Support for binding arrays of storage textures.
103        const STORAGE_TEXTURE_BINDING_ARRAY = 1 << 5;
104        /// Support for binding arrays of storage buffers.
105        const STORAGE_BUFFER_BINDING_ARRAY = 1 << 6;
106        /// Support for [`BuiltIn::ClipDistance`].
107        ///
108        /// [`BuiltIn::ClipDistance`]: crate::BuiltIn::ClipDistance
109        const CLIP_DISTANCE = 1 << 7;
110        /// Support for [`BuiltIn::CullDistance`].
111        ///
112        /// [`BuiltIn::CullDistance`]: crate::BuiltIn::CullDistance
113        const CULL_DISTANCE = 1 << 8;
114        /// Support for 16-bit normalized storage texture formats.
115        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
116        /// Support for [`BuiltIn::ViewIndex`].
117        ///
118        /// [`BuiltIn::ViewIndex`]: crate::BuiltIn::ViewIndex
119        const MULTIVIEW = 1 << 10;
120        /// Support for `early_depth_test`.
121        const EARLY_DEPTH_TEST = 1 << 11;
122        /// Support for [`BuiltIn::SampleIndex`] and [`Sampling::Sample`].
123        ///
124        /// [`BuiltIn::SampleIndex`]: crate::BuiltIn::SampleIndex
125        /// [`Sampling::Sample`]: crate::Sampling::Sample
126        const MULTISAMPLED_SHADING = 1 << 12;
127        /// Support for ray queries and acceleration structures.
128        const RAY_QUERY = 1 << 13;
129        /// Support for generating two sources for blending from fragment shaders.
130        const DUAL_SOURCE_BLENDING = 1 << 14;
131        /// Support for arrayed cube textures.
132        const CUBE_ARRAY_TEXTURES = 1 << 15;
133        /// Support for 64-bit signed and unsigned integers.
134        const SHADER_INT64 = 1 << 16;
135        /// Support for subgroup operations (except barriers) in fragment and compute shaders.
136        ///
137        /// Subgroup operations in the vertex stage require
138        /// [`Capabilities::SUBGROUP_VERTEX_STAGE`] in addition to `Capabilities::SUBGROUP`.
139        /// (But note that `create_validator` automatically sets
140        /// `Capabilities::SUBGROUP` whenever `Features::SUBGROUP_VERTEX` is
141        /// available.)
142        ///
143        /// Subgroup barriers require [`Capabilities::SUBGROUP_BARRIER`] in addition to
144        /// `Capabilities::SUBGROUP`.
145        const SUBGROUP = 1 << 17;
146        /// Support for subgroup barriers in compute shaders.
147        ///
148        /// Requires [`Capabilities::SUBGROUP`]. Without it, enables nothing.
149        const SUBGROUP_BARRIER = 1 << 18;
150        /// Support for subgroup operations (not including barriers) in the vertex stage.
151        ///
152        /// Without [`Capabilities::SUBGROUP`], enables nothing. (But note that
153        /// `create_validator` automatically sets `Capabilities::SUBGROUP`
154        /// whenever `Features::SUBGROUP_VERTEX` is available.)
155        const SUBGROUP_VERTEX_STAGE = 1 << 19;
156        /// Support for [`AtomicFunction::Min`] and [`AtomicFunction::Max`] on
157        /// 64-bit integers in the [`Storage`] address space, when the return
158        /// value is not used.
159        ///
160        /// This is the only 64-bit atomic functionality available on Metal 3.1.
161        ///
162        /// [`AtomicFunction::Min`]: crate::AtomicFunction::Min
163        /// [`AtomicFunction::Max`]: crate::AtomicFunction::Max
164        /// [`Storage`]: crate::AddressSpace::Storage
165        const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
166        /// Support for all atomic operations on 64-bit integers.
167        const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
168        /// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`],
169        /// and [`AtomicFunction::Exchange { compare: None }`] on 32-bit floating-point numbers
170        /// in the [`Storage`] address space.
171        ///
172        /// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
173        /// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
174        /// [`AtomicFunction::Exchange { compare: None }`]: crate::AtomicFunction::Exchange
175        /// [`Storage`]: crate::AddressSpace::Storage
176        const SHADER_FLOAT32_ATOMIC = 1 << 22;
177        /// Support for atomic operations on images.
178        const TEXTURE_ATOMIC = 1 << 23;
179        /// Support for atomic operations on 64-bit images.
180        const TEXTURE_INT64_ATOMIC = 1 << 24;
181        /// Support for ray queries returning vertex position
182        const RAY_HIT_VERTEX_POSITION = 1 << 25;
183        /// Support for 16-bit floating-point types.
184        const SHADER_FLOAT16 = 1 << 26;
185        /// Support for [`ImageClass::External`]
186        const TEXTURE_EXTERNAL = 1 << 27;
187        /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
188        /// `f16`-precision values in `f32`s.
189        const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
190        /// Support for fragment shader barycentric coordinates.
191        const SHADER_BARYCENTRICS = 1 << 29;
192        /// Support for task shaders, mesh shaders, and per-primitive fragment inputs
193        const MESH_SHADER = 1 << 30;
194        /// Support for mesh shaders which output points.
195        const MESH_SHADER_POINT_TOPOLOGY = 1 << 31;
196        /// Support for non-uniform indexing of binding arrays of sampled textures and samplers.
197        const TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 32;
198        /// Support for non-uniform indexing of binding arrays of uniform buffers.
199        const BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 33;
200        /// Support for non-uniform indexing of binding arrays of storage textures.
201        const STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 34;
202        /// Support for non-uniform indexing of binding arrays of storage buffers.
203        const STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 35;
204        /// Support for cooperative matrix types and operations
205        const COOPERATIVE_MATRIX = 1 << 36;
206        /// Support for per-vertex fragment input.
207        const PER_VERTEX = 1 << 37;
208    }
209}
210
211impl Capabilities {
212    /// Returns the extension corresponding to this capability, if there is one.
213    ///
214    /// This is used by integration tests.
215    #[cfg(feature = "wgsl-in")]
216    #[doc(hidden)]
217    pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
218        use crate::front::wgsl::ImplementedEnableExtension as Ext;
219        match *self {
220            Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
221            // NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
222            Self::SHADER_FLOAT16 => Some(Ext::F16),
223            Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
224            Self::MESH_SHADER => Some(Ext::WgpuMeshShader),
225            Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
226            Self::RAY_HIT_VERTEX_POSITION => Some(Ext::WgpuRayQueryVertexReturn),
227            Self::COOPERATIVE_MATRIX => Some(Ext::WgpuCooperativeMatrix),
228            _ => None,
229        }
230    }
231}
232
233impl Default for Capabilities {
234    fn default() -> Self {
235        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
236    }
237}
238
239bitflags::bitflags! {
240    /// Supported subgroup operations
241    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
242    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
243    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
244    pub struct SubgroupOperationSet: u8 {
245        /// Barriers
246        // Possibly elections, when that is supported.
247        // https://github.com/gfx-rs/wgpu/issues/6042#issuecomment-3272603431
248        // Contrary to what the name "basic" suggests, HLSL/DX12 support the
249        // other subgroup operations, but do not support subgroup barriers.
250        const BASIC = 1 << 0;
251        /// Any, All
252        const VOTE = 1 << 1;
253        /// reductions, scans
254        const ARITHMETIC = 1 << 2;
255        /// ballot, broadcast
256        const BALLOT = 1 << 3;
257        /// shuffle, shuffle xor
258        const SHUFFLE = 1 << 4;
259        /// shuffle up, down
260        const SHUFFLE_RELATIVE = 1 << 5;
261        // We don't support these operations yet
262        // /// Clustered
263        // const CLUSTERED = 1 << 6;
264        /// Quad supported
265        const QUAD_FRAGMENT_COMPUTE = 1 << 7;
266        // /// Quad supported in all stages
267        // const QUAD_ALL_STAGES = 1 << 8;
268    }
269}
270
271impl super::SubgroupOperation {
272    const fn required_operations(&self) -> SubgroupOperationSet {
273        use SubgroupOperationSet as S;
274        match *self {
275            Self::All | Self::Any => S::VOTE,
276            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
277                S::ARITHMETIC
278            }
279        }
280    }
281}
282
283impl super::GatherMode {
284    const fn required_operations(&self) -> SubgroupOperationSet {
285        use SubgroupOperationSet as S;
286        match *self {
287            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
288            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
289            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
290            Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
291        }
292    }
293}
294
295bitflags::bitflags! {
296    /// Validation flags.
297    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
298    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
299    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
300    pub struct ShaderStages: u8 {
301        const VERTEX = 0x1;
302        const FRAGMENT = 0x2;
303        const COMPUTE = 0x4;
304        const MESH = 0x8;
305        const TASK = 0x10;
306        const COMPUTE_LIKE = Self::COMPUTE.bits() | Self::TASK.bits() | Self::MESH.bits();
307    }
308}
309
310#[derive(Debug, Clone, Default)]
311#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
312#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
313pub struct ModuleInfo {
314    type_flags: Vec<TypeFlags>,
315    functions: Vec<FunctionInfo>,
316    entry_points: Vec<FunctionInfo>,
317    const_expression_types: Box<[TypeResolution]>,
318}
319
320impl ops::Index<Handle<crate::Type>> for ModuleInfo {
321    type Output = TypeFlags;
322    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
323        &self.type_flags[handle.index()]
324    }
325}
326
327impl ops::Index<Handle<crate::Function>> for ModuleInfo {
328    type Output = FunctionInfo;
329    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
330        &self.functions[handle.index()]
331    }
332}
333
334impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
335    type Output = TypeResolution;
336    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
337        &self.const_expression_types[handle.index()]
338    }
339}
340
341#[derive(Debug)]
342pub struct Validator {
343    flags: ValidationFlags,
344    capabilities: Capabilities,
345    subgroup_stages: ShaderStages,
346    subgroup_operations: SubgroupOperationSet,
347    types: Vec<r#type::TypeInfo>,
348    layouter: Layouter,
349    location_mask: BitSet,
350    blend_src_mask: BitSet,
351    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
352    #[allow(dead_code)]
353    switch_values: FastHashSet<crate::SwitchValue>,
354    valid_expression_list: Vec<Handle<crate::Expression>>,
355    valid_expression_set: HandleSet<crate::Expression>,
356    override_ids: FastHashSet<u16>,
357
358    /// Treat overrides whose initializers are not fully-evaluated
359    /// constant expressions as errors.
360    overrides_resolved: bool,
361
362    /// A checklist of expressions that must be visited by a specific kind of
363    /// statement.
364    ///
365    /// For example:
366    ///
367    /// - [`CallResult`] expressions must be visited by a [`Call`] statement.
368    /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
369    ///
370    /// Be sure not to remove any [`Expression`] handle from this set unless
371    /// you've explicitly checked that it is the right kind of expression for
372    /// the visiting [`Statement`].
373    ///
374    /// [`CallResult`]: crate::Expression::CallResult
375    /// [`Call`]: crate::Statement::Call
376    /// [`AtomicResult`]: crate::Expression::AtomicResult
377    /// [`Atomic`]: crate::Statement::Atomic
378    /// [`Expression`]: crate::Expression
379    /// [`Statement`]: crate::Statement
380    needs_visit: HandleSet<crate::Expression>,
381}
382
383#[derive(Clone, Debug, thiserror::Error)]
384#[cfg_attr(test, derive(PartialEq))]
385pub enum ConstantError {
386    #[error("Initializer must be a const-expression")]
387    InitializerExprType,
388    #[error("The type doesn't match the constant")]
389    InvalidType,
390    #[error("The type is not constructible")]
391    NonConstructibleType,
392}
393
394#[derive(Clone, Debug, thiserror::Error)]
395#[cfg_attr(test, derive(PartialEq))]
396pub enum OverrideError {
397    #[error("Override name and ID are missing")]
398    MissingNameAndID,
399    #[error("Override ID must be unique")]
400    DuplicateID,
401    #[error("Initializer must be a const-expression or override-expression")]
402    InitializerExprType,
403    #[error("The type doesn't match the override")]
404    InvalidType,
405    #[error("The type is not constructible")]
406    NonConstructibleType,
407    #[error("The type is not a scalar")]
408    TypeNotScalar,
409    #[error("Override declarations are not allowed")]
410    NotAllowed,
411    #[error("Override is uninitialized")]
412    UninitializedOverride,
413    #[error("Constant expression {handle:?} is invalid")]
414    ConstExpression {
415        handle: Handle<crate::Expression>,
416        source: ConstExpressionError,
417    },
418}
419
420#[derive(Clone, Debug, thiserror::Error)]
421#[cfg_attr(test, derive(PartialEq))]
422pub enum ValidationError {
423    #[error(transparent)]
424    InvalidHandle(#[from] InvalidHandleError),
425    #[error(transparent)]
426    Layouter(#[from] LayoutError),
427    #[error("Type {handle:?} '{name}' is invalid")]
428    Type {
429        handle: Handle<crate::Type>,
430        name: String,
431        source: TypeError,
432    },
433    #[error("Constant expression {handle:?} is invalid")]
434    ConstExpression {
435        handle: Handle<crate::Expression>,
436        source: ConstExpressionError,
437    },
438    #[error("Array size expression {handle:?} is not strictly positive")]
439    ArraySizeError { handle: Handle<crate::Expression> },
440    #[error("Constant {handle:?} '{name}' is invalid")]
441    Constant {
442        handle: Handle<crate::Constant>,
443        name: String,
444        source: ConstantError,
445    },
446    #[error("Override {handle:?} '{name}' is invalid")]
447    Override {
448        handle: Handle<crate::Override>,
449        name: String,
450        source: OverrideError,
451    },
452    #[error("Global variable {handle:?} '{name}' is invalid")]
453    GlobalVariable {
454        handle: Handle<crate::GlobalVariable>,
455        name: String,
456        source: GlobalVariableError,
457    },
458    #[error("Function {handle:?} '{name}' is invalid")]
459    Function {
460        handle: Handle<crate::Function>,
461        name: String,
462        source: FunctionError,
463    },
464    #[error("Entry point {name} at {stage:?} is invalid")]
465    EntryPoint {
466        stage: crate::ShaderStage,
467        name: String,
468        source: EntryPointError,
469    },
470    #[error("Module is corrupted")]
471    Corrupted,
472}
473
474impl crate::TypeInner {
475    const fn is_sized(&self) -> bool {
476        match *self {
477            Self::Scalar { .. }
478            | Self::Vector { .. }
479            | Self::Matrix { .. }
480            | Self::CooperativeMatrix { .. }
481            | Self::Array {
482                size: crate::ArraySize::Constant(_),
483                ..
484            }
485            | Self::Atomic { .. }
486            | Self::Pointer { .. }
487            | Self::ValuePointer { .. }
488            | Self::Struct { .. } => true,
489            Self::Array { .. }
490            | Self::Image { .. }
491            | Self::Sampler { .. }
492            | Self::AccelerationStructure { .. }
493            | Self::RayQuery { .. }
494            | Self::BindingArray { .. } => false,
495        }
496    }
497
498    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
499    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
500        match *self {
501            Self::Scalar(crate::Scalar {
502                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
503                ..
504            }) => Some(crate::ImageDimension::D1),
505            Self::Vector {
506                size: crate::VectorSize::Bi,
507                scalar:
508                    crate::Scalar {
509                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
510                        ..
511                    },
512            } => Some(crate::ImageDimension::D2),
513            Self::Vector {
514                size: crate::VectorSize::Tri,
515                scalar:
516                    crate::Scalar {
517                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
518                        ..
519                    },
520            } => Some(crate::ImageDimension::D3),
521            _ => None,
522        }
523    }
524}
525
526impl Validator {
527    /// Create a validator for Naga [`Module`]s.
528    ///
529    /// The `flags` argument indicates which stages of validation the
530    /// returned `Validator` should perform. Skipping stages can make
531    /// validation somewhat faster, but the validator may not reject some
532    /// invalid modules. Regardless of `flags`, validation always returns
533    /// a usable [`ModuleInfo`] value on success.
534    ///
535    /// If `flags` contains everything in `ValidationFlags::default()`,
536    /// then the returned Naga [`Validator`] will reject any [`Module`]
537    /// that would use capabilities not included in `capabilities`.
538    ///
539    /// [`Module`]: crate::Module
540    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
541        let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
542            use SubgroupOperationSet as S;
543            S::BASIC
544                | S::VOTE
545                | S::ARITHMETIC
546                | S::BALLOT
547                | S::SHUFFLE
548                | S::SHUFFLE_RELATIVE
549                | S::QUAD_FRAGMENT_COMPUTE
550        } else {
551            SubgroupOperationSet::empty()
552        };
553        let subgroup_stages = {
554            let mut stages = ShaderStages::empty();
555            if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
556                stages |= ShaderStages::VERTEX;
557            }
558            if capabilities.contains(Capabilities::SUBGROUP) {
559                stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE_LIKE;
560            }
561            stages
562        };
563
564        Validator {
565            flags,
566            capabilities,
567            subgroup_stages,
568            subgroup_operations,
569            types: Vec::new(),
570            layouter: Layouter::default(),
571            location_mask: BitSet::new(),
572            blend_src_mask: BitSet::new(),
573            ep_resource_bindings: FastHashSet::default(),
574            switch_values: FastHashSet::default(),
575            valid_expression_list: Vec::new(),
576            valid_expression_set: HandleSet::new(),
577            override_ids: FastHashSet::default(),
578            overrides_resolved: false,
579            needs_visit: HandleSet::new(),
580        }
581    }
582
583    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
584    pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
585        self.subgroup_stages = stages;
586        self
587    }
588
589    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
590    pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
591        self.subgroup_operations = operations;
592        self
593    }
594
595    /// Reset the validator internals
596    pub fn reset(&mut self) {
597        self.types.clear();
598        self.layouter.clear();
599        self.location_mask.clear();
600        self.blend_src_mask.clear();
601        self.ep_resource_bindings.clear();
602        self.switch_values.clear();
603        self.valid_expression_list.clear();
604        self.valid_expression_set.clear();
605        self.override_ids.clear();
606    }
607
608    fn validate_constant(
609        &self,
610        handle: Handle<crate::Constant>,
611        gctx: crate::proc::GlobalCtx,
612        mod_info: &ModuleInfo,
613        global_expr_kind: &ExpressionKindTracker,
614    ) -> Result<(), ConstantError> {
615        let con = &gctx.constants[handle];
616
617        let type_info = &self.types[con.ty.index()];
618        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
619            return Err(ConstantError::NonConstructibleType);
620        }
621
622        if !global_expr_kind.is_const(con.init) {
623            return Err(ConstantError::InitializerExprType);
624        }
625
626        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
627            return Err(ConstantError::InvalidType);
628        }
629
630        Ok(())
631    }
632
633    fn validate_override(
634        &mut self,
635        handle: Handle<crate::Override>,
636        gctx: crate::proc::GlobalCtx,
637        mod_info: &ModuleInfo,
638    ) -> Result<(), OverrideError> {
639        let o = &gctx.overrides[handle];
640
641        if let Some(id) = o.id {
642            if !self.override_ids.insert(id) {
643                return Err(OverrideError::DuplicateID);
644            }
645        }
646
647        let type_info = &self.types[o.ty.index()];
648        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
649            return Err(OverrideError::NonConstructibleType);
650        }
651
652        match gctx.types[o.ty].inner {
653            crate::TypeInner::Scalar(
654                crate::Scalar::BOOL
655                | crate::Scalar::I32
656                | crate::Scalar::U32
657                | crate::Scalar::F16
658                | crate::Scalar::F32
659                | crate::Scalar::F64,
660            ) => {}
661            _ => return Err(OverrideError::TypeNotScalar),
662        }
663
664        if let Some(init) = o.init {
665            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
666                return Err(OverrideError::InvalidType);
667            }
668        } else if self.overrides_resolved {
669            return Err(OverrideError::UninitializedOverride);
670        }
671
672        Ok(())
673    }
674
675    /// Check the given module to be valid.
676    pub fn validate(
677        &mut self,
678        module: &crate::Module,
679    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
680        self.overrides_resolved = false;
681        self.validate_impl(module)
682    }
683
684    /// Check the given module to be valid, requiring overrides to be resolved.
685    ///
686    /// This is the same as [`validate`], except that any override
687    /// whose value is not a fully-evaluated constant expression is
688    /// treated as an error.
689    ///
690    /// [`validate`]: Validator::validate
691    pub fn validate_resolved_overrides(
692        &mut self,
693        module: &crate::Module,
694    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
695        self.overrides_resolved = true;
696        self.validate_impl(module)
697    }
698
699    fn validate_impl(
700        &mut self,
701        module: &crate::Module,
702    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
703        self.reset();
704        self.reset_types(module.types.len());
705
706        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
707
708        self.layouter.update(module.to_ctx()).map_err(|e| {
709            let handle = e.ty;
710            ValidationError::from(e).with_span_handle(handle, &module.types)
711        })?;
712
713        // These should all get overwritten.
714        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
715            kind: crate::ScalarKind::Bool,
716            width: 0,
717        }));
718
719        let mut mod_info = ModuleInfo {
720            type_flags: Vec::with_capacity(module.types.len()),
721            functions: Vec::with_capacity(module.functions.len()),
722            entry_points: Vec::with_capacity(module.entry_points.len()),
723            const_expression_types: vec![placeholder; module.global_expressions.len()]
724                .into_boxed_slice(),
725        };
726
727        for (handle, ty) in module.types.iter() {
728            let ty_info = self
729                .validate_type(handle, module.to_ctx())
730                .map_err(|source| {
731                    ValidationError::Type {
732                        handle,
733                        name: ty.name.clone().unwrap_or_default(),
734                        source,
735                    }
736                    .with_span_handle(handle, &module.types)
737                })?;
738            mod_info.type_flags.push(ty_info.flags);
739            self.types[handle.index()] = ty_info;
740        }
741
742        {
743            let t = crate::Arena::new();
744            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
745            for (handle, _) in module.global_expressions.iter() {
746                mod_info
747                    .process_const_expression(handle, &resolve_context, module.to_ctx())
748                    .map_err(|source| {
749                        ValidationError::ConstExpression { handle, source }
750                            .with_span_handle(handle, &module.global_expressions)
751                    })?
752            }
753        }
754
755        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
756
757        if self.flags.contains(ValidationFlags::CONSTANTS) {
758            for (handle, _) in module.global_expressions.iter() {
759                self.validate_const_expression(
760                    handle,
761                    module.to_ctx(),
762                    &mod_info,
763                    &global_expr_kind,
764                )
765                .map_err(|source| {
766                    ValidationError::ConstExpression { handle, source }
767                        .with_span_handle(handle, &module.global_expressions)
768                })?
769            }
770
771            for (handle, constant) in module.constants.iter() {
772                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
773                    .map_err(|source| {
774                        ValidationError::Constant {
775                            handle,
776                            name: constant.name.clone().unwrap_or_default(),
777                            source,
778                        }
779                        .with_span_handle(handle, &module.constants)
780                    })?
781            }
782
783            for (handle, r#override) in module.overrides.iter() {
784                self.validate_override(handle, module.to_ctx(), &mod_info)
785                    .map_err(|source| {
786                        ValidationError::Override {
787                            handle,
788                            name: r#override.name.clone().unwrap_or_default(),
789                            source,
790                        }
791                        .with_span_handle(handle, &module.overrides)
792                    })?;
793            }
794        }
795
796        for (var_handle, var) in module.global_variables.iter() {
797            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
798                .map_err(|source| {
799                    ValidationError::GlobalVariable {
800                        handle: var_handle,
801                        name: var.name.clone().unwrap_or_default(),
802                        source,
803                    }
804                    .with_span_handle(var_handle, &module.global_variables)
805                })?;
806        }
807
808        for (handle, fun) in module.functions.iter() {
809            match self.validate_function(fun, module, &mod_info, false) {
810                Ok(info) => mod_info.functions.push(info),
811                Err(error) => {
812                    return Err(error.and_then(|source| {
813                        ValidationError::Function {
814                            handle,
815                            name: fun.name.clone().unwrap_or_default(),
816                            source,
817                        }
818                        .with_span_handle(handle, &module.functions)
819                    }))
820                }
821            }
822        }
823
824        let mut ep_map = FastHashSet::default();
825        for ep in module.entry_points.iter() {
826            if !ep_map.insert((ep.stage, &ep.name)) {
827                return Err(ValidationError::EntryPoint {
828                    stage: ep.stage,
829                    name: ep.name.clone(),
830                    source: EntryPointError::Conflict,
831                }
832                .with_span()); // TODO: keep some EP span information?
833            }
834
835            match self.validate_entry_point(ep, module, &mod_info) {
836                Ok(info) => mod_info.entry_points.push(info),
837                Err(error) => {
838                    return Err(error.and_then(|source| {
839                        ValidationError::EntryPoint {
840                            stage: ep.stage,
841                            name: ep.name.clone(),
842                            source,
843                        }
844                        .with_span()
845                    }));
846                }
847            }
848        }
849
850        Ok(mod_info)
851    }
852}
853
854fn validate_atomic_compare_exchange_struct(
855    types: &crate::UniqueArena<crate::Type>,
856    members: &[crate::StructMember],
857    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
858) -> bool {
859    members.len() == 2
860        && members[0].name.as_deref() == Some("old_value")
861        && scalar_predicate(&types[members[0].ty].inner)
862        && members[1].name.as_deref() == Some("exchanged")
863        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
864}