naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19    arena::{Handle, HandleSet},
20    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21    FastHashSet,
22};
23
24//TODO: analyze the model at the same time as we validate it,
25// merge the corresponding matches over expressions and statements.
26
27use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, ImmediateError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38/// Maximum size of a type, in bytes.
39pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; // 1GB
40
41bitflags::bitflags! {
42    /// Validation flags.
43    ///
44    /// If you are working with trusted shaders, then you may be able
45    /// to save some time by skipping validation.
46    ///
47    /// If you do not perform full validation, invalid shaders may
48    /// cause Naga to panic. If you do perform full validation and
49    /// [`Validator::validate`] returns `Ok`, then Naga promises that
50    /// code generation will either succeed or return an error; it
51    /// should never panic.
52    ///
53    /// The default value for `ValidationFlags` is
54    /// `ValidationFlags::all()`.
55    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58    pub struct ValidationFlags: u8 {
59        /// Expressions.
60        const EXPRESSIONS = 0x1;
61        /// Statements and blocks of them.
62        const BLOCKS = 0x2;
63        /// Uniformity of control flow for operations that require it.
64        const CONTROL_FLOW_UNIFORMITY = 0x4;
65        /// Host-shareable structure layouts.
66        const STRUCT_LAYOUTS = 0x8;
67        /// Constants.
68        const CONSTANTS = 0x10;
69        /// Group, binding, and location attributes.
70        const BINDINGS = 0x20;
71    }
72}
73
74impl Default for ValidationFlags {
75    fn default() -> Self {
76        Self::all()
77    }
78}
79
80bitflags::bitflags! {
81    /// Allowed IR capabilities.
82    #[must_use]
83    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86    pub struct Capabilities: u64 {
87        /// Support for [`AddressSpace::Immediate`][1].
88        ///
89        /// [1]: crate::AddressSpace::Immediate
90        const IMMEDIATES = 1 << 0;
91        /// Float values with width = 8.
92        const FLOAT64 = 1 << 1;
93        /// Support for [`BuiltIn::PrimitiveIndex`][1].
94        ///
95        /// [1]: crate::BuiltIn::PrimitiveIndex
96        const PRIMITIVE_INDEX = 1 << 2;
97        /// Support for binding arrays of sampled textures and samplers.
98        const TEXTURE_AND_SAMPLER_BINDING_ARRAY = 1 << 3;
99        /// Support for binding arrays of uniform buffers.
100        const BUFFER_BINDING_ARRAY = 1 << 4;
101        /// Support for binding arrays of storage textures.
102        const STORAGE_TEXTURE_BINDING_ARRAY = 1 << 5;
103        /// Support for binding arrays of storage buffers.
104        const STORAGE_BUFFER_BINDING_ARRAY = 1 << 6;
105        /// Support for [`BuiltIn::ClipDistances`].
106        ///
107        /// [`BuiltIn::ClipDistances`]: crate::BuiltIn::ClipDistances
108        const CLIP_DISTANCES = 1 << 7;
109        /// Support for [`BuiltIn::CullDistance`].
110        ///
111        /// [`BuiltIn::CullDistance`]: crate::BuiltIn::CullDistance
112        const CULL_DISTANCE = 1 << 8;
113        /// Support for 16-bit normalized storage texture formats.
114        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115        /// Support for [`BuiltIn::ViewIndex`].
116        ///
117        /// [`BuiltIn::ViewIndex`]: crate::BuiltIn::ViewIndex
118        const MULTIVIEW = 1 << 10;
119        /// Support for `early_depth_test`.
120        const EARLY_DEPTH_TEST = 1 << 11;
121        /// Support for [`BuiltIn::SampleIndex`] and [`Sampling::Sample`].
122        ///
123        /// [`BuiltIn::SampleIndex`]: crate::BuiltIn::SampleIndex
124        /// [`Sampling::Sample`]: crate::Sampling::Sample
125        const MULTISAMPLED_SHADING = 1 << 12;
126        /// Support for ray queries and acceleration structures.
127        const RAY_QUERY = 1 << 13;
128        /// Support for generating two sources for blending from fragment shaders.
129        const DUAL_SOURCE_BLENDING = 1 << 14;
130        /// Support for arrayed cube textures.
131        const CUBE_ARRAY_TEXTURES = 1 << 15;
132        /// Support for 64-bit signed and unsigned integers.
133        const SHADER_INT64 = 1 << 16;
134        /// Support for subgroup operations (except barriers) in fragment and compute shaders.
135        ///
136        /// Subgroup operations in the vertex stage require
137        /// [`Capabilities::SUBGROUP_VERTEX_STAGE`] in addition to `Capabilities::SUBGROUP`.
138        /// (But note that `create_validator` automatically sets
139        /// `Capabilities::SUBGROUP` whenever `Features::SUBGROUP_VERTEX` is
140        /// available.)
141        ///
142        /// Subgroup barriers require [`Capabilities::SUBGROUP_BARRIER`] in addition to
143        /// `Capabilities::SUBGROUP`.
144        const SUBGROUP = 1 << 17;
145        /// Support for subgroup barriers in compute shaders.
146        ///
147        /// Requires [`Capabilities::SUBGROUP`]. Without it, enables nothing.
148        const SUBGROUP_BARRIER = 1 << 18;
149        /// Support for subgroup operations (not including barriers) in the vertex stage.
150        ///
151        /// Without [`Capabilities::SUBGROUP`], enables nothing. (But note that
152        /// `create_validator` automatically sets `Capabilities::SUBGROUP`
153        /// whenever `Features::SUBGROUP_VERTEX` is available.)
154        const SUBGROUP_VERTEX_STAGE = 1 << 19;
155        /// Support for [`AtomicFunction::Min`] and [`AtomicFunction::Max`] on
156        /// 64-bit integers in the [`Storage`] address space, when the return
157        /// value is not used.
158        ///
159        /// This is the only 64-bit atomic functionality available on Metal 3.1.
160        ///
161        /// [`AtomicFunction::Min`]: crate::AtomicFunction::Min
162        /// [`AtomicFunction::Max`]: crate::AtomicFunction::Max
163        /// [`Storage`]: crate::AddressSpace::Storage
164        const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
165        /// Support for all atomic operations on 64-bit integers.
166        const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
167        /// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`],
168        /// and [`AtomicFunction::Exchange { compare: None }`] on 32-bit floating-point numbers
169        /// in the [`Storage`] address space.
170        ///
171        /// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
172        /// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
173        /// [`AtomicFunction::Exchange { compare: None }`]: crate::AtomicFunction::Exchange
174        /// [`Storage`]: crate::AddressSpace::Storage
175        const SHADER_FLOAT32_ATOMIC = 1 << 22;
176        /// Support for atomic operations on images.
177        const TEXTURE_ATOMIC = 1 << 23;
178        /// Support for atomic operations on 64-bit images.
179        const TEXTURE_INT64_ATOMIC = 1 << 24;
180        /// Support for ray queries returning vertex position
181        const RAY_HIT_VERTEX_POSITION = 1 << 25;
182        /// Support for 16-bit floating-point types.
183        const SHADER_FLOAT16 = 1 << 26;
184        /// Support for [`ImageClass::External`]
185        const TEXTURE_EXTERNAL = 1 << 27;
186        /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
187        /// `f16`-precision values in `f32`s.
188        const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189        /// Support for fragment shader barycentric coordinates.
190        const SHADER_BARYCENTRICS = 1 << 29;
191        /// Support for task shaders, mesh shaders, and per-primitive fragment inputs
192        const MESH_SHADER = 1 << 30;
193        /// Support for mesh shaders which output points.
194        const MESH_SHADER_POINT_TOPOLOGY = 1 << 31;
195        /// Support for non-uniform indexing of binding arrays of sampled textures and samplers.
196        const TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 32;
197        /// Support for non-uniform indexing of binding arrays of uniform buffers.
198        const BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 33;
199        /// Support for non-uniform indexing of binding arrays of storage textures.
200        const STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 34;
201        /// Support for non-uniform indexing of binding arrays of storage buffers.
202        const STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING = 1 << 35;
203        /// Support for cooperative matrix types and operations
204        const COOPERATIVE_MATRIX = 1 << 36;
205        /// Support for per-vertex fragment input.
206        const PER_VERTEX = 1 << 37;
207        /// Support for ray generation, any hit, closest hit, and miss shaders.
208        const RAY_TRACING_PIPELINE = 1 << 38;
209        /// Support for draw index builtin
210        const DRAW_INDEX = 1 << 39;
211        /// Support for binding arrays of acceleration structures.
212        const ACCELERATION_STRUCTURE_BINDING_ARRAY = 1 << 40;
213        /// Support for the `@coherent` memory decoration on storage buffers.
214        const MEMORY_DECORATION_COHERENT = 1 << 41;
215        /// Support for the `@volatile` memory decoration on storage buffers.
216        const MEMORY_DECORATION_VOLATILE = 1 << 42;
217    }
218}
219
220impl Capabilities {
221    /// Returns the extension corresponding to this capability, if there is one.
222    ///
223    /// This is used by integration tests.
224    #[cfg(feature = "wgsl-in")]
225    #[doc(hidden)]
226    pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
227        use crate::front::wgsl::ImplementedEnableExtension as Ext;
228        match *self {
229            Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
230            // NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
231            Self::SHADER_FLOAT16 => Some(Ext::F16),
232            Self::CLIP_DISTANCES => Some(Ext::ClipDistances),
233            Self::MESH_SHADER => Some(Ext::WgpuMeshShader),
234            Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
235            Self::RAY_HIT_VERTEX_POSITION => Some(Ext::WgpuRayQueryVertexReturn),
236            Self::COOPERATIVE_MATRIX => Some(Ext::WgpuCooperativeMatrix),
237            Self::RAY_TRACING_PIPELINE => Some(Ext::WgpuRayTracingPipeline),
238            _ => None,
239        }
240    }
241}
242
243impl Default for Capabilities {
244    fn default() -> Self {
245        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
246    }
247}
248
249bitflags::bitflags! {
250    /// Supported subgroup operations
251    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
252    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
253    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
254    pub struct SubgroupOperationSet: u8 {
255        /// Barriers
256        // Possibly elections, when that is supported.
257        // https://github.com/gfx-rs/wgpu/issues/6042#issuecomment-3272603431
258        // Contrary to what the name "basic" suggests, HLSL/DX12 support the
259        // other subgroup operations, but do not support subgroup barriers.
260        const BASIC = 1 << 0;
261        /// Any, All
262        const VOTE = 1 << 1;
263        /// reductions, scans
264        const ARITHMETIC = 1 << 2;
265        /// ballot, broadcast
266        const BALLOT = 1 << 3;
267        /// shuffle, shuffle xor
268        const SHUFFLE = 1 << 4;
269        /// shuffle up, down
270        const SHUFFLE_RELATIVE = 1 << 5;
271        // We don't support these operations yet
272        // /// Clustered
273        // const CLUSTERED = 1 << 6;
274        /// Quad supported
275        const QUAD_FRAGMENT_COMPUTE = 1 << 7;
276        // /// Quad supported in all stages
277        // const QUAD_ALL_STAGES = 1 << 8;
278    }
279}
280
281impl super::SubgroupOperation {
282    const fn required_operations(&self) -> SubgroupOperationSet {
283        use SubgroupOperationSet as S;
284        match *self {
285            Self::All | Self::Any => S::VOTE,
286            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
287                S::ARITHMETIC
288            }
289        }
290    }
291}
292
293impl super::GatherMode {
294    const fn required_operations(&self) -> SubgroupOperationSet {
295        use SubgroupOperationSet as S;
296        match *self {
297            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
298            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
299            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
300            Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
301        }
302    }
303}
304
305bitflags::bitflags! {
306    /// Validation flags.
307    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
308    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
309    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
310    pub struct ShaderStages: u16 {
311        const VERTEX = 0x1;
312        const FRAGMENT = 0x2;
313        const COMPUTE = 0x4;
314        const MESH = 0x8;
315        const TASK = 0x10;
316        const RAY_GENERATION = 0x20;
317        const ANY_HIT = 0x40;
318        const CLOSEST_HIT = 0x80;
319        const MISS = 0x100;
320        const COMPUTE_LIKE = Self::COMPUTE.bits() | Self::TASK.bits() | Self::MESH.bits();
321    }
322}
323
324#[derive(Debug, Clone, Default)]
325#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
326#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
327pub struct ModuleInfo {
328    type_flags: Vec<TypeFlags>,
329    functions: Vec<FunctionInfo>,
330    entry_points: Vec<FunctionInfo>,
331    const_expression_types: Box<[TypeResolution]>,
332}
333
334impl ops::Index<Handle<crate::Type>> for ModuleInfo {
335    type Output = TypeFlags;
336    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
337        &self.type_flags[handle.index()]
338    }
339}
340
341impl ops::Index<Handle<crate::Function>> for ModuleInfo {
342    type Output = FunctionInfo;
343    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
344        &self.functions[handle.index()]
345    }
346}
347
348impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
349    type Output = TypeResolution;
350    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
351        &self.const_expression_types[handle.index()]
352    }
353}
354
355#[derive(Debug)]
356pub struct Validator {
357    flags: ValidationFlags,
358    capabilities: Capabilities,
359    subgroup_stages: ShaderStages,
360    subgroup_operations: SubgroupOperationSet,
361    types: Vec<r#type::TypeInfo>,
362    layouter: Layouter,
363    location_mask: BitSet,
364    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
365    switch_values: FastHashSet<crate::SwitchValue>,
366    valid_expression_list: Vec<Handle<crate::Expression>>,
367    valid_expression_set: HandleSet<crate::Expression>,
368    override_ids: FastHashSet<u16>,
369
370    /// Treat overrides whose initializers are not fully-evaluated
371    /// constant expressions as errors.
372    overrides_resolved: bool,
373
374    /// A checklist of expressions that must be visited by a specific kind of
375    /// statement.
376    ///
377    /// For example:
378    ///
379    /// - [`CallResult`] expressions must be visited by a [`Call`] statement.
380    /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
381    ///
382    /// Be sure not to remove any [`Expression`] handle from this set unless
383    /// you've explicitly checked that it is the right kind of expression for
384    /// the visiting [`Statement`].
385    ///
386    /// [`CallResult`]: crate::Expression::CallResult
387    /// [`Call`]: crate::Statement::Call
388    /// [`AtomicResult`]: crate::Expression::AtomicResult
389    /// [`Atomic`]: crate::Statement::Atomic
390    /// [`Expression`]: crate::Expression
391    /// [`Statement`]: crate::Statement
392    needs_visit: HandleSet<crate::Expression>,
393
394    /// Whether any trace rays call is called, and whether all have vertex return.
395    /// If one call doesn't use vertex ruturn, builtins for triangle vertex positions
396    /// (not yet implemented) are not allowed.
397    trace_rays_vertex_return: TraceRayVertexReturnState,
398
399    /// The type of the ray payload, this must always be the same type in a particular
400    /// entrypoint
401    trace_rays_payload_type: Option<Handle<crate::Type>>,
402}
403
404#[derive(Debug)]
405enum TraceRayVertexReturnState {
406    /// No trace ray calls yet have been found.
407    NoTraceRays,
408    /// Trace ray calls have been found, at least
409    /// one uses an acceleration structure that
410    /// does not have the flag enabling vertex return.
411    #[expect(
412        unused,
413        reason = "Don't yet have vertex return builtins to return this error for."
414    )]
415    NoVertexReturn(crate::Span),
416    /// Trace ray calls have been found, all
417    /// acceleration structures have the flag enabling
418    /// vertex return.
419    VertexReturn,
420}
421
422#[derive(Clone, Debug, thiserror::Error)]
423#[cfg_attr(test, derive(PartialEq))]
424pub enum ConstantError {
425    #[error("Initializer must be a const-expression")]
426    InitializerExprType,
427    #[error("The type doesn't match the constant")]
428    InvalidType,
429    #[error("The type is not constructible")]
430    NonConstructibleType,
431}
432
433#[derive(Clone, Debug, thiserror::Error)]
434#[cfg_attr(test, derive(PartialEq))]
435pub enum OverrideError {
436    #[error("Override name and ID are missing")]
437    MissingNameAndID,
438    #[error("Override ID must be unique")]
439    DuplicateID,
440    #[error("Initializer must be a const-expression or override-expression")]
441    InitializerExprType,
442    #[error("The type doesn't match the override")]
443    InvalidType,
444    #[error("The type is not constructible")]
445    NonConstructibleType,
446    #[error("The type is not a scalar")]
447    TypeNotScalar,
448    #[error("Override declarations are not allowed")]
449    NotAllowed,
450    #[error("Override is uninitialized")]
451    UninitializedOverride,
452    #[error("Constant expression {handle:?} is invalid")]
453    ConstExpression {
454        handle: Handle<crate::Expression>,
455        source: ConstExpressionError,
456    },
457}
458
459#[derive(Clone, Debug, thiserror::Error)]
460#[cfg_attr(test, derive(PartialEq))]
461pub enum ValidationError {
462    #[error(transparent)]
463    InvalidHandle(#[from] InvalidHandleError),
464    #[error(transparent)]
465    Layouter(#[from] LayoutError),
466    #[error("Type {handle:?} '{name}' is invalid")]
467    Type {
468        handle: Handle<crate::Type>,
469        name: String,
470        source: TypeError,
471    },
472    #[error("Constant expression {handle:?} is invalid")]
473    ConstExpression {
474        handle: Handle<crate::Expression>,
475        source: ConstExpressionError,
476    },
477    #[error("Array size expression {handle:?} is not strictly positive")]
478    ArraySizeError { handle: Handle<crate::Expression> },
479    #[error("Constant {handle:?} '{name}' is invalid")]
480    Constant {
481        handle: Handle<crate::Constant>,
482        name: String,
483        source: ConstantError,
484    },
485    #[error("Override {handle:?} '{name}' is invalid")]
486    Override {
487        handle: Handle<crate::Override>,
488        name: String,
489        source: OverrideError,
490    },
491    #[error("Global variable {handle:?} '{name}' is invalid")]
492    GlobalVariable {
493        handle: Handle<crate::GlobalVariable>,
494        name: String,
495        source: GlobalVariableError,
496    },
497    #[error("Function {handle:?} '{name}' is invalid")]
498    Function {
499        handle: Handle<crate::Function>,
500        name: String,
501        source: FunctionError,
502    },
503    #[error("Entry point {name} at {stage:?} is invalid")]
504    EntryPoint {
505        stage: crate::ShaderStage,
506        name: String,
507        source: EntryPointError,
508    },
509    #[error("Module is corrupted")]
510    Corrupted,
511}
512
513impl crate::TypeInner {
514    const fn is_sized(&self) -> bool {
515        match *self {
516            Self::Scalar { .. }
517            | Self::Vector { .. }
518            | Self::Matrix { .. }
519            | Self::CooperativeMatrix { .. }
520            | Self::Array {
521                size: crate::ArraySize::Constant(_),
522                ..
523            }
524            | Self::Atomic { .. }
525            | Self::Pointer { .. }
526            | Self::ValuePointer { .. }
527            | Self::Struct { .. } => true,
528            Self::Array { .. }
529            | Self::Image { .. }
530            | Self::Sampler { .. }
531            | Self::AccelerationStructure { .. }
532            | Self::RayQuery { .. }
533            | Self::BindingArray { .. } => false,
534        }
535    }
536
537    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
538    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
539        match *self {
540            Self::Scalar(crate::Scalar {
541                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
542                ..
543            }) => Some(crate::ImageDimension::D1),
544            Self::Vector {
545                size: crate::VectorSize::Bi,
546                scalar:
547                    crate::Scalar {
548                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
549                        ..
550                    },
551            } => Some(crate::ImageDimension::D2),
552            Self::Vector {
553                size: crate::VectorSize::Tri,
554                scalar:
555                    crate::Scalar {
556                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
557                        ..
558                    },
559            } => Some(crate::ImageDimension::D3),
560            _ => None,
561        }
562    }
563}
564
565impl Validator {
566    /// Create a validator for Naga [`Module`]s.
567    ///
568    /// The `flags` argument indicates which stages of validation the
569    /// returned `Validator` should perform. Skipping stages can make
570    /// validation somewhat faster, but the validator may not reject some
571    /// invalid modules. Regardless of `flags`, validation always returns
572    /// a usable [`ModuleInfo`] value on success.
573    ///
574    /// If `flags` contains everything in `ValidationFlags::default()`,
575    /// then the returned Naga [`Validator`] will reject any [`Module`]
576    /// that would use capabilities not included in `capabilities`.
577    ///
578    /// [`Module`]: crate::Module
579    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
580        let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
581            use SubgroupOperationSet as S;
582            S::BASIC
583                | S::VOTE
584                | S::ARITHMETIC
585                | S::BALLOT
586                | S::SHUFFLE
587                | S::SHUFFLE_RELATIVE
588                | S::QUAD_FRAGMENT_COMPUTE
589        } else {
590            SubgroupOperationSet::empty()
591        };
592        let subgroup_stages = {
593            let mut stages = ShaderStages::empty();
594            if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
595                stages |= ShaderStages::VERTEX;
596            }
597            if capabilities.contains(Capabilities::SUBGROUP) {
598                stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE_LIKE;
599            }
600            stages
601        };
602
603        Validator {
604            flags,
605            capabilities,
606            subgroup_stages,
607            subgroup_operations,
608            types: Vec::new(),
609            layouter: Layouter::default(),
610            location_mask: BitSet::new(),
611            ep_resource_bindings: FastHashSet::default(),
612            switch_values: FastHashSet::default(),
613            valid_expression_list: Vec::new(),
614            valid_expression_set: HandleSet::new(),
615            override_ids: FastHashSet::default(),
616            overrides_resolved: false,
617            needs_visit: HandleSet::new(),
618            trace_rays_vertex_return: TraceRayVertexReturnState::NoTraceRays,
619            trace_rays_payload_type: None,
620        }
621    }
622
623    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
624    pub const fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
625        self.subgroup_stages = stages;
626        self
627    }
628
629    // TODO(https://github.com/gfx-rs/wgpu/issues/8207): Consider removing this
630    pub const fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
631        self.subgroup_operations = operations;
632        self
633    }
634
635    /// Reset the validator internals
636    pub fn reset(&mut self) {
637        self.types.clear();
638        self.layouter.clear();
639        self.location_mask.make_empty();
640        self.ep_resource_bindings.clear();
641        self.switch_values.clear();
642        self.valid_expression_list.clear();
643        self.valid_expression_set.clear();
644        self.override_ids.clear();
645    }
646
647    fn validate_constant(
648        &self,
649        handle: Handle<crate::Constant>,
650        gctx: crate::proc::GlobalCtx,
651        mod_info: &ModuleInfo,
652        global_expr_kind: &ExpressionKindTracker,
653    ) -> Result<(), ConstantError> {
654        let con = &gctx.constants[handle];
655
656        let type_info = &self.types[con.ty.index()];
657        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
658            return Err(ConstantError::NonConstructibleType);
659        }
660
661        if !global_expr_kind.is_const(con.init) {
662            return Err(ConstantError::InitializerExprType);
663        }
664
665        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
666            return Err(ConstantError::InvalidType);
667        }
668
669        Ok(())
670    }
671
672    fn validate_override(
673        &mut self,
674        handle: Handle<crate::Override>,
675        gctx: crate::proc::GlobalCtx,
676        mod_info: &ModuleInfo,
677    ) -> Result<(), OverrideError> {
678        let o = &gctx.overrides[handle];
679
680        if let Some(id) = o.id {
681            if !self.override_ids.insert(id) {
682                return Err(OverrideError::DuplicateID);
683            }
684        }
685
686        let type_info = &self.types[o.ty.index()];
687        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
688            return Err(OverrideError::NonConstructibleType);
689        }
690
691        match gctx.types[o.ty].inner {
692            crate::TypeInner::Scalar(
693                crate::Scalar::BOOL
694                | crate::Scalar::I32
695                | crate::Scalar::U32
696                | crate::Scalar::F16
697                | crate::Scalar::F32
698                | crate::Scalar::F64,
699            ) => {}
700            _ => return Err(OverrideError::TypeNotScalar),
701        }
702
703        if let Some(init) = o.init {
704            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
705                return Err(OverrideError::InvalidType);
706            }
707        } else if self.overrides_resolved {
708            return Err(OverrideError::UninitializedOverride);
709        }
710
711        Ok(())
712    }
713
714    /// Check the given module to be valid.
715    pub fn validate(
716        &mut self,
717        module: &crate::Module,
718    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
719        self.overrides_resolved = false;
720        self.validate_impl(module)
721    }
722
723    /// Check the given module to be valid, requiring overrides to be resolved.
724    ///
725    /// This is the same as [`validate`], except that any override
726    /// whose value is not a fully-evaluated constant expression is
727    /// treated as an error.
728    ///
729    /// [`validate`]: Validator::validate
730    pub fn validate_resolved_overrides(
731        &mut self,
732        module: &crate::Module,
733    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
734        self.overrides_resolved = true;
735        self.validate_impl(module)
736    }
737
738    fn validate_impl(
739        &mut self,
740        module: &crate::Module,
741    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
742        self.reset();
743        self.reset_types(module.types.len());
744
745        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
746
747        self.layouter.update(module.to_ctx()).map_err(|e| {
748            let handle = e.ty;
749            ValidationError::from(e).with_span_handle(handle, &module.types)
750        })?;
751
752        // These should all get overwritten.
753        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
754            kind: crate::ScalarKind::Bool,
755            width: 0,
756        }));
757
758        let mut mod_info = ModuleInfo {
759            type_flags: Vec::with_capacity(module.types.len()),
760            functions: Vec::with_capacity(module.functions.len()),
761            entry_points: Vec::with_capacity(module.entry_points.len()),
762            const_expression_types: vec![placeholder; module.global_expressions.len()]
763                .into_boxed_slice(),
764        };
765
766        for (handle, ty) in module.types.iter() {
767            let ty_info = self
768                .validate_type(handle, module.to_ctx())
769                .map_err(|source| {
770                    ValidationError::Type {
771                        handle,
772                        name: ty.name.clone().unwrap_or_default(),
773                        source,
774                    }
775                    .with_span_handle(handle, &module.types)
776                })?;
777            debug_assert!(
778                ty_info.flags.contains(TypeFlags::CONSTRUCTIBLE)
779                    == module.types[handle].inner.is_constructible(&module.types)
780            );
781            mod_info.type_flags.push(ty_info.flags);
782            self.types[handle.index()] = ty_info;
783        }
784
785        {
786            let t = crate::Arena::new();
787            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
788            for (handle, _) in module.global_expressions.iter() {
789                mod_info
790                    .process_const_expression(handle, &resolve_context, module.to_ctx())
791                    .map_err(|source| {
792                        ValidationError::ConstExpression { handle, source }
793                            .with_span_handle(handle, &module.global_expressions)
794                    })?
795            }
796        }
797
798        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
799
800        if self.flags.contains(ValidationFlags::CONSTANTS) {
801            for (handle, _) in module.global_expressions.iter() {
802                self.validate_const_expression(
803                    handle,
804                    module.to_ctx(),
805                    &mod_info,
806                    &global_expr_kind,
807                )
808                .map_err(|source| {
809                    ValidationError::ConstExpression { handle, source }
810                        .with_span_handle(handle, &module.global_expressions)
811                })?
812            }
813
814            for (handle, constant) in module.constants.iter() {
815                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
816                    .map_err(|source| {
817                        ValidationError::Constant {
818                            handle,
819                            name: constant.name.clone().unwrap_or_default(),
820                            source,
821                        }
822                        .with_span_handle(handle, &module.constants)
823                    })?
824            }
825
826            for (handle, r#override) in module.overrides.iter() {
827                self.validate_override(handle, module.to_ctx(), &mod_info)
828                    .map_err(|source| {
829                        ValidationError::Override {
830                            handle,
831                            name: r#override.name.clone().unwrap_or_default(),
832                            source,
833                        }
834                        .with_span_handle(handle, &module.overrides)
835                    })?;
836            }
837        }
838
839        for (var_handle, var) in module.global_variables.iter() {
840            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
841                .map_err(|source| {
842                    ValidationError::GlobalVariable {
843                        handle: var_handle,
844                        name: var.name.clone().unwrap_or_default(),
845                        source,
846                    }
847                    .with_span_handle(var_handle, &module.global_variables)
848                })?;
849        }
850
851        for (handle, fun) in module.functions.iter() {
852            match self.validate_function(fun, module, &mod_info, false) {
853                Ok(info) => mod_info.functions.push(info),
854                Err(error) => {
855                    return Err(error.and_then(|source| {
856                        ValidationError::Function {
857                            handle,
858                            name: fun.name.clone().unwrap_or_default(),
859                            source,
860                        }
861                        .with_span_handle(handle, &module.functions)
862                    }))
863                }
864            }
865        }
866
867        let mut ep_map = FastHashSet::default();
868        for ep in module.entry_points.iter() {
869            if !ep_map.insert((ep.stage, &ep.name)) {
870                return Err(ValidationError::EntryPoint {
871                    stage: ep.stage,
872                    name: ep.name.clone(),
873                    source: EntryPointError::Conflict,
874                }
875                .with_span()); // TODO: keep some EP span information?
876            }
877
878            match self.validate_entry_point(ep, module, &mod_info) {
879                Ok(info) => mod_info.entry_points.push(info),
880                Err(error) => {
881                    return Err(error.and_then(|source| {
882                        ValidationError::EntryPoint {
883                            stage: ep.stage,
884                            name: ep.name.clone(),
885                            source,
886                        }
887                        .with_span()
888                    }));
889                }
890            }
891        }
892
893        Ok(mod_info)
894    }
895}
896
897fn validate_atomic_compare_exchange_struct(
898    types: &crate::UniqueArena<crate::Type>,
899    members: &[crate::StructMember],
900    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
901) -> bool {
902    members.len() == 2
903        && members[0].name.as_deref() == Some("old_value")
904        && scalar_predicate(&types[members[0].ty].inner)
905        && members[1].name.as_deref() == Some("exchanged")
906        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
907}