naga/valid/
interface.rs

1use alloc::vec::Vec;
2
3use bit_set::BitSet;
4
5use super::{
6    analyzer::{FunctionInfo, GlobalUse},
7    Capabilities, Disalignment, FunctionError, ModuleInfo, PushConstantError,
8};
9use crate::arena::{Handle, UniqueArena};
10use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
11
12const MAX_WORKGROUP_SIZE: u32 = 0x4000;
13
14#[derive(Clone, Debug, thiserror::Error)]
15#[cfg_attr(test, derive(PartialEq))]
16pub enum GlobalVariableError {
17    #[error("Usage isn't compatible with address space {0:?}")]
18    InvalidUsage(crate::AddressSpace),
19    #[error("Type isn't compatible with address space {0:?}")]
20    InvalidType(crate::AddressSpace),
21    #[error("Type flags {seen:?} do not meet the required {required:?}")]
22    MissingTypeFlags {
23        required: super::TypeFlags,
24        seen: super::TypeFlags,
25    },
26    #[error("Capability {0:?} is not supported")]
27    UnsupportedCapability(Capabilities),
28    #[error("Binding decoration is missing or not applicable")]
29    InvalidBinding,
30    #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
31    Alignment(
32        crate::AddressSpace,
33        Handle<crate::Type>,
34        #[source] Disalignment,
35    ),
36    #[error("Initializer must be an override-expression")]
37    InitializerExprType,
38    #[error("Initializer doesn't match the variable type")]
39    InitializerType,
40    #[error("Initializer can't be used with address space {0:?}")]
41    InitializerNotAllowed(crate::AddressSpace),
42    #[error("Storage address space doesn't support write-only access")]
43    StorageAddressSpaceWriteOnlyNotSupported,
44    #[error("Type is not valid for use as a push constant")]
45    InvalidPushConstantType(#[source] PushConstantError),
46    #[error("Task payload must not be zero-sized")]
47    ZeroSizedTaskPayload,
48}
49
50#[derive(Clone, Debug, thiserror::Error)]
51#[cfg_attr(test, derive(PartialEq))]
52pub enum VaryingError {
53    #[error("The type {0:?} does not match the varying")]
54    InvalidType(Handle<crate::Type>),
55    #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
56    NotIOShareableType(Handle<crate::Type>),
57    #[error("Interpolation is not valid")]
58    InvalidInterpolation,
59    #[error("Cannot combine {interpolation:?} interpolation with the {sampling:?} sample type")]
60    InvalidInterpolationSamplingCombination {
61        interpolation: crate::Interpolation,
62        sampling: crate::Sampling,
63    },
64    #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
65    MissingInterpolation,
66    #[error("Built-in {0:?} is not available at this stage")]
67    InvalidBuiltInStage(crate::BuiltIn),
68    #[error("Built-in type for {0:?} is invalid")]
69    InvalidBuiltInType(crate::BuiltIn),
70    #[error("Entry point arguments and return values must all have bindings")]
71    MissingBinding,
72    #[error("Struct member {0} is missing a binding")]
73    MemberMissingBinding(u32),
74    #[error("Multiple bindings at location {location} are present")]
75    BindingCollision { location: u32 },
76    #[error("Multiple bindings use the same `blend_src` {blend_src}")]
77    BindingCollisionBlendSrc { blend_src: u32 },
78    #[error("Built-in {0:?} is present more than once")]
79    DuplicateBuiltIn(crate::BuiltIn),
80    #[error("Capability {0:?} is not supported")]
81    UnsupportedCapability(Capabilities),
82    #[error("The attribute {0:?} is only valid as an output for stage {1:?}")]
83    InvalidInputAttributeInStage(&'static str, crate::ShaderStage),
84    #[error("The attribute {0:?} is not valid for stage {1:?}")]
85    InvalidAttributeInStage(&'static str, crate::ShaderStage),
86    #[error("The `blend_src` attribute can only be used on location 0, only indices 0 and 1 are valid. Location was {location}, index was {blend_src}.")]
87    InvalidBlendSrcIndex { location: u32, blend_src: u32 },
88    #[error("If `blend_src` is used, there must be exactly two outputs both with location 0, one with `blend_src(0)` and the other with `blend_src(1)`.")]
89    IncompleteBlendSrcUsage,
90    #[error("If `blend_src` is used, both outputs must have the same type. `blend_src(0)` has type {blend_src_0_type:?} and `blend_src(1)` has type {blend_src_1_type:?}.")]
91    BlendSrcOutputTypeMismatch {
92        blend_src_0_type: Handle<crate::Type>,
93        blend_src_1_type: Handle<crate::Type>,
94    },
95    #[error("Workgroup size is multi dimensional, `@builtin(subgroup_id)` and `@builtin(subgroup_invocation_id)` are not supported.")]
96    InvalidMultiDimensionalSubgroupBuiltIn,
97    #[error("The `@per_primitive` attribute can only be used in fragment shader inputs or mesh shader primitive outputs")]
98    InvalidPerPrimitive,
99    #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")]
100    MissingPerPrimitive,
101}
102
103#[derive(Clone, Debug, thiserror::Error)]
104#[cfg_attr(test, derive(PartialEq))]
105pub enum EntryPointError {
106    #[error("Multiple conflicting entry points")]
107    Conflict,
108    #[error("Vertex shaders must return a `@builtin(position)` output value")]
109    MissingVertexOutputPosition,
110    #[error("Early depth test is not applicable")]
111    UnexpectedEarlyDepthTest,
112    #[error("Workgroup size is not applicable")]
113    UnexpectedWorkgroupSize,
114    #[error("Workgroup size is out of range")]
115    OutOfRangeWorkgroupSize,
116    #[error("Uses operations forbidden at this stage")]
117    ForbiddenStageOperations,
118    #[error("Global variable {0:?} is used incorrectly as {1:?}")]
119    InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
120    #[error("More than 1 push constant variable is used")]
121    MoreThanOnePushConstantUsed,
122    #[error("Bindings for {0:?} conflict with other resource")]
123    BindingCollision(Handle<crate::GlobalVariable>),
124    #[error("Argument {0} varying error")]
125    Argument(u32, #[source] VaryingError),
126    #[error(transparent)]
127    Result(#[from] VaryingError),
128    #[error("Location {location} interpolation of an integer has to be flat")]
129    InvalidIntegerInterpolation { location: u32 },
130    #[error(transparent)]
131    Function(#[from] FunctionError),
132    #[error("Capability {0:?} is not supported")]
133    UnsupportedCapability(Capabilities),
134
135    #[error("mesh shader entry point missing mesh shader attributes")]
136    ExpectedMeshShaderAttributes,
137    #[error("Non mesh shader entry point cannot have mesh shader attributes")]
138    UnexpectedMeshShaderAttributes,
139    #[error("Non mesh/task shader entry point cannot have task payload attribute")]
140    UnexpectedTaskPayload,
141    #[error("Task payload must be declared with `var<task_payload>`")]
142    TaskPayloadWrongAddressSpace,
143    #[error("For a task payload to be used, it must be declared with @payload")]
144    WrongTaskPayloadUsed,
145    #[error("Task shader entry point must return @builtin(mesh_task_size) vec3<u32>")]
146    WrongTaskShaderEntryResult,
147    #[error("Task shaders must declare a task payload output")]
148    ExpectedTaskPayload,
149    #[error(
150        "Mesh shader output variable must be a struct with fields that are all allowed builtins"
151    )]
152    BadMeshOutputVariableType,
153    #[error("Mesh shader output variable fields must have types that are in accordance with the mesh shader spec")]
154    BadMeshOutputVariableField,
155    #[error("Mesh shader entry point cannot have a return type")]
156    UnexpectedMeshShaderEntryResult,
157    #[error(
158        "Mesh output type must be a user-defined struct with fields in alignment with the mesh shader spec"
159    )]
160    InvalidMeshOutputType,
161    #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")]
162    InvalidMeshPrimitiveOutputType,
163    #[error("Mesh output global variable must live in the workgroup address space")]
164    WrongMeshOutputAddressSpace,
165    #[error("Task payload must be at least 4 bytes, but is {0} bytes")]
166    TaskPayloadTooSmall(u32),
167}
168
169fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
170    let mut storage_usage = GlobalUse::QUERY;
171    if access.contains(crate::StorageAccess::LOAD) {
172        storage_usage |= GlobalUse::READ;
173    }
174    if access.contains(crate::StorageAccess::STORE) {
175        storage_usage |= GlobalUse::WRITE;
176    }
177    if access.contains(crate::StorageAccess::ATOMIC) {
178        storage_usage |= GlobalUse::ATOMIC;
179    }
180    storage_usage
181}
182
183#[derive(Clone, Copy, Debug, PartialEq, Eq)]
184enum MeshOutputType {
185    None,
186    VertexOutput,
187    PrimitiveOutput,
188}
189
190struct VaryingContext<'a> {
191    stage: crate::ShaderStage,
192    output: bool,
193    types: &'a UniqueArena<crate::Type>,
194    type_info: &'a Vec<super::r#type::TypeInfo>,
195    location_mask: &'a mut BitSet,
196    blend_src_mask: &'a mut BitSet,
197    built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
198    capabilities: Capabilities,
199    flags: super::ValidationFlags,
200    mesh_output_type: MeshOutputType,
201    has_task_payload: bool,
202}
203
204impl VaryingContext<'_> {
205    fn validate_impl(
206        &mut self,
207        ep: &crate::EntryPoint,
208        ty: Handle<crate::Type>,
209        binding: &crate::Binding,
210    ) -> Result<(), VaryingError> {
211        use crate::{BuiltIn as Bi, ShaderStage as St, TypeInner as Ti, VectorSize as Vs};
212
213        let ty_inner = &self.types[ty].inner;
214        match *binding {
215            crate::Binding::BuiltIn(built_in) => {
216                // Ignore the `invariant` field for the sake of duplicate checks,
217                // but use the original in error messages.
218                let canonical = if let crate::BuiltIn::Position { .. } = built_in {
219                    crate::BuiltIn::Position { invariant: false }
220                } else {
221                    built_in
222                };
223
224                if self.built_ins.contains(&canonical) {
225                    return Err(VaryingError::DuplicateBuiltIn(built_in));
226                }
227                self.built_ins.insert(canonical);
228
229                let required = match built_in {
230                    Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
231                    Bi::CullDistance => Capabilities::CULL_DISTANCE,
232                    Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
233                    Bi::Barycentric => Capabilities::SHADER_BARYCENTRICS,
234                    Bi::ViewIndex => Capabilities::MULTIVIEW,
235                    Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
236                    Bi::NumSubgroups
237                    | Bi::SubgroupId
238                    | Bi::SubgroupSize
239                    | Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
240                    _ => Capabilities::empty(),
241                };
242                if !self.capabilities.contains(required) {
243                    return Err(VaryingError::UnsupportedCapability(required));
244                }
245
246                if matches!(
247                    built_in,
248                    crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
249                ) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
250                {
251                    return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
252                }
253
254                let (visible, type_good) = match built_in {
255                    Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
256                        self.stage == St::Vertex && !self.output,
257                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
258                    ),
259                    Bi::DrawID => (
260                        // Always allowed in task/vertex stage. Allowed in mesh stage if there is no task stage in the pipeline.
261                        (self.stage == St::Vertex
262                            || self.stage == St::Task
263                            || (self.stage == St::Mesh && !self.has_task_payload))
264                            && !self.output,
265                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
266                    ),
267                    Bi::ClipDistance | Bi::CullDistance => (
268                        (self.stage == St::Vertex || self.stage == St::Mesh) && self.output,
269                        match *ty_inner {
270                            Ti::Array { base, size, .. } => {
271                                self.types[base].inner == Ti::Scalar(crate::Scalar::F32)
272                                    && match size {
273                                        crate::ArraySize::Constant(non_zero) => non_zero.get() <= 8,
274                                        _ => false,
275                                    }
276                            }
277                            _ => false,
278                        },
279                    ),
280                    Bi::PointSize => (
281                        (self.stage == St::Vertex || self.stage == St::Mesh) && self.output,
282                        *ty_inner == Ti::Scalar(crate::Scalar::F32),
283                    ),
284                    Bi::PointCoord => (
285                        self.stage == St::Fragment && !self.output,
286                        *ty_inner
287                            == Ti::Vector {
288                                size: Vs::Bi,
289                                scalar: crate::Scalar::F32,
290                            },
291                    ),
292                    Bi::Position { .. } => (
293                        match self.stage {
294                            St::Vertex | St::Mesh => self.output,
295                            St::Fragment => !self.output,
296                            St::Compute | St::Task => false,
297                        },
298                        *ty_inner
299                            == Ti::Vector {
300                                size: Vs::Quad,
301                                scalar: crate::Scalar::F32,
302                            },
303                    ),
304                    Bi::ViewIndex => (
305                        match self.stage {
306                            St::Vertex | St::Fragment | St::Task | St::Mesh => !self.output,
307                            St::Compute => false,
308                        },
309                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
310                    ),
311                    Bi::FragDepth => (
312                        self.stage == St::Fragment && self.output,
313                        *ty_inner == Ti::Scalar(crate::Scalar::F32),
314                    ),
315                    Bi::FrontFacing => (
316                        self.stage == St::Fragment && !self.output,
317                        *ty_inner == Ti::Scalar(crate::Scalar::BOOL),
318                    ),
319                    Bi::PrimitiveIndex => (
320                        (self.stage == St::Fragment && !self.output)
321                            || (self.stage == St::Mesh
322                                && self.output
323                                && self.mesh_output_type == MeshOutputType::PrimitiveOutput),
324                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
325                    ),
326                    Bi::Barycentric => (
327                        self.stage == St::Fragment && !self.output,
328                        *ty_inner
329                            == Ti::Vector {
330                                size: Vs::Tri,
331                                scalar: crate::Scalar::F32,
332                            },
333                    ),
334                    Bi::SampleIndex => (
335                        self.stage == St::Fragment && !self.output,
336                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
337                    ),
338                    Bi::SampleMask => (
339                        self.stage == St::Fragment,
340                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
341                    ),
342                    Bi::LocalInvocationIndex => (
343                        self.stage.compute_like() && !self.output,
344                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
345                    ),
346                    Bi::GlobalInvocationId
347                    | Bi::LocalInvocationId
348                    | Bi::WorkGroupId
349                    | Bi::WorkGroupSize
350                    | Bi::NumWorkGroups => (
351                        self.stage.compute_like() && !self.output,
352                        *ty_inner
353                            == Ti::Vector {
354                                size: Vs::Tri,
355                                scalar: crate::Scalar::U32,
356                            },
357                    ),
358                    Bi::NumSubgroups | Bi::SubgroupId => (
359                        self.stage.compute_like() && !self.output,
360                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
361                    ),
362                    Bi::SubgroupSize | Bi::SubgroupInvocationId => (
363                        match self.stage {
364                            St::Compute | St::Fragment | St::Task | St::Mesh => !self.output,
365                            St::Vertex => false,
366                        },
367                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
368                    ),
369                    Bi::CullPrimitive => (
370                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
371                        *ty_inner == Ti::Scalar(crate::Scalar::BOOL),
372                    ),
373                    Bi::PointIndex => (
374                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
375                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
376                    ),
377                    Bi::LineIndices => (
378                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
379                        *ty_inner
380                            == Ti::Vector {
381                                size: Vs::Bi,
382                                scalar: crate::Scalar::U32,
383                            },
384                    ),
385                    Bi::TriangleIndices => (
386                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
387                        *ty_inner
388                            == Ti::Vector {
389                                size: Vs::Tri,
390                                scalar: crate::Scalar::U32,
391                            },
392                    ),
393                    Bi::MeshTaskSize => (
394                        self.stage == St::Task && self.output,
395                        *ty_inner
396                            == Ti::Vector {
397                                size: Vs::Tri,
398                                scalar: crate::Scalar::U32,
399                            },
400                    ),
401                    // Validated elsewhere, shouldn't be here
402                    Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives => {
403                        (false, true)
404                    }
405                };
406                match built_in {
407                    Bi::CullPrimitive
408                    | Bi::PointIndex
409                    | Bi::LineIndices
410                    | Bi::TriangleIndices
411                    | Bi::MeshTaskSize
412                    | Bi::VertexCount
413                    | Bi::PrimitiveCount
414                    | Bi::Vertices
415                    | Bi::Primitives => {
416                        if !self.capabilities.contains(Capabilities::MESH_SHADER) {
417                            return Err(VaryingError::UnsupportedCapability(
418                                Capabilities::MESH_SHADER,
419                            ));
420                        }
421                    }
422                    _ => (),
423                }
424
425                if !visible {
426                    return Err(VaryingError::InvalidBuiltInStage(built_in));
427                }
428                if !type_good {
429                    log::warn!("Wrong builtin type: {ty_inner:?}");
430                    return Err(VaryingError::InvalidBuiltInType(built_in));
431                }
432            }
433            crate::Binding::Location {
434                location,
435                interpolation,
436                sampling,
437                blend_src,
438                per_primitive,
439            } => {
440                if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) {
441                    return Err(VaryingError::UnsupportedCapability(
442                        Capabilities::MESH_SHADER,
443                    ));
444                }
445                // Only IO-shareable types may be stored in locations.
446                if !self.type_info[ty.index()]
447                    .flags
448                    .contains(super::TypeFlags::IO_SHAREABLE)
449                {
450                    return Err(VaryingError::NotIOShareableType(ty));
451                }
452
453                // Check whether `per_primitive` is appropriate for this stage and direction.
454                if self.mesh_output_type == MeshOutputType::PrimitiveOutput {
455                    // All mesh shader `Location` outputs must be `per_primitive`.
456                    if !per_primitive {
457                        return Err(VaryingError::MissingPerPrimitive);
458                    }
459                } else if self.stage == crate::ShaderStage::Fragment && !self.output {
460                    // Fragment stage inputs may be `per_primitive`. We'll only
461                    // know if these are correct when the whole mesh pipeline is
462                    // created and we're paired with a specific mesh or vertex
463                    // shader.
464                } else if per_primitive {
465                    // All other `Location` bindings must not be `per_primitive`.
466                    return Err(VaryingError::InvalidPerPrimitive);
467                }
468
469                if let Some(blend_src) = blend_src {
470                    // `blend_src` is only valid if dual source blending was explicitly enabled,
471                    // see https://www.w3.org/TR/WGSL/#extension-dual_source_blending
472                    if !self
473                        .capabilities
474                        .contains(Capabilities::DUAL_SOURCE_BLENDING)
475                    {
476                        return Err(VaryingError::UnsupportedCapability(
477                            Capabilities::DUAL_SOURCE_BLENDING,
478                        ));
479                    }
480                    if self.stage != crate::ShaderStage::Fragment {
481                        return Err(VaryingError::InvalidAttributeInStage(
482                            "blend_src",
483                            self.stage,
484                        ));
485                    }
486                    if !self.output {
487                        return Err(VaryingError::InvalidInputAttributeInStage(
488                            "blend_src",
489                            self.stage,
490                        ));
491                    }
492                    if (blend_src != 0 && blend_src != 1) || location != 0 {
493                        return Err(VaryingError::InvalidBlendSrcIndex {
494                            location,
495                            blend_src,
496                        });
497                    }
498                    if !self.blend_src_mask.insert(blend_src as usize) {
499                        return Err(VaryingError::BindingCollisionBlendSrc { blend_src });
500                    }
501                } else if !self.location_mask.insert(location as usize)
502                    && self.flags.contains(super::ValidationFlags::BINDINGS)
503                {
504                    return Err(VaryingError::BindingCollision { location });
505                }
506
507                if let Some(interpolation) = interpolation {
508                    let invalid_sampling = match (interpolation, sampling) {
509                        (_, None)
510                        | (
511                            crate::Interpolation::Perspective | crate::Interpolation::Linear,
512                            Some(
513                                crate::Sampling::Center
514                                | crate::Sampling::Centroid
515                                | crate::Sampling::Sample,
516                            ),
517                        )
518                        | (
519                            crate::Interpolation::Flat,
520                            Some(crate::Sampling::First | crate::Sampling::Either),
521                        ) => None,
522                        (_, Some(invalid_sampling)) => Some(invalid_sampling),
523                    };
524                    if let Some(sampling) = invalid_sampling {
525                        return Err(VaryingError::InvalidInterpolationSamplingCombination {
526                            interpolation,
527                            sampling,
528                        });
529                    }
530                }
531
532                let needs_interpolation = match self.stage {
533                    crate::ShaderStage::Vertex => self.output,
534                    crate::ShaderStage::Fragment => !self.output && !per_primitive,
535                    crate::ShaderStage::Compute | crate::ShaderStage::Task => false,
536                    crate::ShaderStage::Mesh => self.output,
537                };
538
539                // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
540                // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
541                // qualifiers, so we won't complain about that here.
542                let _ = sampling;
543
544                let required = match sampling {
545                    Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
546                    _ => Capabilities::empty(),
547                };
548                if !self.capabilities.contains(required) {
549                    return Err(VaryingError::UnsupportedCapability(required));
550                }
551
552                match ty_inner.scalar_kind() {
553                    Some(crate::ScalarKind::Float) => {
554                        if needs_interpolation && interpolation.is_none() {
555                            return Err(VaryingError::MissingInterpolation);
556                        }
557                    }
558                    Some(_) => {
559                        if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
560                        {
561                            return Err(VaryingError::InvalidInterpolation);
562                        }
563                    }
564                    None => return Err(VaryingError::InvalidType(ty)),
565                }
566            }
567        }
568
569        Ok(())
570    }
571
572    fn validate(
573        &mut self,
574        ep: &crate::EntryPoint,
575        ty: Handle<crate::Type>,
576        binding: Option<&crate::Binding>,
577    ) -> Result<(), WithSpan<VaryingError>> {
578        let span_context = self.types.get_span_context(ty);
579        match binding {
580            Some(binding) => self
581                .validate_impl(ep, ty, binding)
582                .map_err(|e| e.with_span_context(span_context)),
583            None => {
584                let crate::TypeInner::Struct { ref members, .. } = self.types[ty].inner else {
585                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
586                        return Err(VaryingError::MissingBinding.with_span());
587                    } else {
588                        return Ok(());
589                    }
590                };
591
592                for (index, member) in members.iter().enumerate() {
593                    let span_context = self.types.get_span_context(ty);
594                    match member.binding {
595                        None => {
596                            if self.flags.contains(super::ValidationFlags::BINDINGS) {
597                                return Err(VaryingError::MemberMissingBinding(index as u32)
598                                    .with_span_context(span_context));
599                            }
600                        }
601                        Some(ref binding) => self
602                            .validate_impl(ep, member.ty, binding)
603                            .map_err(|e| e.with_span_context(span_context))?,
604                    }
605                }
606
607                if !self.blend_src_mask.is_empty() {
608                    let span_context = self.types.get_span_context(ty);
609
610                    // If there's any blend_src usage, it must apply to all members of which there must be exactly 2.
611                    if members.len() != 2 || self.blend_src_mask.len() != 2 {
612                        return Err(
613                            VaryingError::IncompleteBlendSrcUsage.with_span_context(span_context)
614                        );
615                    }
616                    // Also, all members must have the same type.
617                    if members[0].ty != members[1].ty {
618                        return Err(VaryingError::BlendSrcOutputTypeMismatch {
619                            blend_src_0_type: members[0].ty,
620                            blend_src_1_type: members[1].ty,
621                        }
622                        .with_span_context(span_context));
623                    }
624                }
625                Ok(())
626            }
627        }
628    }
629}
630
631impl super::Validator {
632    pub(super) fn validate_global_var(
633        &self,
634        var: &crate::GlobalVariable,
635        gctx: crate::proc::GlobalCtx,
636        mod_info: &ModuleInfo,
637        global_expr_kind: &crate::proc::ExpressionKindTracker,
638    ) -> Result<(), GlobalVariableError> {
639        use super::TypeFlags;
640
641        log::debug!("var {var:?}");
642        let inner_ty = match gctx.types[var.ty].inner {
643            // A binding array is (mostly) supposed to behave the same as a
644            // series of individually bound resources, so we can (mostly)
645            // validate a `binding_array<T>` as if it were just a plain `T`.
646            crate::TypeInner::BindingArray { base, .. } => match var.space {
647                crate::AddressSpace::Storage { .. }
648                | crate::AddressSpace::Uniform
649                | crate::AddressSpace::Handle => base,
650                _ => return Err(GlobalVariableError::InvalidUsage(var.space)),
651            },
652            _ => var.ty,
653        };
654        let type_info = &self.types[inner_ty.index()];
655
656        let (required_type_flags, is_resource) = match var.space {
657            crate::AddressSpace::Function => {
658                return Err(GlobalVariableError::InvalidUsage(var.space))
659            }
660            crate::AddressSpace::Storage { access } => {
661                if let Err((ty_handle, disalignment)) = type_info.storage_layout {
662                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
663                        return Err(GlobalVariableError::Alignment(
664                            var.space,
665                            ty_handle,
666                            disalignment,
667                        ));
668                    }
669                }
670                if access == crate::StorageAccess::STORE {
671                    return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported);
672                }
673                (
674                    TypeFlags::DATA | TypeFlags::HOST_SHAREABLE | TypeFlags::CREATION_RESOLVED,
675                    true,
676                )
677            }
678            crate::AddressSpace::Uniform => {
679                if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
680                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
681                        return Err(GlobalVariableError::Alignment(
682                            var.space,
683                            ty_handle,
684                            disalignment,
685                        ));
686                    }
687                }
688                (
689                    TypeFlags::DATA
690                        | TypeFlags::COPY
691                        | TypeFlags::SIZED
692                        | TypeFlags::HOST_SHAREABLE
693                        | TypeFlags::CREATION_RESOLVED,
694                    true,
695                )
696            }
697            crate::AddressSpace::Handle => {
698                match gctx.types[inner_ty].inner {
699                    crate::TypeInner::Image { class, .. } => match class {
700                        crate::ImageClass::Storage {
701                            format:
702                                crate::StorageFormat::R16Unorm
703                                | crate::StorageFormat::R16Snorm
704                                | crate::StorageFormat::Rg16Unorm
705                                | crate::StorageFormat::Rg16Snorm
706                                | crate::StorageFormat::Rgba16Unorm
707                                | crate::StorageFormat::Rgba16Snorm,
708                            ..
709                        } => {
710                            if !self
711                                .capabilities
712                                .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
713                            {
714                                return Err(GlobalVariableError::UnsupportedCapability(
715                                    Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
716                                ));
717                            }
718                        }
719                        _ => {}
720                    },
721                    crate::TypeInner::Sampler { .. }
722                    | crate::TypeInner::AccelerationStructure { .. }
723                    | crate::TypeInner::RayQuery { .. } => {}
724                    _ => {
725                        return Err(GlobalVariableError::InvalidType(var.space));
726                    }
727                }
728
729                (TypeFlags::empty(), true)
730            }
731            crate::AddressSpace::Private => (
732                TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED,
733                false,
734            ),
735            crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false),
736            crate::AddressSpace::TaskPayload => {
737                if !self.capabilities.contains(Capabilities::MESH_SHADER) {
738                    return Err(GlobalVariableError::UnsupportedCapability(
739                        Capabilities::MESH_SHADER,
740                    ));
741                }
742                (TypeFlags::DATA | TypeFlags::SIZED, false)
743            }
744            crate::AddressSpace::PushConstant => {
745                if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
746                    return Err(GlobalVariableError::UnsupportedCapability(
747                        Capabilities::PUSH_CONSTANT,
748                    ));
749                }
750                if let Err(ref err) = type_info.push_constant_compatibility {
751                    return Err(GlobalVariableError::InvalidPushConstantType(err.clone()));
752                }
753                (
754                    TypeFlags::DATA
755                        | TypeFlags::COPY
756                        | TypeFlags::HOST_SHAREABLE
757                        | TypeFlags::SIZED,
758                    false,
759                )
760            }
761        };
762
763        if !type_info.flags.contains(required_type_flags) {
764            return Err(GlobalVariableError::MissingTypeFlags {
765                seen: type_info.flags,
766                required: required_type_flags,
767            });
768        }
769
770        if is_resource != var.binding.is_some() {
771            if self.flags.contains(super::ValidationFlags::BINDINGS) {
772                return Err(GlobalVariableError::InvalidBinding);
773            }
774        }
775
776        if var.space == crate::AddressSpace::TaskPayload {
777            let ty = &gctx.types[var.ty].inner;
778            // HLSL doesn't allow zero sized payloads.
779            if ty.try_size(gctx) == Some(0) {
780                return Err(GlobalVariableError::ZeroSizedTaskPayload);
781            }
782        }
783
784        if let Some(init) = var.init {
785            match var.space {
786                crate::AddressSpace::Private | crate::AddressSpace::Function => {}
787                _ => {
788                    return Err(GlobalVariableError::InitializerNotAllowed(var.space));
789                }
790            }
791
792            if !global_expr_kind.is_const_or_override(init) {
793                return Err(GlobalVariableError::InitializerExprType);
794            }
795
796            if !gctx.compare_types(
797                &crate::proc::TypeResolution::Handle(var.ty),
798                &mod_info[init],
799            ) {
800                return Err(GlobalVariableError::InitializerType);
801            }
802        }
803
804        Ok(())
805    }
806
807    /// Validate the mesh shader output type `ty`, used as `mesh_output_type`.
808    fn validate_mesh_output_type(
809        &mut self,
810        ep: &crate::EntryPoint,
811        module: &crate::Module,
812        ty: Handle<crate::Type>,
813        mesh_output_type: MeshOutputType,
814    ) -> Result<(), WithSpan<EntryPointError>> {
815        if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) {
816            return Err(EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types));
817        }
818        let mut result_built_ins = crate::FastHashSet::default();
819        let mut ctx = VaryingContext {
820            stage: ep.stage,
821            output: true,
822            types: &module.types,
823            type_info: &self.types,
824            location_mask: &mut self.location_mask,
825            blend_src_mask: &mut self.blend_src_mask,
826            built_ins: &mut result_built_ins,
827            capabilities: self.capabilities,
828            flags: self.flags,
829            mesh_output_type,
830            has_task_payload: ep.task_payload.is_some(),
831        };
832        ctx.validate(ep, ty, None)
833            .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
834        if mesh_output_type == MeshOutputType::PrimitiveOutput {
835            let mut num_indices_builtins = 0;
836            if result_built_ins.contains(&crate::BuiltIn::PointIndex) {
837                num_indices_builtins += 1;
838            }
839            if result_built_ins.contains(&crate::BuiltIn::LineIndices) {
840                num_indices_builtins += 1;
841            }
842            if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) {
843                num_indices_builtins += 1;
844            }
845            if num_indices_builtins != 1 {
846                return Err(EntryPointError::InvalidMeshPrimitiveOutputType
847                    .with_span_handle(ty, &module.types));
848            }
849        } else if mesh_output_type == MeshOutputType::VertexOutput
850            && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
851        {
852            return Err(
853                EntryPointError::MissingVertexOutputPosition.with_span_handle(ty, &module.types)
854            );
855        }
856
857        Ok(())
858    }
859
860    pub(super) fn validate_entry_point(
861        &mut self,
862        ep: &crate::EntryPoint,
863        module: &crate::Module,
864        mod_info: &ModuleInfo,
865    ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
866        if matches!(
867            ep.stage,
868            crate::ShaderStage::Task | crate::ShaderStage::Mesh
869        ) && !self.capabilities.contains(Capabilities::MESH_SHADER)
870        {
871            return Err(
872                EntryPointError::UnsupportedCapability(Capabilities::MESH_SHADER).with_span(),
873            );
874        }
875        if ep.early_depth_test.is_some() {
876            let required = Capabilities::EARLY_DEPTH_TEST;
877            if !self.capabilities.contains(required) {
878                return Err(
879                    EntryPointError::Result(VaryingError::UnsupportedCapability(required))
880                        .with_span(),
881                );
882            }
883
884            if ep.stage != crate::ShaderStage::Fragment {
885                return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
886            }
887        }
888
889        if ep.stage.compute_like() {
890            if ep
891                .workgroup_size
892                .iter()
893                .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
894            {
895                return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
896            }
897        } else if ep.workgroup_size != [0; 3] {
898            return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
899        }
900
901        match (ep.stage, &ep.mesh_info) {
902            (crate::ShaderStage::Mesh, &None) => {
903                return Err(EntryPointError::ExpectedMeshShaderAttributes.with_span());
904            }
905            (crate::ShaderStage::Mesh, &Some(..)) => {}
906            (_, &Some(_)) => {
907                return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span());
908            }
909            (_, _) => {}
910        }
911
912        let mut info = self
913            .validate_function(&ep.function, module, mod_info, true)
914            .map_err(WithSpan::into_other)?;
915
916        // Validate the task shader payload.
917        match ep.stage {
918            // Task shaders must produce a payload.
919            crate::ShaderStage::Task => {
920                let Some(handle) = ep.task_payload else {
921                    return Err(EntryPointError::ExpectedTaskPayload.with_span());
922                };
923                if module.global_variables[handle].space != crate::AddressSpace::TaskPayload {
924                    return Err(EntryPointError::TaskPayloadWrongAddressSpace
925                        .with_span_handle(handle, &module.global_variables));
926                }
927                info.insert_global_use(GlobalUse::READ | GlobalUse::WRITE, handle);
928            }
929
930            // Mesh shaders may accept a payload.
931            crate::ShaderStage::Mesh => {
932                if let Some(handle) = ep.task_payload {
933                    if module.global_variables[handle].space != crate::AddressSpace::TaskPayload {
934                        return Err(EntryPointError::TaskPayloadWrongAddressSpace
935                            .with_span_handle(handle, &module.global_variables));
936                    }
937                    info.insert_global_use(GlobalUse::READ, handle);
938                }
939                if let Some(ref mesh_info) = ep.mesh_info {
940                    info.insert_global_use(GlobalUse::READ, mesh_info.output_variable);
941                }
942            }
943
944            // Other stages must not have a payload.
945            _ => {
946                if let Some(handle) = ep.task_payload {
947                    return Err(EntryPointError::UnexpectedTaskPayload
948                        .with_span_handle(handle, &module.global_variables));
949                }
950            }
951        }
952
953        {
954            use super::ShaderStages;
955
956            let stage_bit = match ep.stage {
957                crate::ShaderStage::Vertex => ShaderStages::VERTEX,
958                crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
959                crate::ShaderStage::Compute => ShaderStages::COMPUTE,
960                crate::ShaderStage::Mesh => ShaderStages::MESH,
961                crate::ShaderStage::Task => ShaderStages::TASK,
962            };
963
964            if !info.available_stages.contains(stage_bit) {
965                return Err(EntryPointError::ForbiddenStageOperations.with_span());
966            }
967        }
968
969        self.location_mask.clear();
970        let mut argument_built_ins = crate::FastHashSet::default();
971        // TODO: add span info to function arguments
972        for (index, fa) in ep.function.arguments.iter().enumerate() {
973            let mut ctx = VaryingContext {
974                stage: ep.stage,
975                output: false,
976                types: &module.types,
977                type_info: &self.types,
978                location_mask: &mut self.location_mask,
979                blend_src_mask: &mut self.blend_src_mask,
980                built_ins: &mut argument_built_ins,
981                capabilities: self.capabilities,
982                flags: self.flags,
983                mesh_output_type: MeshOutputType::None,
984                has_task_payload: ep.task_payload.is_some(),
985            };
986            ctx.validate(ep, fa.ty, fa.binding.as_ref())
987                .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
988        }
989
990        self.location_mask.clear();
991        if let Some(ref fr) = ep.function.result {
992            let mut result_built_ins = crate::FastHashSet::default();
993            let mut ctx = VaryingContext {
994                stage: ep.stage,
995                output: true,
996                types: &module.types,
997                type_info: &self.types,
998                location_mask: &mut self.location_mask,
999                blend_src_mask: &mut self.blend_src_mask,
1000                built_ins: &mut result_built_ins,
1001                capabilities: self.capabilities,
1002                flags: self.flags,
1003                mesh_output_type: MeshOutputType::None,
1004                has_task_payload: ep.task_payload.is_some(),
1005            };
1006            ctx.validate(ep, fr.ty, fr.binding.as_ref())
1007                .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
1008            if ep.stage == crate::ShaderStage::Vertex
1009                && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
1010            {
1011                return Err(EntryPointError::MissingVertexOutputPosition.with_span());
1012            }
1013            if ep.stage == crate::ShaderStage::Mesh {
1014                return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span());
1015            }
1016            // Task shaders must have a single `MeshTaskSize` output, and nothing else.
1017            if ep.stage == crate::ShaderStage::Task {
1018                let ok = result_built_ins.contains(&crate::BuiltIn::MeshTaskSize)
1019                    && result_built_ins.len() == 1
1020                    && self.location_mask.is_empty();
1021                if !ok {
1022                    return Err(EntryPointError::WrongTaskShaderEntryResult.with_span());
1023                }
1024            }
1025            if !self.blend_src_mask.is_empty() {
1026                info.dual_source_blending = true;
1027            }
1028        } else if ep.stage == crate::ShaderStage::Vertex {
1029            return Err(EntryPointError::MissingVertexOutputPosition.with_span());
1030        } else if ep.stage == crate::ShaderStage::Task {
1031            return Err(EntryPointError::WrongTaskShaderEntryResult.with_span());
1032        }
1033
1034        {
1035            let mut used_push_constants = module
1036                .global_variables
1037                .iter()
1038                .filter(|&(_, var)| var.space == crate::AddressSpace::PushConstant)
1039                .map(|(handle, _)| handle)
1040                .filter(|&handle| !info[handle].is_empty());
1041            // Check if there is more than one push constant, and error if so.
1042            // Use a loop for when returning multiple errors is supported.
1043            if let Some(handle) = used_push_constants.nth(1) {
1044                return Err(EntryPointError::MoreThanOnePushConstantUsed
1045                    .with_span_handle(handle, &module.global_variables));
1046            }
1047        }
1048
1049        self.ep_resource_bindings.clear();
1050        for (var_handle, var) in module.global_variables.iter() {
1051            let usage = info[var_handle];
1052            if usage.is_empty() {
1053                continue;
1054            }
1055
1056            if var.space == crate::AddressSpace::TaskPayload {
1057                if ep.task_payload != Some(var_handle) {
1058                    return Err(EntryPointError::WrongTaskPayloadUsed
1059                        .with_span_handle(var_handle, &module.global_variables));
1060                }
1061                let size = module.types[var.ty].inner.size(module.to_ctx());
1062                if size < 4 {
1063                    return Err(EntryPointError::TaskPayloadTooSmall(size)
1064                        .with_span_handle(var_handle, &module.global_variables));
1065                }
1066            }
1067
1068            let allowed_usage = match var.space {
1069                crate::AddressSpace::Function => unreachable!(),
1070                crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
1071                crate::AddressSpace::Storage { access } => storage_usage(access),
1072                crate::AddressSpace::Handle => match module.types[var.ty].inner {
1073                    crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
1074                        crate::TypeInner::Image {
1075                            class: crate::ImageClass::Storage { access, .. },
1076                            ..
1077                        } => storage_usage(access),
1078                        _ => GlobalUse::READ | GlobalUse::QUERY,
1079                    },
1080                    crate::TypeInner::Image {
1081                        class: crate::ImageClass::Storage { access, .. },
1082                        ..
1083                    } => storage_usage(access),
1084                    _ => GlobalUse::READ | GlobalUse::QUERY,
1085                },
1086                crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
1087                    GlobalUse::READ | GlobalUse::WRITE | GlobalUse::QUERY
1088                }
1089                crate::AddressSpace::TaskPayload => {
1090                    GlobalUse::READ
1091                        | GlobalUse::QUERY
1092                        | if ep.stage == crate::ShaderStage::Task {
1093                            GlobalUse::WRITE
1094                        } else {
1095                            GlobalUse::empty()
1096                        }
1097                }
1098                crate::AddressSpace::PushConstant => GlobalUse::READ,
1099            };
1100            if !allowed_usage.contains(usage) {
1101                log::warn!("\tUsage error for: {var:?}");
1102                log::warn!("\tAllowed usage: {allowed_usage:?}, requested: {usage:?}");
1103                return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
1104                    .with_span_handle(var_handle, &module.global_variables));
1105            }
1106
1107            if let Some(ref bind) = var.binding {
1108                if !self.ep_resource_bindings.insert(*bind) {
1109                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
1110                        return Err(EntryPointError::BindingCollision(var_handle)
1111                            .with_span_handle(var_handle, &module.global_variables));
1112                    }
1113                }
1114            }
1115        }
1116
1117        // If this is a `Mesh` entry point, check its vertex and primitive output types.
1118        // We verified previously that only mesh shaders can have `mesh_info`.
1119        if let &Some(ref mesh_info) = &ep.mesh_info {
1120            if module.global_variables[mesh_info.output_variable].space
1121                != crate::AddressSpace::WorkGroup
1122            {
1123                return Err(EntryPointError::WrongMeshOutputAddressSpace.with_span());
1124            }
1125
1126            let mut implied = module.analyze_mesh_shader_info(mesh_info.output_variable);
1127            if let Some(e) = implied.2 {
1128                return Err(e);
1129            }
1130
1131            if let Some(e) = mesh_info.max_vertices_override {
1132                if let crate::Expression::Override(o) = module.global_expressions[e] {
1133                    if implied.1[0] != Some(o) {
1134                        return Err(EntryPointError::BadMeshOutputVariableType.with_span());
1135                    }
1136                }
1137            }
1138            if let Some(e) = mesh_info.max_primitives_override {
1139                if let crate::Expression::Override(o) = module.global_expressions[e] {
1140                    if implied.1[1] != Some(o) {
1141                        return Err(EntryPointError::BadMeshOutputVariableType.with_span());
1142                    }
1143                }
1144            }
1145
1146            implied.0.max_vertices_override = mesh_info.max_vertices_override;
1147            implied.0.max_primitives_override = mesh_info.max_primitives_override;
1148            if implied.0 != *mesh_info {
1149                return Err(EntryPointError::BadMeshOutputVariableType.with_span());
1150            }
1151            if mesh_info.topology == crate::MeshOutputTopology::Points
1152                && !self
1153                    .capabilities
1154                    .contains(Capabilities::MESH_SHADER_POINT_TOPOLOGY)
1155            {
1156                return Err(EntryPointError::UnsupportedCapability(
1157                    Capabilities::MESH_SHADER_POINT_TOPOLOGY,
1158                )
1159                .with_span());
1160            }
1161
1162            self.validate_mesh_output_type(
1163                ep,
1164                module,
1165                mesh_info.vertex_output_type,
1166                MeshOutputType::VertexOutput,
1167            )?;
1168            self.validate_mesh_output_type(
1169                ep,
1170                module,
1171                mesh_info.primitive_output_type,
1172                MeshOutputType::PrimitiveOutput,
1173            )?;
1174        }
1175
1176        Ok(info)
1177    }
1178}