wgpu_core/
validation.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString as _},
4    sync::Arc,
5    vec::Vec,
6};
7use core::fmt;
8
9use arrayvec::ArrayVec;
10use hashbrown::{hash_map::Entry, HashSet};
11use shader_io_deductions::{display_deductions_as_optional_list, MaxVertexShaderOutputDeduction};
12use thiserror::Error;
13use wgt::{
14    error::{ErrorType, WebGpuError},
15    BindGroupLayoutEntry, BindingType,
16};
17
18use crate::{
19    command::ColorAttachmentError, device::bgl, resource::InvalidResourceError,
20    validation::shader_io_deductions::MaxFragmentShaderInputDeduction, FastHashMap, FastHashSet,
21};
22
23pub mod shader_io_deductions;
24
25#[derive(Debug)]
26enum ResourceType {
27    Buffer {
28        size: wgt::BufferSize,
29    },
30    Texture {
31        dim: naga::ImageDimension,
32        arrayed: bool,
33        class: naga::ImageClass,
34    },
35    Sampler {
36        comparison: bool,
37    },
38    AccelerationStructure {
39        vertex_return: bool,
40    },
41}
42
43#[derive(Clone, Debug)]
44pub enum BindingTypeName {
45    Buffer,
46    Texture,
47    Sampler,
48    AccelerationStructure,
49    ExternalTexture,
50}
51
52impl From<&ResourceType> for BindingTypeName {
53    fn from(ty: &ResourceType) -> BindingTypeName {
54        match ty {
55            ResourceType::Buffer { .. } => BindingTypeName::Buffer,
56            ResourceType::Texture {
57                class: naga::ImageClass::External,
58                ..
59            } => BindingTypeName::ExternalTexture,
60            ResourceType::Texture { .. } => BindingTypeName::Texture,
61            ResourceType::Sampler { .. } => BindingTypeName::Sampler,
62            ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
63        }
64    }
65}
66
67impl From<&BindingType> for BindingTypeName {
68    fn from(ty: &BindingType) -> BindingTypeName {
69        match ty {
70            BindingType::Buffer { .. } => BindingTypeName::Buffer,
71            BindingType::Texture { .. } => BindingTypeName::Texture,
72            BindingType::StorageTexture { .. } => BindingTypeName::Texture,
73            BindingType::Sampler { .. } => BindingTypeName::Sampler,
74            BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
75            BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
76        }
77    }
78}
79
80#[derive(Debug)]
81struct Resource {
82    #[allow(unused)]
83    name: Option<String>,
84    bind: naga::ResourceBinding,
85    ty: ResourceType,
86    class: naga::AddressSpace,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
90enum NumericDimension {
91    Scalar,
92    Vector(naga::VectorSize),
93    Matrix(naga::VectorSize, naga::VectorSize),
94}
95
96impl fmt::Display for NumericDimension {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        match *self {
99            Self::Scalar => write!(f, ""),
100            Self::Vector(size) => write!(f, "x{}", size as u8),
101            Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
102        }
103    }
104}
105
106#[derive(Clone, Copy, Debug, Eq, PartialEq)]
107pub struct NumericType {
108    dim: NumericDimension,
109    scalar: naga::Scalar,
110}
111
112impl fmt::Display for NumericType {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(
115            f,
116            "{:?}{}{}",
117            self.scalar.kind,
118            self.scalar.width * 8,
119            self.dim
120        )
121    }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct InterfaceVar {
126    pub ty: NumericType,
127    interpolation: Option<naga::Interpolation>,
128    sampling: Option<naga::Sampling>,
129    per_primitive: bool,
130}
131
132impl InterfaceVar {
133    pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
134        InterfaceVar {
135            ty: NumericType::from_vertex_format(format),
136            interpolation: None,
137            sampling: None,
138            per_primitive: false,
139        }
140    }
141}
142
143impl fmt::Display for InterfaceVar {
144    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145        write!(
146            f,
147            "{} interpolated as {:?} with sampling {:?}",
148            self.ty, self.interpolation, self.sampling
149        )
150    }
151}
152
153#[derive(Debug, Eq, PartialEq)]
154enum Varying {
155    Local { location: u32, iv: InterfaceVar },
156    BuiltIn(BuiltIn),
157}
158
159#[derive(Clone, Debug, Eq, PartialEq)]
160enum BuiltIn {
161    Position { invariant: bool },
162    ViewIndex,
163    BaseInstance,
164    BaseVertex,
165    ClipDistances { array_size: u32 },
166    CullDistance,
167    InstanceIndex,
168    PointSize,
169    VertexIndex,
170    DrawIndex,
171    FragDepth,
172    PointCoord,
173    FrontFacing,
174    PrimitiveIndex,
175    Barycentric { perspective: bool },
176    SampleIndex,
177    SampleMask,
178    GlobalInvocationId,
179    LocalInvocationId,
180    LocalInvocationIndex,
181    WorkGroupId,
182    WorkGroupSize,
183    NumWorkGroups,
184    NumSubgroups,
185    SubgroupId,
186    SubgroupSize,
187    SubgroupInvocationId,
188    MeshTaskSize,
189    CullPrimitive,
190    PointIndex,
191    LineIndices,
192    TriangleIndices,
193    VertexCount,
194    Vertices,
195    PrimitiveCount,
196    Primitives,
197    RayInvocationId,
198    NumRayInvocations,
199    InstanceCustomData,
200    GeometryIndex,
201    WorldRayOrigin,
202    WorldRayDirection,
203    ObjectRayOrigin,
204    ObjectRayDirection,
205    RayTmin,
206    RayTCurrentMax,
207    ObjectToWorld,
208    WorldToObject,
209    HitKind,
210}
211
212impl BuiltIn {
213    pub fn to_naga(&self) -> naga::BuiltIn {
214        match self {
215            &Self::Position { invariant } => naga::BuiltIn::Position { invariant },
216            Self::ViewIndex => naga::BuiltIn::ViewIndex,
217            Self::BaseInstance => naga::BuiltIn::BaseInstance,
218            Self::BaseVertex => naga::BuiltIn::BaseVertex,
219            Self::ClipDistances { .. } => naga::BuiltIn::ClipDistances,
220            Self::CullDistance => naga::BuiltIn::CullDistance,
221            Self::InstanceIndex => naga::BuiltIn::InstanceIndex,
222            Self::PointSize => naga::BuiltIn::PointSize,
223            Self::VertexIndex => naga::BuiltIn::VertexIndex,
224            Self::DrawIndex => naga::BuiltIn::DrawIndex,
225            Self::FragDepth => naga::BuiltIn::FragDepth,
226            Self::PointCoord => naga::BuiltIn::PointCoord,
227            Self::FrontFacing => naga::BuiltIn::FrontFacing,
228            Self::PrimitiveIndex => naga::BuiltIn::PrimitiveIndex,
229            Self::Barycentric { perspective } => naga::BuiltIn::Barycentric {
230                perspective: *perspective,
231            },
232            Self::SampleIndex => naga::BuiltIn::SampleIndex,
233            Self::SampleMask => naga::BuiltIn::SampleMask,
234            Self::GlobalInvocationId => naga::BuiltIn::GlobalInvocationId,
235            Self::LocalInvocationId => naga::BuiltIn::LocalInvocationId,
236            Self::LocalInvocationIndex => naga::BuiltIn::LocalInvocationIndex,
237            Self::WorkGroupId => naga::BuiltIn::WorkGroupId,
238            Self::WorkGroupSize => naga::BuiltIn::WorkGroupSize,
239            Self::NumWorkGroups => naga::BuiltIn::NumWorkGroups,
240            Self::NumSubgroups => naga::BuiltIn::NumSubgroups,
241            Self::SubgroupId => naga::BuiltIn::SubgroupId,
242            Self::SubgroupSize => naga::BuiltIn::SubgroupSize,
243            Self::SubgroupInvocationId => naga::BuiltIn::SubgroupInvocationId,
244            Self::MeshTaskSize => naga::BuiltIn::MeshTaskSize,
245            Self::CullPrimitive => naga::BuiltIn::CullPrimitive,
246            Self::PointIndex => naga::BuiltIn::PointIndex,
247            Self::LineIndices => naga::BuiltIn::LineIndices,
248            Self::TriangleIndices => naga::BuiltIn::TriangleIndices,
249            Self::VertexCount => naga::BuiltIn::VertexCount,
250            Self::Vertices => naga::BuiltIn::Vertices,
251            Self::PrimitiveCount => naga::BuiltIn::PrimitiveCount,
252            Self::Primitives => naga::BuiltIn::Primitives,
253            Self::RayInvocationId => naga::BuiltIn::RayInvocationId,
254            Self::NumRayInvocations => naga::BuiltIn::NumRayInvocations,
255            Self::InstanceCustomData => naga::BuiltIn::InstanceCustomData,
256            Self::GeometryIndex => naga::BuiltIn::GeometryIndex,
257            Self::WorldRayOrigin => naga::BuiltIn::WorldRayOrigin,
258            Self::WorldRayDirection => naga::BuiltIn::WorldRayDirection,
259            Self::ObjectRayOrigin => naga::BuiltIn::ObjectRayOrigin,
260            Self::ObjectRayDirection => naga::BuiltIn::ObjectRayDirection,
261            Self::RayTmin => naga::BuiltIn::RayTmin,
262            Self::RayTCurrentMax => naga::BuiltIn::RayTCurrentMax,
263            Self::ObjectToWorld => naga::BuiltIn::ObjectToWorld,
264            Self::WorldToObject => naga::BuiltIn::WorldToObject,
265            Self::HitKind => naga::BuiltIn::HitKind,
266        }
267    }
268}
269
270#[allow(unused)]
271#[derive(Debug)]
272struct SpecializationConstant {
273    id: u32,
274    ty: NumericType,
275}
276
277#[derive(Debug)]
278struct EntryPointMeshInfo {
279    max_vertices: u32,
280    max_primitives: u32,
281    primitive_topology: wgt::PrimitiveTopology,
282}
283
284#[derive(Debug, Default)]
285struct EntryPoint {
286    inputs: Vec<Varying>,
287    outputs: Vec<Varying>,
288    resources: Vec<naga::Handle<Resource>>,
289    #[allow(unused)]
290    spec_constants: Vec<SpecializationConstant>,
291    sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
292    workgroup_size: [u32; 3],
293    dual_source_blending: bool,
294    task_payload_size: Option<u32>,
295    mesh_info: Option<EntryPointMeshInfo>,
296    immediate_slots_required: naga::valid::ImmediateSlots,
297}
298
299#[derive(Debug)]
300pub struct Interface {
301    limits: wgt::Limits,
302    resources: naga::Arena<Resource>,
303    entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
304    pub(crate) immediate_size: u32,
305}
306
307#[derive(Debug)]
308pub struct PassthroughInterface {
309    pub entry_point_names: HashSet<String>,
310}
311
312// Most shaders will use a standard interface which is very large.
313// Passthrough shaders have a much smaller interface. No reason to
314// box the standard interface though.
315#[expect(clippy::large_enum_variant)]
316#[derive(Debug)]
317pub enum ShaderMetaData {
318    Interface(Interface),
319    Passthrough(PassthroughInterface),
320}
321impl ShaderMetaData {
322    pub fn interface(&self) -> Option<&Interface> {
323        match self {
324            Self::Interface(i) => Some(i),
325            Self::Passthrough(_) => None,
326        }
327    }
328}
329
330#[derive(Clone, Debug, Error)]
331#[non_exhaustive]
332pub enum BindingError {
333    #[error("Binding is missing from the pipeline layout")]
334    Missing,
335    #[error("Visibility flags don't include the shader stage")]
336    Invisible,
337    #[error(
338        "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
339    )]
340    WrongType {
341        binding: BindingTypeName,
342        shader: BindingTypeName,
343    },
344    #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
345    WrongAddressSpace {
346        binding: naga::AddressSpace,
347        shader: naga::AddressSpace,
348    },
349    #[error("Address space {space:?} is not a valid Buffer address space")]
350    WrongBufferAddressSpace { space: naga::AddressSpace },
351    #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
352    WrongBufferSize {
353        buffer_size: wgt::BufferSize,
354        min_binding_size: wgt::BufferSize,
355    },
356    #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
357    WrongTextureViewDimension {
358        dim: naga::ImageDimension,
359        is_array: bool,
360        binding: BindingType,
361    },
362    #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
363    WrongTextureClass {
364        binding: naga::ImageClass,
365        shader: naga::ImageClass,
366    },
367    #[error("Comparison flag doesn't match the shader")]
368    WrongSamplerComparison,
369    #[error("Derived bind group layout type is not consistent between stages")]
370    InconsistentlyDerivedType,
371    #[error("Texture format {0:?} is not supported for storage use")]
372    BadStorageFormat(wgt::TextureFormat),
373}
374
375impl WebGpuError for BindingError {
376    fn webgpu_error_type(&self) -> ErrorType {
377        ErrorType::Validation
378    }
379}
380
381#[derive(Clone, Debug, Error)]
382#[non_exhaustive]
383pub enum FilteringError {
384    #[error("Integer textures can't be sampled with a filtering sampler")]
385    Integer,
386    #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
387    Float,
388}
389
390impl WebGpuError for FilteringError {
391    fn webgpu_error_type(&self) -> ErrorType {
392        ErrorType::Validation
393    }
394}
395
396#[derive(Clone, Debug, Error)]
397#[non_exhaustive]
398pub enum InputError {
399    #[error("Input is not provided by the earlier stage in the pipeline")]
400    Missing,
401    #[error("Input type is not compatible with the provided {0}")]
402    WrongType(NumericType),
403    #[error("Input interpolation doesn't match provided {0:?}")]
404    InterpolationMismatch(Option<naga::Interpolation>),
405    #[error("Input sampling doesn't match provided {0:?}")]
406    SamplingMismatch(Option<naga::Sampling>),
407    #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
408    WrongPerPrimitive { pipeline_input: bool, shader: bool },
409}
410
411impl WebGpuError for InputError {
412    fn webgpu_error_type(&self) -> ErrorType {
413        ErrorType::Validation
414    }
415}
416
417/// Errors produced when validating a programmable stage of a pipeline.
418#[derive(Clone, Debug, Error)]
419#[non_exhaustive]
420pub enum StageError {
421    #[error(transparent)]
422    InvalidWorkgroupSize(#[from] InvalidWorkgroupSizeError),
423    #[error("Unable to find entry point '{0}'")]
424    MissingEntryPoint(String),
425    #[error("Shader global {0:?} is not available in the pipeline layout")]
426    Binding(naga::ResourceBinding, #[source] BindingError),
427    #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
428    Filtering {
429        texture: naga::ResourceBinding,
430        sampler: naga::ResourceBinding,
431        #[source]
432        error: FilteringError,
433    },
434    #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
435    Input {
436        location: wgt::ShaderLocation,
437        var: InterfaceVar,
438        #[source]
439        error: InputError,
440    },
441    #[error(
442        "Unable to select an entry point: no entry point was found in the provided shader module"
443    )]
444    NoEntryPointFound,
445    #[error(
446        "Unable to select an entry point: \
447        multiple entry points were found in the provided shader module, \
448        but no entry point was specified"
449    )]
450    MultipleEntryPointsFound,
451    #[error(transparent)]
452    InvalidResource(#[from] InvalidResourceError),
453    #[error(
454        "vertex shader output location Location[{location}] ({var}) exceeds the \
455        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
456        // NOTE: Remember: the limit is 0-based for indices.
457        limit - 1,
458        display_deductions_as_optional_list(deductions, |d| d.for_location())
459    )]
460    VertexOutputLocationTooLarge {
461        location: u32,
462        var: InterfaceVar,
463        limit: u32,
464        deductions: Vec<MaxVertexShaderOutputDeduction>,
465    },
466    #[error(
467        "found {num_found} user-defined vertex shader output variables, which exceeds the \
468        `max_inter_stage_shader_variables` limit ({limit}){}",
469        display_deductions_as_optional_list(deductions, |d| d.for_variables())
470    )]
471    TooManyUserDefinedVertexOutputs {
472        num_found: u32,
473        limit: u32,
474        deductions: Vec<MaxVertexShaderOutputDeduction>,
475    },
476    #[error(
477        "fragment shader input location Location[{location}] ({var}) exceeds the \
478        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
479        // NOTE: Remember: the limit is 0-based for indices.
480        limit - 1,
481        // NOTE: WebGPU spec. validation for fragment inputs is expressed in terms of variables
482        // (unlike vertex outputs), so we use `MaxFragmentShaderInputDeduction::for_variables` here
483        // (and not a non-existent `for_locations`).
484        display_deductions_as_optional_list(deductions, |d| d.for_variables())
485    )]
486    FragmentInputLocationTooLarge {
487        location: u32,
488        var: InterfaceVar,
489        limit: u32,
490        deductions: Vec<MaxFragmentShaderInputDeduction>,
491    },
492    #[error(
493        "found {num_found} user-defined fragment shader input variables, which exceeds the \
494        `max_inter_stage_shader_variables` limit ({limit}){}",
495        display_deductions_as_optional_list(deductions, |d| d.for_variables())
496    )]
497    TooManyUserDefinedFragmentInputs {
498        num_found: u32,
499        limit: u32,
500        deductions: Vec<MaxFragmentShaderInputDeduction>,
501    },
502    #[error(
503        "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
504    )]
505    ColorAttachmentLocationTooLarge {
506        location: u32,
507        var: InterfaceVar,
508        limit: u32,
509    },
510    #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
511    TooManyMeshVertices { limit: u32, value: u32 },
512    #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
513    TooManyMeshPrimitives { limit: u32, value: u32 },
514    #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
515    TaskPayloadTooLarge { limit: u32, value: u32 },
516    #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
517    TaskPayloadMustMatch {
518        input: Option<u32>,
519        shader: Option<u32>,
520    },
521    #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
522    InvalidPrimitiveIndex,
523    #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
524    MissingPrimitiveIndex,
525    #[error("DrawId cannot be used in a mesh shader in a pipeline with a task shader")]
526    DrawIdError,
527    #[error("Pipeline uses dual-source blending, but the shader does not support it")]
528    InvalidDualSourceBlending,
529    #[error("Fragment shader writes depth, but pipeline does not have a depth attachment")]
530    MissingFragDepthAttachment,
531    #[error("Per vertex fragment inputs can only be used in triangle primitive pipelines")]
532    PerVertexNotTriangles,
533    #[error("Mesh shader pipelines must have primitive topology of TriangleList, LineList or PointList, and this must match with what the mesh shader declares.")]
534    MeshTopologyMismatch,
535}
536
537impl WebGpuError for StageError {
538    fn webgpu_error_type(&self) -> ErrorType {
539        match self {
540            Self::Binding(_, e) => e.webgpu_error_type(),
541            Self::InvalidResource(e) => e.webgpu_error_type(),
542            Self::Filtering {
543                texture: _,
544                sampler: _,
545                error,
546            } => error.webgpu_error_type(),
547            Self::Input {
548                location: _,
549                var: _,
550                error,
551            } => error.webgpu_error_type(),
552            Self::InvalidWorkgroupSize { .. }
553            | Self::MissingEntryPoint(..)
554            | Self::NoEntryPointFound
555            | Self::MultipleEntryPointsFound
556            | Self::VertexOutputLocationTooLarge { .. }
557            | Self::TooManyUserDefinedVertexOutputs { .. }
558            | Self::FragmentInputLocationTooLarge { .. }
559            | Self::TooManyUserDefinedFragmentInputs { .. }
560            | Self::ColorAttachmentLocationTooLarge { .. }
561            | Self::TooManyMeshVertices { .. }
562            | Self::TooManyMeshPrimitives { .. }
563            | Self::TaskPayloadTooLarge { .. }
564            | Self::TaskPayloadMustMatch { .. }
565            | Self::InvalidPrimitiveIndex
566            | Self::MissingPrimitiveIndex
567            | Self::DrawIdError
568            | Self::InvalidDualSourceBlending
569            | Self::MissingFragDepthAttachment
570            | Self::PerVertexNotTriangles
571            | Self::MeshTopologyMismatch => ErrorType::Validation,
572        }
573    }
574}
575
576pub use wgpu_naga_bridge::map_storage_format_from_naga;
577pub use wgpu_naga_bridge::map_storage_format_to_naga;
578
579impl Resource {
580    fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
581        match self.ty {
582            ResourceType::Buffer { size } => {
583                let min_size = match entry.ty {
584                    BindingType::Buffer {
585                        ty,
586                        has_dynamic_offset: _,
587                        min_binding_size,
588                    } => {
589                        let class = match ty {
590                            wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
591                            wgt::BufferBindingType::Storage { read_only } => {
592                                let mut naga_access = naga::StorageAccess::LOAD;
593                                naga_access.set(naga::StorageAccess::STORE, !read_only);
594                                naga::AddressSpace::Storage {
595                                    access: naga_access,
596                                }
597                            }
598                        };
599                        if self.class != class {
600                            return Err(BindingError::WrongAddressSpace {
601                                binding: class,
602                                shader: self.class,
603                            });
604                        }
605                        min_binding_size
606                    }
607                    _ => {
608                        return Err(BindingError::WrongType {
609                            binding: (&entry.ty).into(),
610                            shader: (&self.ty).into(),
611                        })
612                    }
613                };
614                match min_size {
615                    Some(non_zero) if non_zero < size => {
616                        return Err(BindingError::WrongBufferSize {
617                            buffer_size: size,
618                            min_binding_size: non_zero,
619                        })
620                    }
621                    _ => (),
622                }
623            }
624            ResourceType::Sampler { comparison } => match entry.ty {
625                BindingType::Sampler(ty) => {
626                    if (ty == wgt::SamplerBindingType::Comparison) != comparison {
627                        return Err(BindingError::WrongSamplerComparison);
628                    }
629                }
630                _ => {
631                    return Err(BindingError::WrongType {
632                        binding: (&entry.ty).into(),
633                        shader: (&self.ty).into(),
634                    })
635                }
636            },
637            ResourceType::Texture {
638                dim,
639                arrayed,
640                class: shader_class,
641            } => {
642                let view_dimension = match entry.ty {
643                    BindingType::Texture { view_dimension, .. }
644                    | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
645                    BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
646                    _ => {
647                        return Err(BindingError::WrongTextureViewDimension {
648                            dim,
649                            is_array: false,
650                            binding: entry.ty,
651                        })
652                    }
653                };
654                if arrayed {
655                    match (dim, view_dimension) {
656                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
657                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
658                        _ => {
659                            return Err(BindingError::WrongTextureViewDimension {
660                                dim,
661                                is_array: true,
662                                binding: entry.ty,
663                            })
664                        }
665                    }
666                } else {
667                    match (dim, view_dimension) {
668                        (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
669                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
670                        (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
671                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
672                        _ => {
673                            return Err(BindingError::WrongTextureViewDimension {
674                                dim,
675                                is_array: false,
676                                binding: entry.ty,
677                            })
678                        }
679                    }
680                }
681                match entry.ty {
682                    BindingType::Texture {
683                        sample_type,
684                        view_dimension: _,
685                        multisampled: multi,
686                    } => {
687                        let binding_class = match sample_type {
688                            wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
689                                kind: naga::ScalarKind::Float,
690                                multi,
691                            },
692                            wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
693                                kind: naga::ScalarKind::Sint,
694                                multi,
695                            },
696                            wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
697                                kind: naga::ScalarKind::Uint,
698                                multi,
699                            },
700                            wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
701                        };
702                        if shader_class == binding_class {
703                            Ok(())
704                        } else {
705                            Err(binding_class)
706                        }
707                    }
708                    BindingType::StorageTexture {
709                        access: wgt_binding_access,
710                        format: wgt_binding_format,
711                        view_dimension: _,
712                    } => {
713                        const LOAD_STORE: naga::StorageAccess =
714                            naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
715                        let binding_format = map_storage_format_to_naga(wgt_binding_format)
716                            .ok_or(BindingError::BadStorageFormat(wgt_binding_format))?;
717                        let binding_access = match wgt_binding_access {
718                            wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
719                            wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
720                            wgt::StorageTextureAccess::ReadWrite => LOAD_STORE,
721                            wgt::StorageTextureAccess::Atomic => {
722                                naga::StorageAccess::ATOMIC | LOAD_STORE
723                            }
724                        };
725                        match shader_class {
726                            // Formats must match exactly. A write-only shader (but not a
727                            // read-only shader) is compatible with a read-write binding.
728                            naga::ImageClass::Storage {
729                                format: shader_format,
730                                access: shader_access,
731                            } if shader_format == binding_format
732                                && (shader_access == binding_access
733                                    || shader_access == naga::StorageAccess::STORE
734                                        && binding_access == LOAD_STORE) =>
735                            {
736                                Ok(())
737                            }
738                            _ => Err(naga::ImageClass::Storage {
739                                format: binding_format,
740                                access: binding_access,
741                            }),
742                        }
743                    }
744                    BindingType::ExternalTexture => {
745                        let binding_class = naga::ImageClass::External;
746                        if shader_class == binding_class {
747                            Ok(())
748                        } else {
749                            Err(binding_class)
750                        }
751                    }
752                    _ => {
753                        return Err(BindingError::WrongType {
754                            binding: (&entry.ty).into(),
755                            shader: (&self.ty).into(),
756                        })
757                    }
758                }
759                .map_err(|binding_class| BindingError::WrongTextureClass {
760                    binding: binding_class,
761                    shader: shader_class,
762                })?;
763            }
764            ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
765                BindingType::AccelerationStructure {
766                    vertex_return: entry_vertex_return,
767                } if vertex_return == entry_vertex_return => (),
768                _ => {
769                    return Err(BindingError::WrongType {
770                        binding: (&entry.ty).into(),
771                        shader: (&self.ty).into(),
772                    })
773                }
774            },
775        };
776
777        Ok(())
778    }
779
780    fn derive_binding_type(
781        &self,
782        is_reffed_by_sampler_in_entrypoint: bool,
783    ) -> Result<BindingType, BindingError> {
784        Ok(match self.ty {
785            ResourceType::Buffer { size } => BindingType::Buffer {
786                ty: match self.class {
787                    naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
788                    naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
789                        read_only: access == naga::StorageAccess::LOAD,
790                    },
791                    _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
792                },
793                has_dynamic_offset: false,
794                min_binding_size: Some(size),
795            },
796            ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
797                wgt::SamplerBindingType::Comparison
798            } else {
799                wgt::SamplerBindingType::Filtering
800            }),
801            ResourceType::Texture {
802                dim,
803                arrayed,
804                class,
805            } => {
806                let view_dimension = match dim {
807                    naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
808                    naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
809                    naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
810                    naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
811                    naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
812                    naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
813                };
814                match class {
815                    naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
816                        sample_type: match kind {
817                            naga::ScalarKind::Float => wgt::TextureSampleType::Float {
818                                filterable: is_reffed_by_sampler_in_entrypoint,
819                            },
820                            naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
821                            naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
822                            naga::ScalarKind::AbstractInt
823                            | naga::ScalarKind::AbstractFloat
824                            | naga::ScalarKind::Bool => unreachable!(),
825                        },
826                        view_dimension,
827                        multisampled: multi,
828                    },
829                    naga::ImageClass::Depth { multi } => BindingType::Texture {
830                        sample_type: wgt::TextureSampleType::Depth,
831                        view_dimension,
832                        multisampled: multi,
833                    },
834                    naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
835                        access: {
836                            const LOAD_STORE: naga::StorageAccess =
837                                naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
838                            match access {
839                                naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
840                                naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
841                                LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
842                                _ if access.contains(naga::StorageAccess::ATOMIC) => {
843                                    wgt::StorageTextureAccess::Atomic
844                                }
845                                _ => unreachable!(),
846                            }
847                        },
848                        view_dimension,
849                        format: {
850                            let f = map_storage_format_from_naga(format);
851                            let original = map_storage_format_to_naga(f)
852                                .ok_or(BindingError::BadStorageFormat(f))?;
853                            debug_assert_eq!(format, original);
854                            f
855                        },
856                    },
857                    naga::ImageClass::External => BindingType::ExternalTexture,
858                }
859            }
860            ResourceType::AccelerationStructure { vertex_return } => {
861                BindingType::AccelerationStructure { vertex_return }
862            }
863        })
864    }
865}
866
867impl NumericType {
868    fn from_vertex_format(format: wgt::VertexFormat) -> Self {
869        use naga::{Scalar, VectorSize as Vs};
870        use wgt::VertexFormat as Vf;
871
872        let (dim, scalar) = match format {
873            Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
874            Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
875                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
876            }
877            Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
878            Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
879                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
880            }
881            Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
882            Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
883                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
884            }
885            Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
886            Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
887                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
888            }
889            Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
890                (NumericDimension::Scalar, Scalar::F32)
891            }
892            Vf::Unorm8x2
893            | Vf::Snorm8x2
894            | Vf::Unorm16x2
895            | Vf::Snorm16x2
896            | Vf::Float16x2
897            | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
898            Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
899            Vf::Unorm8x4
900            | Vf::Snorm8x4
901            | Vf::Unorm16x4
902            | Vf::Snorm16x4
903            | Vf::Float16x4
904            | Vf::Float32x4
905            | Vf::Unorm10_10_10_2
906            | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
907            Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
908            Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
909            Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
910            Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
911        };
912
913        NumericType {
914            dim,
915            //Note: Shader always sees data as int, uint, or float.
916            // It doesn't know if the original is normalized in a tighter form.
917            scalar,
918        }
919    }
920
921    fn from_texture_format(format: wgt::TextureFormat) -> Self {
922        use naga::{Scalar, VectorSize as Vs};
923        use wgt::TextureFormat as Tf;
924
925        let (dim, scalar) = match format {
926            Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
927                (NumericDimension::Scalar, Scalar::F32)
928            }
929            Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
930            Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
931            Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
932                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
933            }
934            Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
935            Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
936                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
937            }
938            Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
939                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
940            }
941            Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
942            Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
943            Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
944            Tf::Rgba8Unorm
945            | Tf::Rgba8UnormSrgb
946            | Tf::Rgba8Snorm
947            | Tf::Bgra8Unorm
948            | Tf::Bgra8UnormSrgb
949            | Tf::Rgb10a2Unorm
950            | Tf::Rgba16Float
951            | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
952            Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
953                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
954            }
955            Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
956                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
957            }
958            Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
959            Tf::Stencil8
960            | Tf::Depth16Unorm
961            | Tf::Depth32Float
962            | Tf::Depth32FloatStencil8
963            | Tf::Depth24Plus
964            | Tf::Depth24PlusStencil8 => {
965                panic!("Unexpected depth format")
966            }
967            Tf::NV12 => panic!("Unexpected nv12 format"),
968            Tf::P010 => panic!("Unexpected p010 format"),
969            Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
970            Tf::Bc1RgbaUnorm
971            | Tf::Bc1RgbaUnormSrgb
972            | Tf::Bc2RgbaUnorm
973            | Tf::Bc2RgbaUnormSrgb
974            | Tf::Bc3RgbaUnorm
975            | Tf::Bc3RgbaUnormSrgb
976            | Tf::Bc7RgbaUnorm
977            | Tf::Bc7RgbaUnormSrgb
978            | Tf::Etc2Rgb8A1Unorm
979            | Tf::Etc2Rgb8A1UnormSrgb
980            | Tf::Etc2Rgba8Unorm
981            | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
982            Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
983                (NumericDimension::Scalar, Scalar::F32)
984            }
985            Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
986                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
987            }
988            Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
989                (NumericDimension::Vector(Vs::Tri), Scalar::F32)
990            }
991            Tf::Astc {
992                block: _,
993                channel: _,
994            } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
995        };
996
997        NumericType {
998            dim,
999            //Note: Shader always sees data as int, uint, or float.
1000            // It doesn't know if the original is normalized in a tighter form.
1001            scalar,
1002        }
1003    }
1004
1005    fn is_subtype_of(&self, other: &NumericType) -> bool {
1006        if self.scalar.width > other.scalar.width {
1007            return false;
1008        }
1009        if self.scalar.kind != other.scalar.kind {
1010            return false;
1011        }
1012        match (self.dim, other.dim) {
1013            (NumericDimension::Scalar, NumericDimension::Scalar) => true,
1014            (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
1015            (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
1016            (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
1017                c0 == c1 && r0 == r1
1018            }
1019            _ => false,
1020        }
1021    }
1022}
1023
1024/// Return true if the fragment `format` is covered by the provided `output`.
1025pub fn check_texture_format(
1026    format: wgt::TextureFormat,
1027    output: &NumericType,
1028) -> Result<(), NumericType> {
1029    let nt = NumericType::from_texture_format(format);
1030    if nt.is_subtype_of(output) {
1031        Ok(())
1032    } else {
1033        Err(nt)
1034    }
1035}
1036
1037pub enum BindingLayoutSource {
1038    /// The binding layout is derived from the pipeline layout.
1039    ///
1040    /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces.
1041    Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
1042    /// The binding layout is provided by the user in BGLs.
1043    ///
1044    /// This will be validated against the shader's interfaces.
1045    Provided(Arc<crate::binding_model::PipelineLayout>),
1046}
1047
1048impl BindingLayoutSource {
1049    pub fn new_derived(limits: &wgt::Limits) -> Self {
1050        let mut array = ArrayVec::new();
1051        for _ in 0..limits.max_bind_groups {
1052            array.push(Default::default());
1053        }
1054        BindingLayoutSource::Derived(Box::new(array))
1055    }
1056}
1057
1058#[derive(Debug, Clone, Default)]
1059pub struct StageIo {
1060    pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
1061    /// This must match between mesh & task shaders
1062    pub task_payload_size: Option<u32>,
1063    /// Fragment shaders cannot input primitive index on mesh shaders that don't output it on DX12.
1064    /// Therefore, we track between shader stages if primitive index is written (or if vertex shader
1065    /// is used).
1066    ///
1067    /// This is Some if it was a mesh shader.
1068    pub primitive_index: Option<bool>,
1069}
1070
1071impl Interface {
1072    fn populate(
1073        list: &mut Vec<Varying>,
1074        binding: Option<&naga::Binding>,
1075        ty: naga::Handle<naga::Type>,
1076        arena: &naga::UniqueArena<naga::Type>,
1077    ) {
1078        let numeric_ty = match arena[ty].inner {
1079            naga::TypeInner::Scalar(scalar) => NumericType {
1080                dim: NumericDimension::Scalar,
1081                scalar,
1082            },
1083            naga::TypeInner::Vector { size, scalar } => NumericType {
1084                dim: NumericDimension::Vector(size),
1085                scalar,
1086            },
1087            naga::TypeInner::Matrix {
1088                columns,
1089                rows,
1090                scalar,
1091            } => NumericType {
1092                dim: NumericDimension::Matrix(columns, rows),
1093                scalar,
1094            },
1095            naga::TypeInner::Struct { ref members, .. } => {
1096                for member in members {
1097                    Self::populate(list, member.binding.as_ref(), member.ty, arena);
1098                }
1099                return;
1100            }
1101            naga::TypeInner::Array { base, size, stride }
1102                if matches!(
1103                    binding,
1104                    Some(naga::Binding::BuiltIn(naga::BuiltIn::ClipDistances)),
1105                ) =>
1106            {
1107                // NOTE: We should already have validated these in `naga`.
1108                debug_assert_eq!(
1109                    &arena[base].inner,
1110                    &naga::TypeInner::Scalar(naga::Scalar::F32)
1111                );
1112                debug_assert_eq!(stride, 4);
1113
1114                let naga::ArraySize::Constant(array_size) = size else {
1115                    // NOTE: Based on the
1116                    // [spec](https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types):
1117                    //
1118                    // > The only valid use of a fixed-size array with an element count that is an
1119                    // > override-expression that is not a const-expression is as a memory view in
1120                    // > the workgroup address space.
1121                    unreachable!("non-constant array size for `clip_distances`")
1122                };
1123                let array_size = array_size.get();
1124
1125                list.push(Varying::BuiltIn(BuiltIn::ClipDistances { array_size }));
1126                return;
1127            }
1128            ref other => {
1129                //Note: technically this should be at least `log::error`, but
1130                // the reality is - every shader coming from `glslc` outputs an array
1131                // of clip distances and hits this path :(
1132                // So we lower it to `log::debug` to be less annoying as
1133                // there's nothing the user can do about it.
1134                log::debug!("Unexpected varying type: {other:?}");
1135                return;
1136            }
1137        };
1138
1139        let varying = match binding {
1140            Some(&naga::Binding::Location {
1141                location,
1142                interpolation,
1143                sampling,
1144                per_primitive,
1145                blend_src: _,
1146            }) => Varying::Local {
1147                location,
1148                iv: InterfaceVar {
1149                    ty: numeric_ty,
1150                    interpolation,
1151                    sampling,
1152                    per_primitive,
1153                },
1154            },
1155            Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(match built_in {
1156                naga::BuiltIn::Position { invariant } => BuiltIn::Position { invariant },
1157                naga::BuiltIn::ViewIndex => BuiltIn::ViewIndex,
1158                naga::BuiltIn::BaseInstance => BuiltIn::BaseInstance,
1159                naga::BuiltIn::BaseVertex => BuiltIn::BaseVertex,
1160                naga::BuiltIn::ClipDistances => unreachable!(),
1161                naga::BuiltIn::CullDistance => BuiltIn::CullDistance,
1162                naga::BuiltIn::InstanceIndex => BuiltIn::InstanceIndex,
1163                naga::BuiltIn::PointSize => BuiltIn::PointSize,
1164                naga::BuiltIn::VertexIndex => BuiltIn::VertexIndex,
1165                naga::BuiltIn::DrawIndex => BuiltIn::DrawIndex,
1166                naga::BuiltIn::FragDepth => BuiltIn::FragDepth,
1167                naga::BuiltIn::PointCoord => BuiltIn::PointCoord,
1168                naga::BuiltIn::FrontFacing => BuiltIn::FrontFacing,
1169                naga::BuiltIn::PrimitiveIndex => BuiltIn::PrimitiveIndex,
1170                naga::BuiltIn::Barycentric { perspective } => BuiltIn::Barycentric { perspective },
1171                naga::BuiltIn::SampleIndex => BuiltIn::SampleIndex,
1172                naga::BuiltIn::SampleMask => BuiltIn::SampleMask,
1173                naga::BuiltIn::GlobalInvocationId => BuiltIn::GlobalInvocationId,
1174                naga::BuiltIn::LocalInvocationId => BuiltIn::LocalInvocationId,
1175                naga::BuiltIn::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
1176                naga::BuiltIn::WorkGroupId => BuiltIn::WorkGroupId,
1177                naga::BuiltIn::WorkGroupSize => BuiltIn::WorkGroupSize,
1178                naga::BuiltIn::NumWorkGroups => BuiltIn::NumWorkGroups,
1179                naga::BuiltIn::NumSubgroups => BuiltIn::NumSubgroups,
1180                naga::BuiltIn::SubgroupId => BuiltIn::SubgroupId,
1181                naga::BuiltIn::SubgroupSize => BuiltIn::SubgroupSize,
1182                naga::BuiltIn::SubgroupInvocationId => BuiltIn::SubgroupInvocationId,
1183                naga::BuiltIn::MeshTaskSize => BuiltIn::MeshTaskSize,
1184                naga::BuiltIn::CullPrimitive => BuiltIn::CullPrimitive,
1185                naga::BuiltIn::PointIndex => BuiltIn::PointIndex,
1186                naga::BuiltIn::LineIndices => BuiltIn::LineIndices,
1187                naga::BuiltIn::TriangleIndices => BuiltIn::TriangleIndices,
1188                naga::BuiltIn::VertexCount => BuiltIn::VertexCount,
1189                naga::BuiltIn::Vertices => BuiltIn::Vertices,
1190                naga::BuiltIn::PrimitiveCount => BuiltIn::PrimitiveCount,
1191                naga::BuiltIn::Primitives => BuiltIn::Primitives,
1192                naga::BuiltIn::RayInvocationId => BuiltIn::RayInvocationId,
1193                naga::BuiltIn::NumRayInvocations => BuiltIn::NumRayInvocations,
1194                naga::BuiltIn::InstanceCustomData => BuiltIn::InstanceCustomData,
1195                naga::BuiltIn::GeometryIndex => BuiltIn::GeometryIndex,
1196                naga::BuiltIn::WorldRayOrigin => BuiltIn::WorldRayOrigin,
1197                naga::BuiltIn::WorldRayDirection => BuiltIn::WorldRayDirection,
1198                naga::BuiltIn::ObjectRayOrigin => BuiltIn::ObjectRayOrigin,
1199                naga::BuiltIn::ObjectRayDirection => BuiltIn::ObjectRayDirection,
1200                naga::BuiltIn::RayTmin => BuiltIn::RayTmin,
1201                naga::BuiltIn::RayTCurrentMax => BuiltIn::RayTCurrentMax,
1202                naga::BuiltIn::ObjectToWorld => BuiltIn::ObjectToWorld,
1203                naga::BuiltIn::WorldToObject => BuiltIn::WorldToObject,
1204                naga::BuiltIn::HitKind => BuiltIn::HitKind,
1205            }),
1206            None => {
1207                log::error!("Missing binding for a varying");
1208                return;
1209            }
1210        };
1211        list.push(varying);
1212    }
1213
1214    pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1215        let mut resources = naga::Arena::new();
1216        let mut resource_mapping = FastHashMap::default();
1217        for (var_handle, var) in module.global_variables.iter() {
1218            let bind = match var.binding {
1219                Some(br) => br,
1220                _ => continue,
1221            };
1222            let naga_ty = &module.types[var.ty].inner;
1223
1224            let inner_ty = match *naga_ty {
1225                naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1226                ref ty => ty,
1227            };
1228
1229            let ty = match *inner_ty {
1230                naga::TypeInner::Image {
1231                    dim,
1232                    arrayed,
1233                    class,
1234                } => ResourceType::Texture {
1235                    dim,
1236                    arrayed,
1237                    class,
1238                },
1239                naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1240                naga::TypeInner::AccelerationStructure { vertex_return } => {
1241                    ResourceType::AccelerationStructure { vertex_return }
1242                }
1243                ref other => ResourceType::Buffer {
1244                    size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1245                },
1246            };
1247            let handle = resources.append(
1248                Resource {
1249                    name: var.name.clone(),
1250                    bind,
1251                    ty,
1252                    class: var.space,
1253                },
1254                Default::default(),
1255            );
1256            resource_mapping.insert(var_handle, handle);
1257        }
1258
1259        let immediate_size = naga::valid::ImmediateSlots::size_for_module(module);
1260
1261        let mut entry_points = FastHashMap::default();
1262        entry_points.reserve(module.entry_points.len());
1263        for (index, entry_point) in module.entry_points.iter().enumerate() {
1264            let info = info.get_entry_point(index);
1265            let mut ep = EntryPoint::default();
1266            for arg in entry_point.function.arguments.iter() {
1267                Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1268            }
1269            if let Some(ref result) = entry_point.function.result {
1270                Self::populate(
1271                    &mut ep.outputs,
1272                    result.binding.as_ref(),
1273                    result.ty,
1274                    &module.types,
1275                );
1276            }
1277
1278            for (var_handle, var) in module.global_variables.iter() {
1279                let usage = info[var_handle];
1280                if !usage.is_empty() && var.binding.is_some() {
1281                    ep.resources.push(resource_mapping[&var_handle]);
1282                }
1283            }
1284
1285            for key in info.sampling_set.iter() {
1286                ep.sampling_pairs
1287                    .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1288            }
1289            ep.dual_source_blending = info.dual_source_blending;
1290            ep.workgroup_size = entry_point.workgroup_size;
1291            ep.immediate_slots_required = info.immediate_slots_used;
1292
1293            if let Some(task_payload) = entry_point.task_payload {
1294                ep.task_payload_size = Some(
1295                    module.types[module.global_variables[task_payload].ty]
1296                        .inner
1297                        .size(module.to_ctx()),
1298                );
1299            }
1300            if let Some(ref mesh_info) = entry_point.mesh_info {
1301                ep.mesh_info = Some(EntryPointMeshInfo {
1302                    max_vertices: mesh_info.max_vertices,
1303                    max_primitives: mesh_info.max_primitives,
1304                    primitive_topology: match mesh_info.topology {
1305                        naga::MeshOutputTopology::Triangles => wgt::PrimitiveTopology::TriangleList,
1306                        naga::MeshOutputTopology::Lines => wgt::PrimitiveTopology::LineList,
1307                        naga::MeshOutputTopology::Points => wgt::PrimitiveTopology::PointList,
1308                    },
1309                });
1310                Self::populate(
1311                    &mut ep.outputs,
1312                    None,
1313                    mesh_info.vertex_output_type,
1314                    &module.types,
1315                );
1316                Self::populate(
1317                    &mut ep.outputs,
1318                    None,
1319                    mesh_info.primitive_output_type,
1320                    &module.types,
1321                );
1322            }
1323
1324            entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1325        }
1326
1327        Self {
1328            limits,
1329            resources,
1330            entry_points,
1331            immediate_size,
1332        }
1333    }
1334
1335    pub fn immediate_slots_required(
1336        &self,
1337        stage: naga::ShaderStage,
1338        entry_point_name: &str,
1339    ) -> naga::valid::ImmediateSlots {
1340        self.entry_points
1341            .get(&(stage, entry_point_name.to_string()))
1342            .map_or(Default::default(), |ep| ep.immediate_slots_required)
1343    }
1344
1345    pub fn finalize_entry_point_name(
1346        &self,
1347        stage: naga::ShaderStage,
1348        entry_point_name: Option<&str>,
1349    ) -> Result<String, StageError> {
1350        entry_point_name
1351            .map(|ep| ep.to_string())
1352            .map(Ok)
1353            .unwrap_or_else(|| {
1354                let mut entry_points = self
1355                    .entry_points
1356                    .keys()
1357                    .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1358                let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1359                if entry_points.next().is_some() {
1360                    return Err(StageError::MultipleEntryPointsFound);
1361                }
1362                Ok(first.clone())
1363            })
1364    }
1365
1366    /// Among other things, this implements some validation logic defined by the WebGPU spec. at
1367    /// <https://www.w3.org/TR/webgpu/#abstract-opdef-validating-inter-stage-interfaces>.
1368    pub fn check_stage(
1369        &self,
1370        layouts: &mut BindingLayoutSource,
1371        shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1372        entry_point_name: &str,
1373        shader_stage: ShaderStageForValidation,
1374        inputs: StageIo,
1375        primitive_topology: Option<wgt::PrimitiveTopology>,
1376    ) -> Result<StageIo, StageError> {
1377        // Since a shader module can have multiple entry points with the same name,
1378        // we need to look for one with the right execution model.
1379        let pair = (shader_stage.to_naga(), entry_point_name.to_string());
1380        let entry_point = match self.entry_points.get(&pair) {
1381            Some(some) => some,
1382            None => return Err(StageError::MissingEntryPoint(pair.1)),
1383        };
1384        let (_, entry_point_name) = pair;
1385
1386        let stage_bit = shader_stage.to_wgt_bit();
1387
1388        // check resources visibility
1389        for &handle in entry_point.resources.iter() {
1390            let res = &self.resources[handle];
1391            let result = 'err: {
1392                match layouts {
1393                    BindingLayoutSource::Provided(pipeline_layout) => {
1394                        // update the required binding size for this buffer
1395                        if let ResourceType::Buffer { size } = res.ty {
1396                            match shader_binding_sizes.entry(res.bind) {
1397                                Entry::Occupied(e) => {
1398                                    *e.into_mut() = size.max(*e.get());
1399                                }
1400                                Entry::Vacant(e) => {
1401                                    e.insert(size);
1402                                }
1403                            }
1404                        }
1405
1406                        let Some(entry) =
1407                            pipeline_layout.get_bgl_entry(res.bind.group, res.bind.binding)
1408                        else {
1409                            break 'err Err(BindingError::Missing);
1410                        };
1411
1412                        if !entry.visibility.contains(stage_bit) {
1413                            break 'err Err(BindingError::Invisible);
1414                        }
1415
1416                        res.check_binding_use(entry)
1417                    }
1418                    BindingLayoutSource::Derived(layouts) => {
1419                        let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1420                            break 'err Err(BindingError::Missing);
1421                        };
1422
1423                        let ty = match res.derive_binding_type(
1424                            entry_point
1425                                .sampling_pairs
1426                                .iter()
1427                                .any(|&(im, _samp)| im == handle),
1428                        ) {
1429                            Ok(ty) => ty,
1430                            Err(error) => break 'err Err(error),
1431                        };
1432
1433                        match map.entry(res.bind.binding) {
1434                            indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1435                                break 'err Err(BindingError::InconsistentlyDerivedType)
1436                            }
1437                            indexmap::map::Entry::Occupied(e) => {
1438                                e.into_mut().visibility |= stage_bit;
1439                            }
1440                            indexmap::map::Entry::Vacant(e) => {
1441                                e.insert(BindGroupLayoutEntry {
1442                                    binding: res.bind.binding,
1443                                    ty,
1444                                    visibility: stage_bit,
1445                                    count: None,
1446                                });
1447                            }
1448                        }
1449                        Ok(())
1450                    }
1451                }
1452            };
1453            if let Err(error) = result {
1454                return Err(StageError::Binding(res.bind, error));
1455            }
1456        }
1457
1458        // Check the compatibility between textures and samplers
1459        //
1460        // We only need to do this if the binding layout is provided by the user, as derived
1461        // layouts will inherently be correctly tagged.
1462        if let BindingLayoutSource::Provided(pipeline_layout) = layouts {
1463            for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1464                let texture_bind = &self.resources[texture_handle].bind;
1465                let sampler_bind = &self.resources[sampler_handle].bind;
1466                let texture_layout = pipeline_layout
1467                    .get_bgl_entry(texture_bind.group, texture_bind.binding)
1468                    .unwrap();
1469                let sampler_layout = pipeline_layout
1470                    .get_bgl_entry(sampler_bind.group, sampler_bind.binding)
1471                    .unwrap();
1472                assert!(texture_layout.visibility.contains(stage_bit));
1473                assert!(sampler_layout.visibility.contains(stage_bit));
1474
1475                let sampler_filtering = matches!(
1476                    sampler_layout.ty,
1477                    BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1478                );
1479                let texture_sample_type = match texture_layout.ty {
1480                    BindingType::Texture { sample_type, .. } => sample_type,
1481                    BindingType::ExternalTexture => {
1482                        wgt::TextureSampleType::Float { filterable: true }
1483                    }
1484                    _ => unreachable!(),
1485                };
1486
1487                let error = match (sampler_filtering, texture_sample_type) {
1488                    (true, wgt::TextureSampleType::Float { filterable: false }) => {
1489                        Some(FilteringError::Float)
1490                    }
1491                    (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1492                    (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1493                    _ => None,
1494                };
1495
1496                if let Some(error) = error {
1497                    return Err(StageError::Filtering {
1498                        texture: *texture_bind,
1499                        sampler: *sampler_bind,
1500                        error,
1501                    });
1502                }
1503            }
1504        }
1505
1506        // check workgroup size limits
1507        if shader_stage.to_naga().compute_like() {
1508            let total = match shader_stage.to_naga() {
1509                naga::ShaderStage::Compute => check_workgroup_sizes(
1510                    &entry_point.workgroup_size,
1511                    &[
1512                        self.limits.max_compute_workgroup_size_x,
1513                        self.limits.max_compute_workgroup_size_y,
1514                        self.limits.max_compute_workgroup_size_z,
1515                    ],
1516                    "max_compute_workgroup_size_*",
1517                    self.limits.max_compute_invocations_per_workgroup,
1518                    "max_compute_invocations_per_workgroup",
1519                )?,
1520                naga::ShaderStage::Task => check_workgroup_sizes(
1521                    &entry_point.workgroup_size,
1522                    &[
1523                        self.limits.max_task_invocations_per_dimension,
1524                        self.limits.max_task_invocations_per_dimension,
1525                        self.limits.max_task_invocations_per_dimension,
1526                    ],
1527                    "max_task_invocations_per_dimension",
1528                    self.limits.max_task_invocations_per_workgroup,
1529                    "max_task_invocations_per_workgroup",
1530                )?,
1531                naga::ShaderStage::Mesh => check_workgroup_sizes(
1532                    &entry_point.workgroup_size,
1533                    &[
1534                        self.limits.max_mesh_invocations_per_dimension,
1535                        self.limits.max_mesh_invocations_per_dimension,
1536                        self.limits.max_mesh_invocations_per_dimension,
1537                    ],
1538                    "max_mesh_invocations_per_dimension",
1539                    self.limits.max_mesh_invocations_per_workgroup,
1540                    "max_mesh_invocations_per_workgroup",
1541                )?,
1542                _ => unreachable!(),
1543            };
1544            if total == 0 {
1545                return Err(StageError::InvalidWorkgroupSize(
1546                    InvalidWorkgroupSizeError::Zero {
1547                        dimensions: entry_point.workgroup_size,
1548                    },
1549                ));
1550            }
1551        }
1552
1553        let mut this_stage_primitive_index = false;
1554        let mut has_draw_id = false;
1555        let mut has_per_vertex = false;
1556
1557        // check inputs compatibility
1558        for input in entry_point.inputs.iter() {
1559            match *input {
1560                Varying::Local { location, ref iv } => {
1561                    let result = inputs
1562                        .varyings
1563                        .get(&location)
1564                        .ok_or(InputError::Missing)
1565                        .and_then(|provided| {
1566                            let (compatible, per_primitive_correct) = match shader_stage.to_naga() {
1567                                // For vertex attributes, there are defaults filled out
1568                                // by the driver if data is not provided.
1569                                naga::ShaderStage::Vertex => {
1570                                    let is_compatible =
1571                                        iv.ty.scalar.kind == provided.ty.scalar.kind;
1572                                    // vertex inputs don't count towards inter-stage
1573                                    (is_compatible, !iv.per_primitive)
1574                                }
1575                                naga::ShaderStage::Fragment => {
1576                                    if iv.interpolation != provided.interpolation {
1577                                        return Err(InputError::InterpolationMismatch(
1578                                            provided.interpolation,
1579                                        ));
1580                                    }
1581                                    if iv.sampling != provided.sampling {
1582                                        return Err(InputError::SamplingMismatch(
1583                                            provided.sampling,
1584                                        ));
1585                                    }
1586                                    (
1587                                        iv.ty.is_subtype_of(&provided.ty),
1588                                        iv.per_primitive == provided.per_primitive,
1589                                    )
1590                                }
1591                                // These can't have varying inputs
1592                                naga::ShaderStage::Compute
1593                                | naga::ShaderStage::Task
1594                                | naga::ShaderStage::Mesh => (false, false),
1595                                naga::ShaderStage::RayGeneration
1596                                | naga::ShaderStage::AnyHit
1597                                | naga::ShaderStage::ClosestHit
1598                                | naga::ShaderStage::Miss => {
1599                                    unreachable!()
1600                                }
1601                            };
1602                            if !compatible {
1603                                return Err(InputError::WrongType(provided.ty));
1604                            } else if !per_primitive_correct {
1605                                return Err(InputError::WrongPerPrimitive {
1606                                    pipeline_input: provided.per_primitive,
1607                                    shader: iv.per_primitive,
1608                                });
1609                            }
1610                            Ok(())
1611                        });
1612
1613                    if let Err(error) = result {
1614                        return Err(StageError::Input {
1615                            location,
1616                            var: iv.clone(),
1617                            error,
1618                        });
1619                    }
1620                    has_per_vertex |= iv.interpolation == Some(naga::Interpolation::PerVertex);
1621                }
1622                Varying::BuiltIn(BuiltIn::PrimitiveIndex) => {
1623                    this_stage_primitive_index = true;
1624                }
1625                Varying::BuiltIn(BuiltIn::DrawIndex) => {
1626                    has_draw_id = true;
1627                }
1628                Varying::BuiltIn(_) => {}
1629            }
1630        }
1631
1632        match shader_stage {
1633            ShaderStageForValidation::Vertex {
1634                topology,
1635                compare_function,
1636            } => {
1637                let mut max_vertex_shader_output_variables =
1638                    self.limits.max_inter_stage_shader_variables;
1639                let mut max_vertex_shader_output_location = max_vertex_shader_output_variables - 1;
1640
1641                let point_list_deduction = if topology == wgt::PrimitiveTopology::PointList {
1642                    Some(MaxVertexShaderOutputDeduction::PointListPrimitiveTopology)
1643                } else {
1644                    None
1645                };
1646
1647                let clip_distance_deductions = entry_point.outputs.iter().filter_map(|output| {
1648                    if let &Varying::BuiltIn(BuiltIn::ClipDistances { array_size }) = output {
1649                        Some(MaxVertexShaderOutputDeduction::ClipDistances { array_size })
1650                    } else {
1651                        None
1652                    }
1653                });
1654                debug_assert!(
1655                    clip_distance_deductions.clone().count() <= 1,
1656                    "multiple `clip_distances` outputs found"
1657                );
1658
1659                let deductions = point_list_deduction
1660                    .into_iter()
1661                    .chain(clip_distance_deductions);
1662
1663                for deduction in deductions.clone() {
1664                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1665                    // ever exceed the minimum variables available.
1666                    max_vertex_shader_output_variables = max_vertex_shader_output_variables
1667                        .checked_sub(deduction.for_variables())
1668                        .unwrap();
1669                    max_vertex_shader_output_location = max_vertex_shader_output_location
1670                        .checked_sub(deduction.for_location())
1671                        .unwrap();
1672                }
1673
1674                let mut num_user_defined_outputs = 0;
1675
1676                for output in entry_point.outputs.iter() {
1677                    match *output {
1678                        Varying::Local { ref iv, location } => {
1679                            if location > max_vertex_shader_output_location {
1680                                return Err(StageError::VertexOutputLocationTooLarge {
1681                                    location,
1682                                    var: iv.clone(),
1683                                    limit: self.limits.max_inter_stage_shader_variables,
1684                                    deductions: deductions.collect(),
1685                                });
1686                            }
1687                            num_user_defined_outputs += 1;
1688                        }
1689                        Varying::BuiltIn(_) => {}
1690                    };
1691
1692                    if let Some(
1693                        cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1694                    ) = compare_function
1695                    {
1696                        if let Varying::BuiltIn(BuiltIn::Position { invariant: false }) = *output {
1697                            log::warn!(
1698                                concat!(
1699                                    "Vertex shader with entry point {} outputs a ",
1700                                    "@builtin(position) without the @invariant attribute and ",
1701                                    "is used in a pipeline with {cmp:?}. On some machines, ",
1702                                    "this can cause bad artifacting as {cmp:?} assumes the ",
1703                                    "values output from the vertex shader exactly match the ",
1704                                    "value in the depth buffer. The @invariant attribute on the ",
1705                                    "@builtin(position) vertex output ensures that the exact ",
1706                                    "same pixel depths are used every render."
1707                                ),
1708                                entry_point_name,
1709                                cmp = cmp
1710                            );
1711                        }
1712                    }
1713                }
1714
1715                if num_user_defined_outputs > max_vertex_shader_output_variables {
1716                    return Err(StageError::TooManyUserDefinedVertexOutputs {
1717                        num_found: num_user_defined_outputs,
1718                        limit: self.limits.max_inter_stage_shader_variables,
1719                        deductions: deductions.collect(),
1720                    });
1721                }
1722            }
1723            ShaderStageForValidation::Fragment {
1724                dual_source_blending,
1725                has_depth_attachment,
1726            } => {
1727                let mut max_fragment_shader_input_variables =
1728                    self.limits.max_inter_stage_shader_variables;
1729
1730                let deductions = entry_point.inputs.iter().filter_map(|output| match output {
1731                    Varying::Local { .. } => None,
1732                    Varying::BuiltIn(builtin) => {
1733                        MaxFragmentShaderInputDeduction::from_inter_stage_builtin(builtin.to_naga())
1734                            .or_else(|| {
1735                                unreachable!(
1736                                    concat!(
1737                                        "unexpected built-in provided; ",
1738                                        "{:?} is not used for fragment stage input",
1739                                    ),
1740                                    builtin
1741                                )
1742                            })
1743                    }
1744                });
1745
1746                for deduction in deductions.clone() {
1747                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1748                    // ever exceed the minimum variables available.
1749                    max_fragment_shader_input_variables = max_fragment_shader_input_variables
1750                        .checked_sub(deduction.for_variables())
1751                        .unwrap();
1752                }
1753
1754                let mut num_user_defined_inputs = 0;
1755
1756                for output in entry_point.inputs.iter() {
1757                    match *output {
1758                        Varying::Local { ref iv, location } => {
1759                            if location >= self.limits.max_inter_stage_shader_variables {
1760                                return Err(StageError::FragmentInputLocationTooLarge {
1761                                    location,
1762                                    var: iv.clone(),
1763                                    limit: self.limits.max_inter_stage_shader_variables,
1764                                    deductions: deductions.collect(),
1765                                });
1766                            }
1767                            num_user_defined_inputs += 1;
1768                        }
1769                        Varying::BuiltIn(_) => {}
1770                    };
1771                }
1772
1773                if num_user_defined_inputs > max_fragment_shader_input_variables {
1774                    return Err(StageError::TooManyUserDefinedFragmentInputs {
1775                        num_found: num_user_defined_inputs,
1776                        limit: self.limits.max_inter_stage_shader_variables,
1777                        deductions: deductions.collect(),
1778                    });
1779                }
1780
1781                for output in &entry_point.outputs {
1782                    let &Varying::Local { location, ref iv } = output else {
1783                        continue;
1784                    };
1785                    if location >= self.limits.max_color_attachments {
1786                        return Err(StageError::ColorAttachmentLocationTooLarge {
1787                            location,
1788                            var: iv.clone(),
1789                            limit: self.limits.max_color_attachments,
1790                        });
1791                    }
1792                }
1793
1794                // If the pipeline uses dual-source blending, then the shader
1795                // must configure appropriate I/O, but it is not an error to
1796                // use a shader that defines the I/O in a pipeline that only
1797                // uses one blend source.
1798                if dual_source_blending && !entry_point.dual_source_blending {
1799                    return Err(StageError::InvalidDualSourceBlending);
1800                }
1801
1802                if entry_point
1803                    .outputs
1804                    .contains(&Varying::BuiltIn(BuiltIn::FragDepth))
1805                    && !has_depth_attachment
1806                {
1807                    return Err(StageError::MissingFragDepthAttachment);
1808                }
1809            }
1810            ShaderStageForValidation::Mesh => {
1811                for output in &entry_point.outputs {
1812                    if matches!(output, Varying::BuiltIn(BuiltIn::PrimitiveIndex)) {
1813                        this_stage_primitive_index = true;
1814                    }
1815                }
1816            }
1817            _ => (),
1818        }
1819
1820        if let Some(ref mesh_info) = entry_point.mesh_info {
1821            if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1822                return Err(StageError::TooManyMeshVertices {
1823                    limit: self.limits.max_mesh_output_vertices,
1824                    value: mesh_info.max_vertices,
1825                });
1826            }
1827            if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1828                return Err(StageError::TooManyMeshPrimitives {
1829                    limit: self.limits.max_mesh_output_primitives,
1830                    value: mesh_info.max_primitives,
1831                });
1832            }
1833            if primitive_topology != Some(mesh_info.primitive_topology) {
1834                return Err(StageError::MeshTopologyMismatch);
1835            }
1836        }
1837        if let Some(task_payload_size) = entry_point.task_payload_size {
1838            if task_payload_size > self.limits.max_task_payload_size {
1839                return Err(StageError::TaskPayloadTooLarge {
1840                    limit: self.limits.max_task_payload_size,
1841                    value: task_payload_size,
1842                });
1843            }
1844        }
1845        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1846            && entry_point.task_payload_size != inputs.task_payload_size
1847        {
1848            return Err(StageError::TaskPayloadMustMatch {
1849                input: inputs.task_payload_size,
1850                shader: entry_point.task_payload_size,
1851            });
1852        }
1853
1854        // Fragment shader primitive index is treated like a varying
1855        if shader_stage.to_naga() == naga::ShaderStage::Fragment
1856            && this_stage_primitive_index
1857            && inputs.primitive_index == Some(false)
1858        {
1859            return Err(StageError::InvalidPrimitiveIndex);
1860        } else if shader_stage.to_naga() == naga::ShaderStage::Fragment
1861            && !this_stage_primitive_index
1862            && inputs.primitive_index == Some(true)
1863        {
1864            return Err(StageError::MissingPrimitiveIndex);
1865        }
1866        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1867            && inputs.task_payload_size.is_some()
1868            && has_draw_id
1869        {
1870            return Err(StageError::DrawIdError);
1871        }
1872
1873        if primitive_topology.is_none_or(|e| !e.is_triangles()) && has_per_vertex {
1874            return Err(StageError::PerVertexNotTriangles);
1875        }
1876
1877        let outputs = entry_point
1878            .outputs
1879            .iter()
1880            .filter_map(|output| match *output {
1881                Varying::Local { location, ref iv } => Some((location, iv.clone())),
1882                Varying::BuiltIn(_) => None,
1883            })
1884            .collect();
1885
1886        Ok(StageIo {
1887            task_payload_size: entry_point.task_payload_size,
1888            varyings: outputs,
1889            primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
1890                Some(this_stage_primitive_index)
1891            } else {
1892                None
1893            },
1894        })
1895    }
1896
1897    pub fn fragment_uses_dual_source_blending(
1898        &self,
1899        entry_point_name: &str,
1900    ) -> Result<bool, StageError> {
1901        let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1902        self.entry_points
1903            .get(&pair)
1904            .ok_or(StageError::MissingEntryPoint(pair.1))
1905            .map(|ep| ep.dual_source_blending)
1906    }
1907}
1908
1909pub fn check_color_attachment_count(
1910    num_attachments: usize,
1911    limit: u32,
1912) -> Result<(), ColorAttachmentError> {
1913    let limit = usize::try_from(limit).unwrap();
1914    if num_attachments > limit {
1915        return Err(ColorAttachmentError::TooMany {
1916            given: num_attachments,
1917            limit,
1918        });
1919    }
1920
1921    Ok(())
1922}
1923
1924/// Validate a list of color attachment formats against `maxColorAttachmentBytesPerSample`.
1925///
1926/// The color attachments can be from a render pass descriptor or a pipeline descriptor.
1927///
1928/// Implements <https://gpuweb.github.io/gpuweb/#abstract-opdef-calculating-color-attachment-bytes-per-sample>.
1929pub fn validate_color_attachment_bytes_per_sample(
1930    attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1931    limit: u32,
1932) -> Result<(), ColorAttachmentError> {
1933    let mut total_bytes_per_sample: u32 = 0;
1934    for format in attachment_formats {
1935        let byte_cost = format.target_pixel_byte_cost().unwrap();
1936        let alignment = format.target_component_alignment().unwrap();
1937
1938        total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1939        total_bytes_per_sample += byte_cost;
1940    }
1941
1942    if total_bytes_per_sample > limit {
1943        return Err(ColorAttachmentError::TooManyBytesPerSample {
1944            total: total_bytes_per_sample,
1945            limit,
1946        });
1947    }
1948
1949    Ok(())
1950}
1951
1952#[derive(Clone, Debug, Error)]
1953pub enum InvalidWorkgroupSizeError {
1954    #[error(
1955        "Workgroup size {dimensions:?} ({total} total invocations) must be less or equal to \
1956        the per-dimension limit `Limits::{per_dimension_limits_desc}` of {per_dimension_limits:?} \
1957        and the total invocation limit `Limits::{total_limit_desc}` of {total_limit}"
1958    )]
1959    LimitExceeded {
1960        dimensions: [u32; 3],
1961        per_dimension_limits: [u32; 3],
1962        per_dimension_limits_desc: &'static str,
1963        total: u32,
1964        total_limit: u32,
1965        total_limit_desc: &'static str,
1966    },
1967    #[error("Workgroup sizes {dimensions:?} must be positive")]
1968    Zero { dimensions: [u32; 3] },
1969}
1970
1971/// Check X/Y/Z workgroup sizes against per-dimension and overall limits.
1972///
1973/// This function does not check that the sizes are non-zero. In a dispatch, it is legal for
1974/// the size to be zero. In shader or pipeline creation, it is an error for the size to be
1975/// zero, and the caller must check that.
1976pub(crate) fn check_workgroup_sizes(
1977    sizes: &[u32; 3],
1978    per_dimension_limits: &[u32; 3],
1979    per_dimension_limits_desc: &'static str,
1980    total_limit: u32,
1981    total_limit_desc: &'static str,
1982) -> Result<u32, InvalidWorkgroupSizeError> {
1983    let total = sizes
1984        .iter()
1985        .fold(1u32, |total, &dim| total.saturating_mul(dim));
1986
1987    let invalid_total_invocations = total > total_limit;
1988
1989    let dimension_too_large = sizes
1990        .iter()
1991        .zip(per_dimension_limits.iter())
1992        .any(|(dim, limit)| dim > limit);
1993
1994    if invalid_total_invocations || dimension_too_large {
1995        Err(InvalidWorkgroupSizeError::LimitExceeded {
1996            dimensions: *sizes,
1997            per_dimension_limits: *per_dimension_limits,
1998            per_dimension_limits_desc,
1999            total,
2000            total_limit,
2001            total_limit_desc,
2002        })
2003    } else {
2004        Ok(total)
2005    }
2006}
2007
2008pub enum ShaderStageForValidation {
2009    Vertex {
2010        topology: wgt::PrimitiveTopology,
2011        compare_function: Option<wgt::CompareFunction>,
2012    },
2013    Mesh,
2014    Fragment {
2015        dual_source_blending: bool,
2016        has_depth_attachment: bool,
2017    },
2018    Compute,
2019    Task,
2020}
2021
2022impl ShaderStageForValidation {
2023    pub fn to_naga(&self) -> naga::ShaderStage {
2024        match self {
2025            Self::Vertex { .. } => naga::ShaderStage::Vertex,
2026            Self::Mesh => naga::ShaderStage::Mesh,
2027            Self::Fragment { .. } => naga::ShaderStage::Fragment,
2028            Self::Compute => naga::ShaderStage::Compute,
2029            Self::Task => naga::ShaderStage::Task,
2030        }
2031    }
2032
2033    pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
2034        match self {
2035            Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
2036            Self::Mesh => wgt::ShaderStages::MESH,
2037            Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
2038            Self::Compute => wgt::ShaderStages::COMPUTE,
2039            Self::Task => wgt::ShaderStages::TASK,
2040        }
2041    }
2042}