naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10pub(crate) mod immediates;
11mod interface;
12mod r#type;
13
14use alloc::{boxed::Box, string::String, vec, vec::Vec};
15use core::ops;
16
17use bit_set::BitSet;
18
19use crate::{
20    arena::{Handle, HandleSet},
21    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
22    FastHashSet,
23};
24
25//TODO: analyze the model at the same time as we validate it,
26// merge the corresponding matches over expressions and statements.
27
28use crate::span::{AddSpan as _, WithSpan};
29pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
30pub use compose::ComposeError;
31pub use expression::{check_literal_value, LiteralError};
32pub use expression::{ConstExpressionError, ExpressionError};
33pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
34pub use immediates::ImmediateSlots;
35pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
36pub use r#type::{Disalignment, ImmediateError, TypeError, TypeFlags, WidthError};
37
38use self::handles::InvalidHandleError;
39
40/// Maximum size of a type, in bytes.
41pub const MAX_TYPE_SIZE: u32 = i32::MAX as u32;
42
43bitflags::bitflags! {
44    /// Validation flags.
45    ///
46    /// If you are working with trusted shaders, then you may be able
47    /// to save some time by skipping validation.
48    ///
49    /// If you do not perform full validation, invalid shaders may
50    /// cause Naga to panic. If you do perform full validation and
51    /// [`Validator::validate`] returns `Ok`, then Naga promises that
52    /// code generation will either succeed or return an error; it
53    /// should never panic.
54    ///
55    /// The default value for `ValidationFlags` is
56    /// `ValidationFlags::all()`.
57    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
58    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
59    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
60    pub struct ValidationFlags: u8 {
61        /// Expressions.
62        const EXPRESSIONS = 0x1;
63        /// Statements and blocks of them.
64        const BLOCKS = 0x2;
65        /// Uniformity of control flow for operations that require it.
66        const CONTROL_FLOW_UNIFORMITY = 0x4;
67        /// Host-shareable structure layouts.
68        const STRUCT_LAYOUTS = 0x8;
69        /// Constants.
70        const CONSTANTS = 0x10;
71        /// Group, binding, and location attributes.
72        const BINDINGS = 0x20;
73    }
74}
75
76impl Default for ValidationFlags {
77    fn default() -> Self {
78        Self::all()
79    }
80}
81
82bitflags::bitflags! {
83    /// Allowed IR capabilities.
84    #[must_use]
85    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
86    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
87    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
88    pub struct Capabilities: u64 {
89        /// Support for [`AddressSpace::Immediate`][1].
90        ///
91        /// [1]: crate::AddressSpace::Immediate
92        const IMMEDIATES = 1 << 0;
93        /// Float values with width = 8.
94        const FLOAT64 = 1 << 1;
95        /// Support for [`BuiltIn::PrimitiveIndex`][1].
96        ///
97        /// [1]: crate::BuiltIn::PrimitiveIndex
98        const PRIMITIVE_INDEX = 1 << 2;
99        /// Support for binding arrays of sampled textures and samplers.
100        const TEXTURE_AND_SAMPLER_BINDING_ARRAY = 1 << 3;
101        /// Support for binding arrays of uniform buffers.
102        const BUFFER_BINDING_ARRAY = 1 << 4;
103        /// Support for binding arrays of storage textures.
104        const STORAGE_TEXTURE_BINDING_ARRAY = 1 << 5;
105        /// Support for binding arrays of storage buffers.
106        const STORAGE_BUFFER_BINDING_ARRAY = 1 << 6;
107        /// Support for [`BuiltIn::ClipDistances`].
108        ///
109        /// [`BuiltIn::ClipDistances`]: crate::BuiltIn::ClipDistances
110        const CLIP_DISTANCES = 1 << 7;
111        /// Support for [`BuiltIn::CullDistance`].
112        ///
113        /// [`BuiltIn::CullDistance`]: crate::BuiltIn::CullDistance
114        const CULL_DISTANCE = 1 << 8;
115        /// Support for 16-bit normalized storage texture formats.
116        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
117        /// Support for [`BuiltIn::ViewIndex`].
118        ///
119        /// [`BuiltIn::ViewIndex`]: crate::BuiltIn::ViewIndex
120        const MULTIVIEW = 1 << 10;
121        /// Support for `early_depth_test`.
122        const EARLY_DEPTH_TEST = 1 << 11;
123        /// Support for [`BuiltIn::SampleIndex`] and [`Sampling::Sample`].
124        ///
125        /// [`BuiltIn::SampleIndex`]: crate::BuiltIn::SampleIndex
126        /// [`Sampling::Sample`]: crate::Sampling::Sample
127        const MULTISAMPLED_SHADING = 1 << 12;
128        /// Support for ray queries and acceleration structures.
129        const RAY_QUERY = 1 << 13;
130        /// Support for generating two sources for blending from fragment shaders.
131        const DUAL_SOURCE_BLENDING = 1 << 14;
132        /// Support for arrayed cube textures.
133        const CUBE_ARRAY_TEXTURES = 1 << 15;
134        /// Support for 64-bit signed and unsigned integers.
135        const SHADER_INT64 = 1 << 16;
136        /// Support for subgroup operations (except barriers) in fragment and compute shaders.
137        ///
138        /// Subgroup operations in the vertex stage require
139        /// [`Capabilities::SUBGROUP_VERTEX_STAGE`] in addition to `Capabilities::SUBGROUP`.
140        /// (But note that `create_validator` automatically sets
141        /// `Capabilities::SUBGROUP` whenever `Features::SUBGROUP_VERTEX` is
142        /// available.)
143        ///
144        /// Subgroup barriers require [`Capabilities::SUBGROUP_BARRIER`] in addition to
145        /// `Capabilities::SUBGROUP`.
146        const SUBGROUP = 1 << 17;
147        /// Support for subgroup barriers in compute shaders.
148        ///
149        /// Requires [`Capabilities::SUBGROUP`]. Without it, enables nothing.
150        const SUBGROUP_BARRIER = 1 << 18;
151        /// Support for subgroup operations (not including barriers) in the vertex stage.
152        ///
153        /// Without [`Capabilities::SUBGROUP`], enables nothing. (But note that
154        /// `create_validator` automatically sets `Capabilities::SUBGROUP`
155        /// whenever `Features::SUBGROUP_VERTEX` is available.)
156        const SUBGROUP_VERTEX_STAGE = 1 << 19;
157        /// Support for [`AtomicFunction::Min`] and [`AtomicFunction::Max`] on
158        /// 64-bit integers in the [`Storage`] address space, when the return
159        /// value is not used.
160        ///
161        /// This is the only 64-bit atomic functionality available on Metal 3.1.
162        ///
163        /// [`AtomicFunction::Min`]: crate::AtomicFunction::Min
164        /// [`AtomicFunction::Max`]: crate::AtomicFunction::Max
165        /// [`Storage`]: crate::AddressSpace::Storage
166        const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
167        /// Support for all atomic operations on 64-bit integers.
168        const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
169        /// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`],
170        /// and [`AtomicFunction::Exchange { compare: None }`] on 32-bit floating-point numbers
171        /// in the [`Storage`] address space.
172        ///
173        /// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
174        /// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
175        /// [`AtomicFunction::Exchange { compare: None }`]: crate::AtomicFunction::Exchange
176        /// [`Storage`]: crate::AddressSpace::Storage
177        const SHADER_FLOAT32_ATOMIC = 1 << 22;
178        /// Support for atomic operations on images.
179        const TEXTURE_ATOMIC = 1 << 23;
180        /// Support for atomic operations on 64-bit images.
181        const TEXTURE_INT64_ATOMIC = 1 << 24;
182        /// Support for ray queries returning vertex position
183        const RAY_HIT_VERTEX_POSITION = 1 << 25;
184        /// Support for 16-bit floating-point types.
185        const SHADER_FLOAT16 = 1 << 26;
186        /// Support for [`ImageClass::External`]
187        const TEXTURE_EXTERNAL = 1 << 27;
188        /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
189        /// `f16`-precision values in `f32`s.
190        const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
191        /// Support for fragment shader barycentric coordinates.
192        const SHADER_BARYCENTRICS = 1 << 29;
193        /// Support for task shaders, mesh shaders, and per-primitive fragment inputs
194        const MESH_SHADER = 1 << 30;
195        /// Support for mesh shaders which output points.
196        const MESH_SHADER_POINT_TOPOLOGY = 1 << 31;
197        /// Support for non-uniform indexing of binding arrays of sampled textures and samplers.
198        const TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 32;
199        /// Support for non-uniform indexing of binding arrays of uniform buffers.
200        const BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 33;
201        /// Support for non-uniform indexing of binding arrays of storage textures.
202        const STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 34;
203        /// Support for non-uniform indexing of binding arrays of storage buffers.
204        const STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 35;
205        /// Support for cooperative matrix types and operations
206        const COOPERATIVE_MATRIX = 1 << 36;
207        /// Support for per-vertex fragment input.
208        const PER_VERTEX = 1 << 37;
209        /// Support for ray generation, any hit, closest hit, and miss shaders.
210        const RAY_TRACING_PIPELINE = 1 << 38;
211        /// Support for draw index builtin
212        const DRAW_INDEX = 1 << 39;
213        /// Support for binding arrays of acceleration structures.
214        const ACCELERATION_STRUCTURE_BINDING_ARRAY = 1 << 40;
215        /// Support for the `@coherent` memory decoration on storage buffers.
216        const MEMORY_DECORATION_COHERENT = 1 << 41;
217        /// Support for the `@volatile` memory decoration on storage buffers.
218        const MEMORY_DECORATION_VOLATILE = 1 << 42;
219        /// Support for 16-bit integer types.
220        const SHADER_INT16 = 1 << 43;
221    }
222}
223
224impl Capabilities {
225    /// Returns the extension corresponding to this capability, if there is one.
226    ///
227    /// This is used by integration tests.
228    #[cfg(feature = "wgsl-in")]
229    #[doc(hidden)]
230    pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
231        use crate::front::wgsl::ImplementedEnableExtension as Ext;
232        match *self {
233            Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
234            // NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
235            Self::SHADER_FLOAT16 => Some(Ext::F16),
236            Self::SHADER_INT16 => Some(Ext::WgpuInt16),
237            Self::CLIP_DISTANCES => Some(Ext::ClipDistances),
238            Self::MESH_SHADER => Some(Ext::WgpuMeshShader),
239            Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
240            Self::RAY_HIT_VERTEX_POSITION => Some(Ext::WgpuRayQueryVertexReturn),
241            Self::COOPERATIVE_MATRIX => Some(Ext::WgpuCooperativeMatrix),
242            Self::RAY_TRACING_PIPELINE => Some(Ext::WgpuRayTracingPipeline),
243            Self::PER_VERTEX => Some(Ext::WgpuPerVertex),
244            Self::BUFFER_BINDING_ARRAY
245            | Self::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
246            | Self::STORAGE_BUFFER_BINDING_ARRAY
247            | Self::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
248            | Self::STORAGE_TEXTURE_BINDING_ARRAY
249            | Self::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
250            | Self::TEXTURE_AND_SAMPLER_BINDING_ARRAY
251            | Self::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING => {
252                Some(Ext::WgpuBindingArray)
253            }
254            _ => None,
255        }
256    }
257}
258
259impl Default for Capabilities {
260    fn default() -> Self {
261        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
262    }
263}
264
265bitflags::bitflags! {
266    /// Supported subgroup operations
267    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
268    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
269    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
270    pub struct SubgroupOperationSet: u8 {
271        /// Barriers
272        // Possibly elections, when that is supported.
273        // https://github.com/gfx-rs/wgpu/issues/6042#issuecomment-3272603431
274        // Contrary to what the name "basic" suggests, HLSL/DX12 support the
275        // other subgroup operations, but do not support subgroup barriers.
276        const BASIC = 1 << 0;
277        /// Any, All
278        const VOTE = 1 << 1;
279        /// reductions, scans
280        const ARITHMETIC = 1 << 2;
281        /// ballot, broadcast
282        const BALLOT = 1 << 3;
283        /// shuffle, shuffle xor
284        const SHUFFLE = 1 << 4;
285        /// shuffle up, down
286        const SHUFFLE_RELATIVE = 1 << 5;
287        // We don't support these operations yet
288        // /// Clustered
289        // const CLUSTERED = 1 << 6;
290        /// Quad supported
291        const QUAD_FRAGMENT_COMPUTE = 1 << 7;
292        // /// Quad supported in all stages
293        // const QUAD_ALL_STAGES = 1 << 8;
294    }
295}
296
297impl super::SubgroupOperation {
298    const fn required_operations(&self) -> SubgroupOperationSet {
299        use SubgroupOperationSet as S;
300        match *self {
301            Self::All | Self::Any => S::VOTE,
302            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
303                S::ARITHMETIC
304            }
305        }
306    }
307}
308
309impl super::GatherMode {
310    const fn required_operations(&self) -> SubgroupOperationSet {
311        use SubgroupOperationSet as S;
312        match *self {
313            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
314            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
315            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
316            Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
317        }
318    }
319}
320
321bitflags::bitflags! {
322    /// Validation flags.
323    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
324    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
325    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
326    pub struct ShaderStages: u16 {
327        const VERTEX = 0x1;
328        const FRAGMENT = 0x2;
329        const COMPUTE = 0x4;
330        const MESH = 0x8;
331        const TASK = 0x10;
332        const RAY_GENERATION = 0x20;
333        const ANY_HIT = 0x40;
334        const CLOSEST_HIT = 0x80;
335        const MISS = 0x100;
336        const COMPUTE_LIKE = Self::COMPUTE.bits() | Self::TASK.bits() | Self::MESH.bits();
337    }
338}
339
340#[derive(Debug, Clone, Default)]
341#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
342#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
343pub struct ModuleInfo {
344    type_flags: Vec<TypeFlags>,
345    functions: Vec<FunctionInfo>,
346    entry_points: Vec<FunctionInfo>,
347    const_expression_types: Box<[TypeResolution]>,
348}
349
350impl ops::Index<Handle<crate::Type>> for ModuleInfo {
351    type Output = TypeFlags;
352    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
353        &self.type_flags[handle.index()]
354    }
355}
356
357impl ops::Index<Handle<crate::Function>> for ModuleInfo {
358    type Output = FunctionInfo;
359    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
360        &self.functions[handle.index()]
361    }
362}
363
364impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
365    type Output = TypeResolution;
366    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
367        &self.const_expression_types[handle.index()]
368    }
369}
370
371#[derive(Debug)]
372pub struct Validator {
373    flags: ValidationFlags,
374    capabilities: Capabilities,
375    subgroup_stages: ShaderStages,
376    subgroup_operations: SubgroupOperationSet,
377    types: Vec<r#type::TypeInfo>,
378    layouter: Layouter,
379    location_mask: BitSet,
380    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
381    switch_values: FastHashSet<crate::SwitchValue>,
382    valid_expression_list: Vec<Handle<crate::Expression>>,
383    valid_expression_set: HandleSet<crate::Expression>,
384    override_ids: FastHashSet<u16>,
385
386    /// Treat overrides whose initializers are not fully-evaluated
387    /// constant expressions as errors.
388    overrides_resolved: bool,
389
390    /// A checklist of expressions that must be visited by a specific kind of
391    /// statement.
392    ///
393    /// For example:
394    ///
395    /// - [`CallResult`] expressions must be visited by a [`Call`] statement.
396    /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
397    ///
398    /// Be sure not to remove any [`Expression`] handle from this set unless
399    /// you've explicitly checked that it is the right kind of expression for
400    /// the visiting [`Statement`].
401    ///
402    /// [`CallResult`]: crate::Expression::CallResult
403    /// [`Call`]: crate::Statement::Call
404    /// [`AtomicResult`]: crate::Expression::AtomicResult
405    /// [`Atomic`]: crate::Statement::Atomic
406    /// [`Expression`]: crate::Expression
407    /// [`Statement`]: crate::Statement
408    needs_visit: HandleSet<crate::Expression>,
409
410    /// Whether any trace rays call is called, and whether all have vertex return.
411    /// If one call doesn't use vertex ruturn, builtins for triangle vertex positions
412    /// (not yet implemented) are not allowed.
413    trace_rays_vertex_return: TraceRayVertexReturnState,
414
415    /// The type of the ray payload, this must always be the same type in a particular
416    /// entrypoint
417    trace_rays_payload_type: Option<Handle<crate::Type>>,
418}
419
420#[derive(Debug)]
421enum TraceRayVertexReturnState {
422    /// No trace ray calls yet have been found.
423    NoTraceRays,
424    /// Trace ray calls have been found, at least
425    /// one uses an acceleration structure that
426    /// does not have the flag enabling vertex return.
427    #[expect(
428        unused,
429        reason = "Don't yet have vertex return builtins to return this error for."
430    )]
431    NoVertexReturn(crate::Span),
432    /// Trace ray calls have been found, all
433    /// acceleration structures have the flag enabling
434    /// vertex return.
435    VertexReturn,
436}
437
438#[derive(Clone, Debug, thiserror::Error)]
439#[cfg_attr(test, derive(PartialEq))]
440pub enum ConstantError {
441    #[error("Initializer must be a const-expression")]
442    InitializerExprType,
443    #[error("The type doesn't match the constant")]
444    InvalidType,
445    #[error("The type is not constructible")]
446    NonConstructibleType,
447}
448
449#[derive(Clone, Debug, thiserror::Error)]
450#[cfg_attr(test, derive(PartialEq))]
451pub enum OverrideError {
452    #[error("Override name and ID are missing")]
453    MissingNameAndID,
454    #[error("Override ID must be unique")]
455    DuplicateID,
456    #[error("Initializer must be a const-expression or override-expression")]
457    InitializerExprType,
458    #[error("The type doesn't match the override")]
459    InvalidType,
460    #[error("The type is not constructible")]
461    NonConstructibleType,
462    #[error("The type is not a scalar")]
463    TypeNotScalar,
464    #[error("Override declarations are not allowed")]
465    NotAllowed,
466    #[error("Override is uninitialized")]
467    UninitializedOverride,
468    #[error("Constant expression {handle:?} is invalid")]
469    ConstExpression {
470        handle: Handle<crate::Expression>,
471        source: ConstExpressionError,
472    },
473}
474
475#[derive(Clone, Debug, thiserror::Error)]
476#[cfg_attr(test, derive(PartialEq))]
477pub enum ValidationError {
478    #[error(transparent)]
479    InvalidHandle(#[from] InvalidHandleError),
480    #[error(transparent)]
481    Layouter(#[from] LayoutError),
482    #[error("Type {handle:?} '{name}' is invalid")]
483    Type {
484        handle: Handle<crate::Type>,
485        name: String,
486        source: TypeError,
487    },
488    #[error("Constant expression {handle:?} is invalid")]
489    ConstExpression {
490        handle: Handle<crate::Expression>,
491        source: ConstExpressionError,
492    },
493    #[error("Array size expression {handle:?} is not strictly positive")]
494    ArraySizeError { handle: Handle<crate::Expression> },
495    #[error("Constant {handle:?} '{name}' is invalid")]
496    Constant {
497        handle: Handle<crate::Constant>,
498        name: String,
499        source: ConstantError,
500    },
501    #[error("Override {handle:?} '{name}' is invalid")]
502    Override {
503        handle: Handle<crate::Override>,
504        name: String,
505        source: OverrideError,
506    },
507    #[error("Global variable {handle:?} '{name}' is invalid")]
508    GlobalVariable {
509        handle: Handle<crate::GlobalVariable>,
510        name: String,
511        source: GlobalVariableError,
512    },
513    #[error("Function {handle:?} '{name}' is invalid")]
514    Function {
515        handle: Handle<crate::Function>,
516        name: String,
517        source: FunctionError,
518    },
519    #[error("Entry point {name} at {stage:?} is invalid")]
520    EntryPoint {
521        stage: crate::ShaderStage,
522        name: String,
523        source: EntryPointError,
524    },
525    #[error("Module is corrupted")]
526    Corrupted,
527}
528
529impl crate::TypeInner {
530    const fn is_sized(&self) -> bool {
531        match *self {
532            Self::Scalar { .. }
533            | Self::Vector { .. }
534            | Self::Matrix { .. }
535            | Self::CooperativeMatrix { .. }
536            | Self::Array {
537                size: crate::ArraySize::Constant(_),
538                ..
539            }
540            | Self::Atomic { .. }
541            | Self::Pointer { .. }
542            | Self::ValuePointer { .. }
543            | Self::Struct { .. } => true,
544            Self::Array { .. }
545            | Self::Image { .. }
546            | Self::Sampler { .. }
547            | Self::AccelerationStructure { .. }
548            | Self::RayQuery { .. }
549            | Self::BindingArray { .. } => false,
550        }
551    }
552
553    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
554    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
555        match *self {
556            Self::Scalar(crate::Scalar {
557                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
558                ..
559            }) => Some(crate::ImageDimension::D1),
560            Self::Vector {
561                size: crate::VectorSize::Bi,
562                scalar:
563                    crate::Scalar {
564                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
565                        ..
566                    },
567            } => Some(crate::ImageDimension::D2),
568            Self::Vector {
569                size: crate::VectorSize::Tri,
570                scalar:
571                    crate::Scalar {
572                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
573                        ..
574                    },
575            } => Some(crate::ImageDimension::D3),
576            _ => None,
577        }
578    }
579}
580
581impl Validator {
582    /// Create a validator for Naga [`Module`]s.
583    ///
584    /// The `flags` argument indicates which stages of validation the
585    /// returned `Validator` should perform. Skipping stages can make
586    /// validation somewhat faster, but the validator may not reject some
587    /// invalid modules. Regardless of `flags`, validation always returns
588    /// a usable [`ModuleInfo`] value on success.
589    ///
590    /// If `flags` contains everything in `ValidationFlags::default()`,
591    /// then the returned Naga [`Validator`] will reject any [`Module`]
592    /// that would use capabilities not included in `capabilities`.
593    ///
594    /// [`Module`]: crate::Module
595    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
596        let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
597            use SubgroupOperationSet as S;
598            S::BASIC
599                | S::VOTE
600                | S::ARITHMETIC
601                | S::BALLOT
602                | S::SHUFFLE
603                | S::SHUFFLE_RELATIVE
604                | S::QUAD_FRAGMENT_COMPUTE
605        } else {
606            SubgroupOperationSet::empty()
607        };
608        let subgroup_stages = {
609            let mut stages = ShaderStages::empty();
610            if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
611                stages |= ShaderStages::VERTEX;
612            }
613            if capabilities.contains(Capabilities::SUBGROUP) {
614                stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE_LIKE;
615            }
616            stages
617        };
618
619        Validator {
620            flags,
621            capabilities,
622            subgroup_stages,
623            subgroup_operations,
624            types: Vec::new(),
625            layouter: Layouter::default(),
626            location_mask: BitSet::new(),
627            ep_resource_bindings: FastHashSet::default(),
628            switch_values: FastHashSet::default(),
629            valid_expression_list: Vec::new(),
630            valid_expression_set: HandleSet::new(),
631            override_ids: FastHashSet::default(),
632            overrides_resolved: false,
633            needs_visit: HandleSet::new(),
634            trace_rays_vertex_return: TraceRayVertexReturnState::NoTraceRays,
635            trace_rays_payload_type: None,
636        }
637    }
638
639    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
640    pub const fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
641        self.subgroup_stages = stages;
642        self
643    }
644
645    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
646    pub const fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
647        self.subgroup_operations = operations;
648        self
649    }
650
651    /// Reset the validator internals
652    pub fn reset(&mut self) {
653        self.types.clear();
654        self.layouter.clear();
655        self.location_mask.make_empty();
656        self.ep_resource_bindings.clear();
657        self.switch_values.clear();
658        self.valid_expression_list.clear();
659        self.valid_expression_set.clear();
660        self.override_ids.clear();
661    }
662
663    fn validate_constant(
664        &self,
665        handle: Handle<crate::Constant>,
666        gctx: crate::proc::GlobalCtx,
667        mod_info: &ModuleInfo,
668        global_expr_kind: &ExpressionKindTracker,
669    ) -> Result<(), ConstantError> {
670        let con = &gctx.constants[handle];
671
672        let type_info = &self.types[con.ty.index()];
673        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
674            return Err(ConstantError::NonConstructibleType);
675        }
676
677        if !global_expr_kind.is_const(con.init) {
678            return Err(ConstantError::InitializerExprType);
679        }
680
681        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
682            return Err(ConstantError::InvalidType);
683        }
684
685        Ok(())
686    }
687
688    fn validate_override(
689        &mut self,
690        handle: Handle<crate::Override>,
691        gctx: crate::proc::GlobalCtx,
692        mod_info: &ModuleInfo,
693    ) -> Result<(), OverrideError> {
694        let o = &gctx.overrides[handle];
695
696        if let Some(id) = o.id {
697            if !self.override_ids.insert(id) {
698                return Err(OverrideError::DuplicateID);
699            }
700        }
701
702        let type_info = &self.types[o.ty.index()];
703        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
704            return Err(OverrideError::NonConstructibleType);
705        }
706
707        match gctx.types[o.ty].inner {
708            crate::TypeInner::Scalar(
709                crate::Scalar::BOOL
710                | crate::Scalar::I16
711                | crate::Scalar::U16
712                | crate::Scalar::I32
713                | crate::Scalar::U32
714                | crate::Scalar::F16
715                | crate::Scalar::F32
716                | crate::Scalar::F64,
717            ) => {}
718            _ => return Err(OverrideError::TypeNotScalar),
719        }
720
721        if let Some(init) = o.init {
722            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
723                return Err(OverrideError::InvalidType);
724            }
725        } else if self.overrides_resolved {
726            return Err(OverrideError::UninitializedOverride);
727        }
728
729        Ok(())
730    }
731
732    /// Check the given module to be valid.
733    pub fn validate(
734        &mut self,
735        module: &crate::Module,
736    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
737        self.overrides_resolved = false;
738        self.validate_impl(module)
739    }
740
741    /// Check the given module to be valid, requiring overrides to be resolved.
742    ///
743    /// This is the same as [`validate`], except that any override
744    /// whose value is not a fully-evaluated constant expression is
745    /// treated as an error.
746    ///
747    /// [`validate`]: Validator::validate
748    pub fn validate_resolved_overrides(
749        &mut self,
750        module: &crate::Module,
751    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
752        self.overrides_resolved = true;
753        self.validate_impl(module)
754    }
755
756    fn validate_impl(
757        &mut self,
758        module: &crate::Module,
759    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
760        self.reset();
761        self.reset_types(module.types.len());
762
763        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
764
765        self.layouter.update(module.to_ctx()).map_err(|e| {
766            let handle = e.ty;
767            ValidationError::from(e).with_span_handle(handle, &module.types)
768        })?;
769
770        // These should all get overwritten.
771        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
772            kind: crate::ScalarKind::Bool,
773            width: 0,
774        }));
775
776        let mut mod_info = ModuleInfo {
777            type_flags: Vec::with_capacity(module.types.len()),
778            functions: Vec::with_capacity(module.functions.len()),
779            entry_points: Vec::with_capacity(module.entry_points.len()),
780            const_expression_types: vec![placeholder; module.global_expressions.len()]
781                .into_boxed_slice(),
782        };
783
784        for (handle, ty) in module.types.iter() {
785            let ty_info = self
786                .validate_type(handle, module.to_ctx())
787                .map_err(|source| {
788                    ValidationError::Type {
789                        handle,
790                        name: ty.name.clone().unwrap_or_default(),
791                        source,
792                    }
793                    .with_span_handle(handle, &module.types)
794                })?;
795            debug_assert!(
796                ty_info.flags.contains(TypeFlags::CONSTRUCTIBLE)
797                    == module.types[handle].inner.is_constructible(&module.types)
798            );
799            mod_info.type_flags.push(ty_info.flags);
800            self.types[handle.index()] = ty_info;
801        }
802
803        {
804            let t = crate::Arena::new();
805            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
806            for (handle, _) in module.global_expressions.iter() {
807                mod_info
808                    .process_const_expression(handle, &resolve_context, module.to_ctx())
809                    .map_err(|source| {
810                        ValidationError::ConstExpression { handle, source }
811                            .with_span_handle(handle, &module.global_expressions)
812                    })?
813            }
814        }
815
816        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
817
818        if self.flags.contains(ValidationFlags::CONSTANTS) {
819            for (handle, _) in module.global_expressions.iter() {
820                self.validate_const_expression(
821                    handle,
822                    module.to_ctx(),
823                    &mod_info,
824                    &global_expr_kind,
825                )
826                .map_err(|source| {
827                    ValidationError::ConstExpression { handle, source }
828                        .with_span_handle(handle, &module.global_expressions)
829                })?
830            }
831
832            for (handle, constant) in module.constants.iter() {
833                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
834                    .map_err(|source| {
835                        ValidationError::Constant {
836                            handle,
837                            name: constant.name.clone().unwrap_or_default(),
838                            source,
839                        }
840                        .with_span_handle(handle, &module.constants)
841                    })?
842            }
843
844            for (handle, r#override) in module.overrides.iter() {
845                self.validate_override(handle, module.to_ctx(), &mod_info)
846                    .map_err(|source| {
847                        ValidationError::Override {
848                            handle,
849                            name: r#override.name.clone().unwrap_or_default(),
850                            source,
851                        }
852                        .with_span_handle(handle, &module.overrides)
853                    })?;
854            }
855        }
856
857        for (var_handle, var) in module.global_variables.iter() {
858            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
859                .map_err(|source| {
860                    ValidationError::GlobalVariable {
861                        handle: var_handle,
862                        name: var.name.clone().unwrap_or_default(),
863                        source,
864                    }
865                    .with_span_handle(var_handle, &module.global_variables)
866                })?;
867        }
868
869        for (handle, fun) in module.functions.iter() {
870            match self.validate_function(fun, module, &mod_info, false) {
871                Ok(info) => mod_info.functions.push(info),
872                Err(error) => {
873                    return Err(error.and_then(|source| {
874                        ValidationError::Function {
875                            handle,
876                            name: fun.name.clone().unwrap_or_default(),
877                            source,
878                        }
879                        .with_span_handle(handle, &module.functions)
880                    }))
881                }
882            }
883        }
884
885        let mut ep_map = FastHashSet::default();
886        for ep in module.entry_points.iter() {
887            if !ep_map.insert((ep.stage, &ep.name)) {
888                return Err(ValidationError::EntryPoint {
889                    stage: ep.stage,
890                    name: ep.name.clone(),
891                    source: EntryPointError::Conflict,
892                }
893                .with_span()); // TODO: keep some EP span information?
894            }
895
896            match self.validate_entry_point(ep, module, &mod_info) {
897                Ok(info) => mod_info.entry_points.push(info),
898                Err(error) => {
899                    return Err(error.and_then(|source| {
900                        ValidationError::EntryPoint {
901                            stage: ep.stage,
902                            name: ep.name.clone(),
903                            source,
904                        }
905                        .with_span()
906                    }));
907                }
908            }
909        }
910
911        Ok(mod_info)
912    }
913}
914
915fn validate_atomic_compare_exchange_struct(
916    types: &crate::UniqueArena<crate::Type>,
917    members: &[crate::StructMember],
918    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
919) -> bool {
920    members.len() == 2
921        && members[0].name.as_deref() == Some("old_value")
922        && scalar_predicate(&types[members[0].ty].inner)
923        && members[1].name.as_deref() == Some("exchanged")
924        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
925}