naga/valid/
interface.rs

1use alloc::vec::Vec;
2
3use bit_set::BitSet;
4
5use super::{
6    analyzer::{FunctionInfo, GlobalUse},
7    Capabilities, Disalignment, FunctionError, ModuleInfo, PushConstantError,
8};
9use crate::arena::{Handle, UniqueArena};
10use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
11
12const MAX_WORKGROUP_SIZE: u32 = 0x4000;
13
14#[derive(Clone, Debug, thiserror::Error)]
15#[cfg_attr(test, derive(PartialEq))]
16pub enum GlobalVariableError {
17    #[error("Usage isn't compatible with address space {0:?}")]
18    InvalidUsage(crate::AddressSpace),
19    #[error("Type isn't compatible with address space {0:?}")]
20    InvalidType(crate::AddressSpace),
21    #[error("Type flags {seen:?} do not meet the required {required:?}")]
22    MissingTypeFlags {
23        required: super::TypeFlags,
24        seen: super::TypeFlags,
25    },
26    #[error("Capability {0:?} is not supported")]
27    UnsupportedCapability(Capabilities),
28    #[error("Binding decoration is missing or not applicable")]
29    InvalidBinding,
30    #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
31    Alignment(
32        crate::AddressSpace,
33        Handle<crate::Type>,
34        #[source] Disalignment,
35    ),
36    #[error("Initializer must be an override-expression")]
37    InitializerExprType,
38    #[error("Initializer doesn't match the variable type")]
39    InitializerType,
40    #[error("Initializer can't be used with address space {0:?}")]
41    InitializerNotAllowed(crate::AddressSpace),
42    #[error("Storage address space doesn't support write-only access")]
43    StorageAddressSpaceWriteOnlyNotSupported,
44    #[error("Type is not valid for use as a push constant")]
45    InvalidPushConstantType(#[source] PushConstantError),
46    #[error("Task payload must not be zero-sized")]
47    ZeroSizedTaskPayload,
48}
49
50#[derive(Clone, Debug, thiserror::Error)]
51#[cfg_attr(test, derive(PartialEq))]
52pub enum VaryingError {
53    #[error("The type {0:?} does not match the varying")]
54    InvalidType(Handle<crate::Type>),
55    #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
56    NotIOShareableType(Handle<crate::Type>),
57    #[error("Interpolation is not valid")]
58    InvalidInterpolation,
59    #[error("Cannot combine {interpolation:?} interpolation with the {sampling:?} sample type")]
60    InvalidInterpolationSamplingCombination {
61        interpolation: crate::Interpolation,
62        sampling: crate::Sampling,
63    },
64    #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
65    MissingInterpolation,
66    #[error("Built-in {0:?} is not available at this stage")]
67    InvalidBuiltInStage(crate::BuiltIn),
68    #[error("Built-in type for {0:?} is invalid")]
69    InvalidBuiltInType(crate::BuiltIn),
70    #[error("Entry point arguments and return values must all have bindings")]
71    MissingBinding,
72    #[error("Struct member {0} is missing a binding")]
73    MemberMissingBinding(u32),
74    #[error("Multiple bindings at location {location} are present")]
75    BindingCollision { location: u32 },
76    #[error("Multiple bindings use the same `blend_src` {blend_src}")]
77    BindingCollisionBlendSrc { blend_src: u32 },
78    #[error("Built-in {0:?} is present more than once")]
79    DuplicateBuiltIn(crate::BuiltIn),
80    #[error("Capability {0:?} is not supported")]
81    UnsupportedCapability(Capabilities),
82    #[error("The attribute {0:?} is only valid as an output for stage {1:?}")]
83    InvalidInputAttributeInStage(&'static str, crate::ShaderStage),
84    #[error("The attribute {0:?} is not valid for stage {1:?}")]
85    InvalidAttributeInStage(&'static str, crate::ShaderStage),
86    #[error("The `blend_src` attribute can only be used on location 0, only indices 0 and 1 are valid. Location was {location}, index was {blend_src}.")]
87    InvalidBlendSrcIndex { location: u32, blend_src: u32 },
88    #[error("If `blend_src` is used, there must be exactly two outputs both with location 0, one with `blend_src(0)` and the other with `blend_src(1)`.")]
89    IncompleteBlendSrcUsage,
90    #[error("If `blend_src` is used, both outputs must have the same type. `blend_src(0)` has type {blend_src_0_type:?} and `blend_src(1)` has type {blend_src_1_type:?}.")]
91    BlendSrcOutputTypeMismatch {
92        blend_src_0_type: Handle<crate::Type>,
93        blend_src_1_type: Handle<crate::Type>,
94    },
95    #[error("Workgroup size is multi dimensional, `@builtin(subgroup_id)` and `@builtin(subgroup_invocation_id)` are not supported.")]
96    InvalidMultiDimensionalSubgroupBuiltIn,
97    #[error("The `@per_primitive` attribute can only be used in fragment shader inputs or mesh shader primitive outputs")]
98    InvalidPerPrimitive,
99    #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")]
100    MissingPerPrimitive,
101    #[error("The `MESH_SHADER` capability must be enabled to use per-primitive fragment inputs.")]
102    PerPrimitiveNotAllowed,
103}
104
105#[derive(Clone, Debug, thiserror::Error)]
106#[cfg_attr(test, derive(PartialEq))]
107pub enum EntryPointError {
108    #[error("Multiple conflicting entry points")]
109    Conflict,
110    #[error("Vertex shaders must return a `@builtin(position)` output value")]
111    MissingVertexOutputPosition,
112    #[error("Early depth test is not applicable")]
113    UnexpectedEarlyDepthTest,
114    #[error("Workgroup size is not applicable")]
115    UnexpectedWorkgroupSize,
116    #[error("Workgroup size is out of range")]
117    OutOfRangeWorkgroupSize,
118    #[error("Uses operations forbidden at this stage")]
119    ForbiddenStageOperations,
120    #[error("Global variable {0:?} is used incorrectly as {1:?}")]
121    InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
122    #[error("More than 1 push constant variable is used")]
123    MoreThanOnePushConstantUsed,
124    #[error("Bindings for {0:?} conflict with other resource")]
125    BindingCollision(Handle<crate::GlobalVariable>),
126    #[error("Argument {0} varying error")]
127    Argument(u32, #[source] VaryingError),
128    #[error(transparent)]
129    Result(#[from] VaryingError),
130    #[error("Location {location} interpolation of an integer has to be flat")]
131    InvalidIntegerInterpolation { location: u32 },
132    #[error(transparent)]
133    Function(#[from] FunctionError),
134    #[error("mesh shader entry point missing mesh shader attributes")]
135    ExpectedMeshShaderAttributes,
136    #[error("Non mesh shader entry point cannot have mesh shader attributes")]
137    UnexpectedMeshShaderAttributes,
138    #[error("Non mesh/task shader entry point cannot have task payload attribute")]
139    UnexpectedTaskPayload,
140    #[error("Task payload must be declared with `var<task_payload>`")]
141    TaskPayloadWrongAddressSpace,
142    #[error("For a task payload to be used, it must be declared with @payload")]
143    WrongTaskPayloadUsed,
144    #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")]
145    WrongMeshOutputType,
146    #[error("Only mesh shader entry points can write to mesh output vertices and primitives")]
147    UnexpectedMeshShaderOutput,
148    #[error("Mesh shader entry point cannot have a return type")]
149    UnexpectedMeshShaderEntryResult,
150    #[error("Task shader entry point must return @builtin(mesh_task_size) vec3<u32>")]
151    WrongTaskShaderEntryResult,
152    #[error("Mesh output type must be a user-defined struct.")]
153    InvalidMeshOutputType,
154    #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")]
155    InvalidMeshPrimitiveOutputType,
156    #[error("Task shaders must declare a task payload output")]
157    ExpectedTaskPayload,
158    #[error(
159        "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders."
160    )]
161    MeshShaderCapabilityDisabled,
162}
163
164fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
165    let mut storage_usage = GlobalUse::QUERY;
166    if access.contains(crate::StorageAccess::LOAD) {
167        storage_usage |= GlobalUse::READ;
168    }
169    if access.contains(crate::StorageAccess::STORE) {
170        storage_usage |= GlobalUse::WRITE;
171    }
172    if access.contains(crate::StorageAccess::ATOMIC) {
173        storage_usage |= GlobalUse::ATOMIC;
174    }
175    storage_usage
176}
177
178#[derive(Clone, Copy, Debug, PartialEq, Eq)]
179enum MeshOutputType {
180    None,
181    VertexOutput,
182    PrimitiveOutput,
183}
184
185struct VaryingContext<'a> {
186    stage: crate::ShaderStage,
187    output: bool,
188    types: &'a UniqueArena<crate::Type>,
189    type_info: &'a Vec<super::r#type::TypeInfo>,
190    location_mask: &'a mut BitSet,
191    blend_src_mask: &'a mut BitSet,
192    built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
193    capabilities: Capabilities,
194    flags: super::ValidationFlags,
195    mesh_output_type: MeshOutputType,
196    has_task_payload: bool,
197}
198
199impl VaryingContext<'_> {
200    fn validate_impl(
201        &mut self,
202        ep: &crate::EntryPoint,
203        ty: Handle<crate::Type>,
204        binding: &crate::Binding,
205    ) -> Result<(), VaryingError> {
206        use crate::{BuiltIn as Bi, ShaderStage as St, TypeInner as Ti, VectorSize as Vs};
207
208        let ty_inner = &self.types[ty].inner;
209        match *binding {
210            crate::Binding::BuiltIn(built_in) => {
211                // Ignore the `invariant` field for the sake of duplicate checks,
212                // but use the original in error messages.
213                let canonical = if let crate::BuiltIn::Position { .. } = built_in {
214                    crate::BuiltIn::Position { invariant: false }
215                } else {
216                    built_in
217                };
218
219                if self.built_ins.contains(&canonical) {
220                    return Err(VaryingError::DuplicateBuiltIn(built_in));
221                }
222                self.built_ins.insert(canonical);
223
224                let required = match built_in {
225                    Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
226                    Bi::CullDistance => Capabilities::CULL_DISTANCE,
227                    Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
228                    Bi::Barycentric => Capabilities::SHADER_BARYCENTRICS,
229                    Bi::ViewIndex => Capabilities::MULTIVIEW,
230                    Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
231                    Bi::NumSubgroups
232                    | Bi::SubgroupId
233                    | Bi::SubgroupSize
234                    | Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
235                    _ => Capabilities::empty(),
236                };
237                if !self.capabilities.contains(required) {
238                    return Err(VaryingError::UnsupportedCapability(required));
239                }
240
241                if matches!(
242                    built_in,
243                    crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
244                ) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
245                {
246                    return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
247                }
248
249                let (visible, type_good) = match built_in {
250                    Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
251                        self.stage == St::Vertex && !self.output,
252                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
253                    ),
254                    Bi::DrawID => (
255                        // Always allowed in task/vertex stage. Allowed in mesh stage if there is no task stage in the pipeline.
256                        (self.stage == St::Vertex
257                            || self.stage == St::Task
258                            || (self.stage == St::Mesh && !self.has_task_payload))
259                            && !self.output,
260                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
261                    ),
262                    Bi::ClipDistance | Bi::CullDistance => (
263                        (self.stage == St::Vertex || self.stage == St::Mesh) && self.output,
264                        match *ty_inner {
265                            Ti::Array { base, size, .. } => {
266                                self.types[base].inner == Ti::Scalar(crate::Scalar::F32)
267                                    && match size {
268                                        crate::ArraySize::Constant(non_zero) => non_zero.get() <= 8,
269                                        _ => false,
270                                    }
271                            }
272                            _ => false,
273                        },
274                    ),
275                    Bi::PointSize => (
276                        (self.stage == St::Vertex || self.stage == St::Mesh) && self.output,
277                        *ty_inner == Ti::Scalar(crate::Scalar::F32),
278                    ),
279                    Bi::PointCoord => (
280                        self.stage == St::Fragment && !self.output,
281                        *ty_inner
282                            == Ti::Vector {
283                                size: Vs::Bi,
284                                scalar: crate::Scalar::F32,
285                            },
286                    ),
287                    Bi::Position { .. } => (
288                        match self.stage {
289                            St::Vertex | St::Mesh => self.output,
290                            St::Fragment => !self.output,
291                            St::Compute | St::Task => false,
292                        },
293                        *ty_inner
294                            == Ti::Vector {
295                                size: Vs::Quad,
296                                scalar: crate::Scalar::F32,
297                            },
298                    ),
299                    Bi::ViewIndex => (
300                        match self.stage {
301                            St::Vertex | St::Fragment | St::Task | St::Mesh => !self.output,
302                            St::Compute => false,
303                        },
304                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
305                    ),
306                    Bi::FragDepth => (
307                        self.stage == St::Fragment && self.output,
308                        *ty_inner == Ti::Scalar(crate::Scalar::F32),
309                    ),
310                    Bi::FrontFacing => (
311                        self.stage == St::Fragment && !self.output,
312                        *ty_inner == Ti::Scalar(crate::Scalar::BOOL),
313                    ),
314                    Bi::PrimitiveIndex => (
315                        self.stage == St::Fragment && !self.output,
316                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
317                    ),
318                    Bi::Barycentric => (
319                        self.stage == St::Fragment && !self.output,
320                        *ty_inner
321                            == Ti::Vector {
322                                size: Vs::Tri,
323                                scalar: crate::Scalar::F32,
324                            },
325                    ),
326                    Bi::SampleIndex => (
327                        self.stage == St::Fragment && !self.output,
328                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
329                    ),
330                    Bi::SampleMask => (
331                        self.stage == St::Fragment,
332                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
333                    ),
334                    Bi::LocalInvocationIndex => (
335                        self.stage.compute_like() && !self.output,
336                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
337                    ),
338                    Bi::GlobalInvocationId
339                    | Bi::LocalInvocationId
340                    | Bi::WorkGroupId
341                    | Bi::WorkGroupSize
342                    | Bi::NumWorkGroups => (
343                        self.stage.compute_like() && !self.output,
344                        *ty_inner
345                            == Ti::Vector {
346                                size: Vs::Tri,
347                                scalar: crate::Scalar::U32,
348                            },
349                    ),
350                    Bi::NumSubgroups | Bi::SubgroupId => (
351                        self.stage.compute_like() && !self.output,
352                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
353                    ),
354                    Bi::SubgroupSize | Bi::SubgroupInvocationId => (
355                        match self.stage {
356                            St::Compute | St::Fragment | St::Task | St::Mesh => !self.output,
357                            St::Vertex => false,
358                        },
359                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
360                    ),
361                    Bi::CullPrimitive => (
362                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
363                        *ty_inner == Ti::Scalar(crate::Scalar::BOOL),
364                    ),
365                    Bi::PointIndex => (
366                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
367                        *ty_inner == Ti::Scalar(crate::Scalar::U32),
368                    ),
369                    Bi::LineIndices => (
370                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
371                        *ty_inner
372                            == Ti::Vector {
373                                size: Vs::Bi,
374                                scalar: crate::Scalar::U32,
375                            },
376                    ),
377                    Bi::TriangleIndices => (
378                        self.mesh_output_type == MeshOutputType::PrimitiveOutput,
379                        *ty_inner
380                            == Ti::Vector {
381                                size: Vs::Tri,
382                                scalar: crate::Scalar::U32,
383                            },
384                    ),
385                    Bi::MeshTaskSize => (
386                        self.stage == St::Task && self.output,
387                        *ty_inner
388                            == Ti::Vector {
389                                size: Vs::Tri,
390                                scalar: crate::Scalar::U32,
391                            },
392                    ),
393                };
394
395                if !visible {
396                    return Err(VaryingError::InvalidBuiltInStage(built_in));
397                }
398                if !type_good {
399                    log::warn!("Wrong builtin type: {ty_inner:?}");
400                    return Err(VaryingError::InvalidBuiltInType(built_in));
401                }
402            }
403            crate::Binding::Location {
404                location,
405                interpolation,
406                sampling,
407                blend_src,
408                per_primitive,
409            } => {
410                if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) {
411                    return Err(VaryingError::PerPrimitiveNotAllowed);
412                }
413                // Only IO-shareable types may be stored in locations.
414                if !self.type_info[ty.index()]
415                    .flags
416                    .contains(super::TypeFlags::IO_SHAREABLE)
417                {
418                    return Err(VaryingError::NotIOShareableType(ty));
419                }
420
421                // Check whether `per_primitive` is appropriate for this stage and direction.
422                if self.mesh_output_type == MeshOutputType::PrimitiveOutput {
423                    // All mesh shader `Location` outputs must be `per_primitive`.
424                    if !per_primitive {
425                        return Err(VaryingError::MissingPerPrimitive);
426                    }
427                } else if self.stage == crate::ShaderStage::Fragment && !self.output {
428                    // Fragment stage inputs may be `per_primitive`. We'll only
429                    // know if these are correct when the whole mesh pipeline is
430                    // created and we're paired with a specific mesh or vertex
431                    // shader.
432                } else if per_primitive {
433                    // All other `Location` bindings must not be `per_primitive`.
434                    return Err(VaryingError::InvalidPerPrimitive);
435                }
436
437                if let Some(blend_src) = blend_src {
438                    // `blend_src` is only valid if dual source blending was explicitly enabled,
439                    // see https://www.w3.org/TR/WGSL/#extension-dual_source_blending
440                    if !self
441                        .capabilities
442                        .contains(Capabilities::DUAL_SOURCE_BLENDING)
443                    {
444                        return Err(VaryingError::UnsupportedCapability(
445                            Capabilities::DUAL_SOURCE_BLENDING,
446                        ));
447                    }
448                    if self.stage != crate::ShaderStage::Fragment {
449                        return Err(VaryingError::InvalidAttributeInStage(
450                            "blend_src",
451                            self.stage,
452                        ));
453                    }
454                    if !self.output {
455                        return Err(VaryingError::InvalidInputAttributeInStage(
456                            "blend_src",
457                            self.stage,
458                        ));
459                    }
460                    if (blend_src != 0 && blend_src != 1) || location != 0 {
461                        return Err(VaryingError::InvalidBlendSrcIndex {
462                            location,
463                            blend_src,
464                        });
465                    }
466                    if !self.blend_src_mask.insert(blend_src as usize) {
467                        return Err(VaryingError::BindingCollisionBlendSrc { blend_src });
468                    }
469                } else if !self.location_mask.insert(location as usize)
470                    && self.flags.contains(super::ValidationFlags::BINDINGS)
471                {
472                    return Err(VaryingError::BindingCollision { location });
473                }
474
475                if let Some(interpolation) = interpolation {
476                    let invalid_sampling = match (interpolation, sampling) {
477                        (_, None)
478                        | (
479                            crate::Interpolation::Perspective | crate::Interpolation::Linear,
480                            Some(
481                                crate::Sampling::Center
482                                | crate::Sampling::Centroid
483                                | crate::Sampling::Sample,
484                            ),
485                        )
486                        | (
487                            crate::Interpolation::Flat,
488                            Some(crate::Sampling::First | crate::Sampling::Either),
489                        ) => None,
490                        (_, Some(invalid_sampling)) => Some(invalid_sampling),
491                    };
492                    if let Some(sampling) = invalid_sampling {
493                        return Err(VaryingError::InvalidInterpolationSamplingCombination {
494                            interpolation,
495                            sampling,
496                        });
497                    }
498                }
499
500                let needs_interpolation = match self.stage {
501                    crate::ShaderStage::Vertex => self.output,
502                    crate::ShaderStage::Fragment => !self.output && !per_primitive,
503                    crate::ShaderStage::Compute | crate::ShaderStage::Task => false,
504                    crate::ShaderStage::Mesh => self.output,
505                };
506
507                // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
508                // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
509                // qualifiers, so we won't complain about that here.
510                let _ = sampling;
511
512                let required = match sampling {
513                    Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
514                    _ => Capabilities::empty(),
515                };
516                if !self.capabilities.contains(required) {
517                    return Err(VaryingError::UnsupportedCapability(required));
518                }
519
520                match ty_inner.scalar_kind() {
521                    Some(crate::ScalarKind::Float) => {
522                        if needs_interpolation && interpolation.is_none() {
523                            return Err(VaryingError::MissingInterpolation);
524                        }
525                    }
526                    Some(_) => {
527                        if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
528                        {
529                            return Err(VaryingError::InvalidInterpolation);
530                        }
531                    }
532                    None => return Err(VaryingError::InvalidType(ty)),
533                }
534            }
535        }
536
537        Ok(())
538    }
539
540    fn validate(
541        &mut self,
542        ep: &crate::EntryPoint,
543        ty: Handle<crate::Type>,
544        binding: Option<&crate::Binding>,
545    ) -> Result<(), WithSpan<VaryingError>> {
546        let span_context = self.types.get_span_context(ty);
547        match binding {
548            Some(binding) => self
549                .validate_impl(ep, ty, binding)
550                .map_err(|e| e.with_span_context(span_context)),
551            None => {
552                match self.types[ty].inner {
553                    crate::TypeInner::Struct { ref members, .. } => {
554                        for (index, member) in members.iter().enumerate() {
555                            let span_context = self.types.get_span_context(ty);
556                            match member.binding {
557                                None => {
558                                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
559                                        return Err(VaryingError::MemberMissingBinding(
560                                            index as u32,
561                                        )
562                                        .with_span_context(span_context));
563                                    }
564                                }
565                                Some(ref binding) => self
566                                    .validate_impl(ep, member.ty, binding)
567                                    .map_err(|e| e.with_span_context(span_context))?,
568                            }
569                        }
570
571                        if !self.blend_src_mask.is_empty() {
572                            let span_context = self.types.get_span_context(ty);
573
574                            // If there's any blend_src usage, it must apply to all members of which there must be exactly 2.
575                            if members.len() != 2 || self.blend_src_mask.len() != 2 {
576                                return Err(VaryingError::IncompleteBlendSrcUsage
577                                    .with_span_context(span_context));
578                            }
579                            // Also, all members must have the same type.
580                            if members[0].ty != members[1].ty {
581                                return Err(VaryingError::BlendSrcOutputTypeMismatch {
582                                    blend_src_0_type: members[0].ty,
583                                    blend_src_1_type: members[1].ty,
584                                }
585                                .with_span_context(span_context));
586                            }
587                        }
588                    }
589                    _ => {
590                        if self.flags.contains(super::ValidationFlags::BINDINGS) {
591                            return Err(VaryingError::MissingBinding.with_span());
592                        }
593                    }
594                }
595                Ok(())
596            }
597        }
598    }
599}
600
601impl super::Validator {
602    pub(super) fn validate_global_var(
603        &self,
604        var: &crate::GlobalVariable,
605        gctx: crate::proc::GlobalCtx,
606        mod_info: &ModuleInfo,
607        global_expr_kind: &crate::proc::ExpressionKindTracker,
608    ) -> Result<(), GlobalVariableError> {
609        use super::TypeFlags;
610
611        log::debug!("var {var:?}");
612        let inner_ty = match gctx.types[var.ty].inner {
613            // A binding array is (mostly) supposed to behave the same as a
614            // series of individually bound resources, so we can (mostly)
615            // validate a `binding_array<T>` as if it were just a plain `T`.
616            crate::TypeInner::BindingArray { base, .. } => match var.space {
617                crate::AddressSpace::Storage { .. }
618                | crate::AddressSpace::Uniform
619                | crate::AddressSpace::Handle => base,
620                _ => return Err(GlobalVariableError::InvalidUsage(var.space)),
621            },
622            _ => var.ty,
623        };
624        let type_info = &self.types[inner_ty.index()];
625
626        let (required_type_flags, is_resource) = match var.space {
627            crate::AddressSpace::Function => {
628                return Err(GlobalVariableError::InvalidUsage(var.space))
629            }
630            crate::AddressSpace::Storage { access } => {
631                if let Err((ty_handle, disalignment)) = type_info.storage_layout {
632                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
633                        return Err(GlobalVariableError::Alignment(
634                            var.space,
635                            ty_handle,
636                            disalignment,
637                        ));
638                    }
639                }
640                if access == crate::StorageAccess::STORE {
641                    return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported);
642                }
643                (
644                    TypeFlags::DATA | TypeFlags::HOST_SHAREABLE | TypeFlags::CREATION_RESOLVED,
645                    true,
646                )
647            }
648            crate::AddressSpace::Uniform => {
649                if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
650                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
651                        return Err(GlobalVariableError::Alignment(
652                            var.space,
653                            ty_handle,
654                            disalignment,
655                        ));
656                    }
657                }
658                (
659                    TypeFlags::DATA
660                        | TypeFlags::COPY
661                        | TypeFlags::SIZED
662                        | TypeFlags::HOST_SHAREABLE
663                        | TypeFlags::CREATION_RESOLVED,
664                    true,
665                )
666            }
667            crate::AddressSpace::Handle => {
668                match gctx.types[inner_ty].inner {
669                    crate::TypeInner::Image { class, .. } => match class {
670                        crate::ImageClass::Storage {
671                            format:
672                                crate::StorageFormat::R16Unorm
673                                | crate::StorageFormat::R16Snorm
674                                | crate::StorageFormat::Rg16Unorm
675                                | crate::StorageFormat::Rg16Snorm
676                                | crate::StorageFormat::Rgba16Unorm
677                                | crate::StorageFormat::Rgba16Snorm,
678                            ..
679                        } => {
680                            if !self
681                                .capabilities
682                                .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
683                            {
684                                return Err(GlobalVariableError::UnsupportedCapability(
685                                    Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
686                                ));
687                            }
688                        }
689                        _ => {}
690                    },
691                    crate::TypeInner::Sampler { .. }
692                    | crate::TypeInner::AccelerationStructure { .. }
693                    | crate::TypeInner::RayQuery { .. } => {}
694                    _ => {
695                        return Err(GlobalVariableError::InvalidType(var.space));
696                    }
697                }
698
699                (TypeFlags::empty(), true)
700            }
701            crate::AddressSpace::Private => (
702                TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED,
703                false,
704            ),
705            crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false),
706            crate::AddressSpace::TaskPayload => {
707                if !self.capabilities.contains(Capabilities::MESH_SHADER) {
708                    return Err(GlobalVariableError::UnsupportedCapability(
709                        Capabilities::MESH_SHADER,
710                    ));
711                }
712                (TypeFlags::DATA | TypeFlags::SIZED, false)
713            }
714            crate::AddressSpace::PushConstant => {
715                if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
716                    return Err(GlobalVariableError::UnsupportedCapability(
717                        Capabilities::PUSH_CONSTANT,
718                    ));
719                }
720                if let Err(ref err) = type_info.push_constant_compatibility {
721                    return Err(GlobalVariableError::InvalidPushConstantType(err.clone()));
722                }
723                (
724                    TypeFlags::DATA
725                        | TypeFlags::COPY
726                        | TypeFlags::HOST_SHAREABLE
727                        | TypeFlags::SIZED,
728                    false,
729                )
730            }
731        };
732
733        if !type_info.flags.contains(required_type_flags) {
734            return Err(GlobalVariableError::MissingTypeFlags {
735                seen: type_info.flags,
736                required: required_type_flags,
737            });
738        }
739
740        if is_resource != var.binding.is_some() {
741            if self.flags.contains(super::ValidationFlags::BINDINGS) {
742                return Err(GlobalVariableError::InvalidBinding);
743            }
744        }
745
746        if var.space == crate::AddressSpace::TaskPayload {
747            let ty = &gctx.types[var.ty].inner;
748            // HLSL doesn't allow zero sized payloads.
749            if ty.try_size(gctx) == Some(0) {
750                return Err(GlobalVariableError::ZeroSizedTaskPayload);
751            }
752        }
753
754        if let Some(init) = var.init {
755            match var.space {
756                crate::AddressSpace::Private | crate::AddressSpace::Function => {}
757                _ => {
758                    return Err(GlobalVariableError::InitializerNotAllowed(var.space));
759                }
760            }
761
762            if !global_expr_kind.is_const_or_override(init) {
763                return Err(GlobalVariableError::InitializerExprType);
764            }
765
766            if !gctx.compare_types(
767                &crate::proc::TypeResolution::Handle(var.ty),
768                &mod_info[init],
769            ) {
770                return Err(GlobalVariableError::InitializerType);
771            }
772        }
773
774        Ok(())
775    }
776
777    /// Validate the mesh shader output type `ty`, used as `mesh_output_type`.
778    fn validate_mesh_output_type(
779        &mut self,
780        ep: &crate::EntryPoint,
781        module: &crate::Module,
782        ty: Handle<crate::Type>,
783        mesh_output_type: MeshOutputType,
784    ) -> Result<(), WithSpan<EntryPointError>> {
785        if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) {
786            return Err(EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types));
787        }
788        let mut result_built_ins = crate::FastHashSet::default();
789        let mut ctx = VaryingContext {
790            stage: ep.stage,
791            output: true,
792            types: &module.types,
793            type_info: &self.types,
794            location_mask: &mut self.location_mask,
795            blend_src_mask: &mut self.blend_src_mask,
796            built_ins: &mut result_built_ins,
797            capabilities: self.capabilities,
798            flags: self.flags,
799            mesh_output_type,
800            has_task_payload: ep.task_payload.is_some(),
801        };
802        ctx.validate(ep, ty, None)
803            .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
804        if mesh_output_type == MeshOutputType::PrimitiveOutput {
805            let mut num_indices_builtins = 0;
806            if result_built_ins.contains(&crate::BuiltIn::PointIndex) {
807                num_indices_builtins += 1;
808            }
809            if result_built_ins.contains(&crate::BuiltIn::LineIndices) {
810                num_indices_builtins += 1;
811            }
812            if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) {
813                num_indices_builtins += 1;
814            }
815            if num_indices_builtins != 1 {
816                return Err(EntryPointError::InvalidMeshPrimitiveOutputType
817                    .with_span_handle(ty, &module.types));
818            }
819        } else if mesh_output_type == MeshOutputType::VertexOutput
820            && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
821        {
822            return Err(
823                EntryPointError::MissingVertexOutputPosition.with_span_handle(ty, &module.types)
824            );
825        }
826
827        Ok(())
828    }
829
830    pub(super) fn validate_entry_point(
831        &mut self,
832        ep: &crate::EntryPoint,
833        module: &crate::Module,
834        mod_info: &ModuleInfo,
835    ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
836        if matches!(
837            ep.stage,
838            crate::ShaderStage::Task | crate::ShaderStage::Mesh
839        ) && !self.capabilities.contains(Capabilities::MESH_SHADER)
840        {
841            return Err(EntryPointError::MeshShaderCapabilityDisabled.with_span());
842        }
843        if ep.early_depth_test.is_some() {
844            let required = Capabilities::EARLY_DEPTH_TEST;
845            if !self.capabilities.contains(required) {
846                return Err(
847                    EntryPointError::Result(VaryingError::UnsupportedCapability(required))
848                        .with_span(),
849                );
850            }
851
852            if ep.stage != crate::ShaderStage::Fragment {
853                return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
854            }
855        }
856
857        if ep.stage.compute_like() {
858            if ep
859                .workgroup_size
860                .iter()
861                .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
862            {
863                return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
864            }
865        } else if ep.workgroup_size != [0; 3] {
866            return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
867        }
868
869        match (ep.stage, &ep.mesh_info) {
870            (crate::ShaderStage::Mesh, &None) => {
871                return Err(EntryPointError::ExpectedMeshShaderAttributes.with_span());
872            }
873            (_, &Some(_)) => {
874                return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span());
875            }
876            (_, _) => {}
877        }
878
879        let mut info = self
880            .validate_function(&ep.function, module, mod_info, true)
881            .map_err(WithSpan::into_other)?;
882
883        // Validate the task shader payload.
884        match ep.stage {
885            // Task shaders must produce a payload.
886            crate::ShaderStage::Task => {
887                let Some(handle) = ep.task_payload else {
888                    return Err(EntryPointError::ExpectedTaskPayload.with_span());
889                };
890                if module.global_variables[handle].space != crate::AddressSpace::TaskPayload {
891                    return Err(EntryPointError::TaskPayloadWrongAddressSpace
892                        .with_span_handle(handle, &module.global_variables));
893                }
894                info.insert_global_use(GlobalUse::READ | GlobalUse::WRITE, handle);
895            }
896
897            // Mesh shaders may accept a payload.
898            crate::ShaderStage::Mesh => {
899                if let Some(handle) = ep.task_payload {
900                    if module.global_variables[handle].space != crate::AddressSpace::TaskPayload {
901                        return Err(EntryPointError::TaskPayloadWrongAddressSpace
902                            .with_span_handle(handle, &module.global_variables));
903                    }
904                    info.insert_global_use(GlobalUse::READ, handle);
905                }
906            }
907
908            // Other stages must not have a payload.
909            _ => {
910                if let Some(handle) = ep.task_payload {
911                    return Err(EntryPointError::UnexpectedTaskPayload
912                        .with_span_handle(handle, &module.global_variables));
913                }
914            }
915        }
916
917        {
918            use super::ShaderStages;
919
920            let stage_bit = match ep.stage {
921                crate::ShaderStage::Vertex => ShaderStages::VERTEX,
922                crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
923                crate::ShaderStage::Compute => ShaderStages::COMPUTE,
924                crate::ShaderStage::Mesh => ShaderStages::MESH,
925                crate::ShaderStage::Task => ShaderStages::TASK,
926            };
927
928            if !info.available_stages.contains(stage_bit) {
929                return Err(EntryPointError::ForbiddenStageOperations.with_span());
930            }
931        }
932
933        self.location_mask.clear();
934        let mut argument_built_ins = crate::FastHashSet::default();
935        // TODO: add span info to function arguments
936        for (index, fa) in ep.function.arguments.iter().enumerate() {
937            let mut ctx = VaryingContext {
938                stage: ep.stage,
939                output: false,
940                types: &module.types,
941                type_info: &self.types,
942                location_mask: &mut self.location_mask,
943                blend_src_mask: &mut self.blend_src_mask,
944                built_ins: &mut argument_built_ins,
945                capabilities: self.capabilities,
946                flags: self.flags,
947                mesh_output_type: MeshOutputType::None,
948                has_task_payload: ep.task_payload.is_some(),
949            };
950            ctx.validate(ep, fa.ty, fa.binding.as_ref())
951                .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
952        }
953
954        self.location_mask.clear();
955        if let Some(ref fr) = ep.function.result {
956            let mut result_built_ins = crate::FastHashSet::default();
957            let mut ctx = VaryingContext {
958                stage: ep.stage,
959                output: true,
960                types: &module.types,
961                type_info: &self.types,
962                location_mask: &mut self.location_mask,
963                blend_src_mask: &mut self.blend_src_mask,
964                built_ins: &mut result_built_ins,
965                capabilities: self.capabilities,
966                flags: self.flags,
967                mesh_output_type: MeshOutputType::None,
968                has_task_payload: ep.task_payload.is_some(),
969            };
970            ctx.validate(ep, fr.ty, fr.binding.as_ref())
971                .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
972            if ep.stage == crate::ShaderStage::Vertex
973                && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
974            {
975                return Err(EntryPointError::MissingVertexOutputPosition.with_span());
976            }
977            if ep.stage == crate::ShaderStage::Mesh {
978                return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span());
979            }
980            // Task shaders must have a single `MeshTaskSize` output, and nothing else.
981            if ep.stage == crate::ShaderStage::Task {
982                let ok = result_built_ins.contains(&crate::BuiltIn::MeshTaskSize)
983                    && result_built_ins.len() == 1
984                    && self.location_mask.is_empty();
985                if !ok {
986                    return Err(EntryPointError::WrongTaskShaderEntryResult.with_span());
987                }
988            }
989            if !self.blend_src_mask.is_empty() {
990                info.dual_source_blending = true;
991            }
992        } else if ep.stage == crate::ShaderStage::Vertex {
993            return Err(EntryPointError::MissingVertexOutputPosition.with_span());
994        } else if ep.stage == crate::ShaderStage::Task {
995            return Err(EntryPointError::WrongTaskShaderEntryResult.with_span());
996        }
997
998        {
999            let mut used_push_constants = module
1000                .global_variables
1001                .iter()
1002                .filter(|&(_, var)| var.space == crate::AddressSpace::PushConstant)
1003                .map(|(handle, _)| handle)
1004                .filter(|&handle| !info[handle].is_empty());
1005            // Check if there is more than one push constant, and error if so.
1006            // Use a loop for when returning multiple errors is supported.
1007            if let Some(handle) = used_push_constants.nth(1) {
1008                return Err(EntryPointError::MoreThanOnePushConstantUsed
1009                    .with_span_handle(handle, &module.global_variables));
1010            }
1011        }
1012
1013        self.ep_resource_bindings.clear();
1014        for (var_handle, var) in module.global_variables.iter() {
1015            let usage = info[var_handle];
1016            if usage.is_empty() {
1017                continue;
1018            }
1019
1020            if var.space == crate::AddressSpace::TaskPayload {
1021                if ep.task_payload != Some(var_handle) {
1022                    return Err(EntryPointError::WrongTaskPayloadUsed
1023                        .with_span_handle(var_handle, &module.global_variables));
1024                }
1025            }
1026
1027            let allowed_usage = match var.space {
1028                crate::AddressSpace::Function => unreachable!(),
1029                crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
1030                crate::AddressSpace::Storage { access } => storage_usage(access),
1031                crate::AddressSpace::Handle => match module.types[var.ty].inner {
1032                    crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
1033                        crate::TypeInner::Image {
1034                            class: crate::ImageClass::Storage { access, .. },
1035                            ..
1036                        } => storage_usage(access),
1037                        _ => GlobalUse::READ | GlobalUse::QUERY,
1038                    },
1039                    crate::TypeInner::Image {
1040                        class: crate::ImageClass::Storage { access, .. },
1041                        ..
1042                    } => storage_usage(access),
1043                    _ => GlobalUse::READ | GlobalUse::QUERY,
1044                },
1045                crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
1046                    GlobalUse::READ | GlobalUse::WRITE | GlobalUse::QUERY
1047                }
1048                crate::AddressSpace::TaskPayload => {
1049                    GlobalUse::READ
1050                        | GlobalUse::QUERY
1051                        | if ep.stage == crate::ShaderStage::Task {
1052                            GlobalUse::WRITE
1053                        } else {
1054                            GlobalUse::empty()
1055                        }
1056                }
1057                crate::AddressSpace::PushConstant => GlobalUse::READ,
1058            };
1059            if !allowed_usage.contains(usage) {
1060                log::warn!("\tUsage error for: {var:?}");
1061                log::warn!("\tAllowed usage: {allowed_usage:?}, requested: {usage:?}");
1062                return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
1063                    .with_span_handle(var_handle, &module.global_variables));
1064            }
1065
1066            if let Some(ref bind) = var.binding {
1067                if !self.ep_resource_bindings.insert(*bind) {
1068                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
1069                        return Err(EntryPointError::BindingCollision(var_handle)
1070                            .with_span_handle(var_handle, &module.global_variables));
1071                    }
1072                }
1073            }
1074        }
1075
1076        // If this is a `Mesh` entry point, check its vertex and primitive output types.
1077        // We verified previously that only mesh shaders can have `mesh_info`.
1078        if let &Some(ref mesh_info) = &ep.mesh_info {
1079            // Mesh shaders don't return any value. All their results are supplied through
1080            // [`SetVertex`] and [`SetPrimitive`] calls.
1081            if let Some((used_vertex_type, _)) = info.mesh_shader_info.vertex_type {
1082                if used_vertex_type != mesh_info.vertex_output_type {
1083                    return Err(EntryPointError::WrongMeshOutputType
1084                        .with_span_handle(mesh_info.vertex_output_type, &module.types));
1085                }
1086            }
1087            if let Some((used_primitive_type, _)) = info.mesh_shader_info.primitive_type {
1088                if used_primitive_type != mesh_info.primitive_output_type {
1089                    return Err(EntryPointError::WrongMeshOutputType
1090                        .with_span_handle(mesh_info.primitive_output_type, &module.types));
1091                }
1092            }
1093
1094            self.validate_mesh_output_type(
1095                ep,
1096                module,
1097                mesh_info.vertex_output_type,
1098                MeshOutputType::VertexOutput,
1099            )?;
1100            self.validate_mesh_output_type(
1101                ep,
1102                module,
1103                mesh_info.primitive_output_type,
1104                MeshOutputType::PrimitiveOutput,
1105            )?;
1106        } else {
1107            // This is not a `Mesh` entry point, so ensure that it never tries to produce
1108            // vertices or primitives.
1109            if info.mesh_shader_info.vertex_type.is_some()
1110                || info.mesh_shader_info.primitive_type.is_some()
1111            {
1112                return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span());
1113            }
1114        }
1115
1116        Ok(info)
1117    }
1118}