wgpu_core/
validation.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString as _},
4    sync::Arc,
5    vec::Vec,
6};
7use core::fmt;
8
9use arrayvec::ArrayVec;
10use hashbrown::hash_map::Entry;
11use shader_io_deductions::{display_deductions_as_optional_list, MaxVertexShaderOutputDeduction};
12use thiserror::Error;
13use wgt::{
14    error::{ErrorType, WebGpuError},
15    BindGroupLayoutEntry, BindingType,
16};
17
18use crate::{
19    device::bgl, resource::InvalidResourceError,
20    validation::shader_io_deductions::MaxFragmentShaderInputDeduction, FastHashMap, FastHashSet,
21};
22
23pub mod shader_io_deductions;
24
25#[derive(Debug)]
26enum ResourceType {
27    Buffer {
28        size: wgt::BufferSize,
29    },
30    Texture {
31        dim: naga::ImageDimension,
32        arrayed: bool,
33        class: naga::ImageClass,
34    },
35    Sampler {
36        comparison: bool,
37    },
38    AccelerationStructure {
39        vertex_return: bool,
40    },
41}
42
43#[derive(Clone, Debug)]
44pub enum BindingTypeName {
45    Buffer,
46    Texture,
47    Sampler,
48    AccelerationStructure,
49    ExternalTexture,
50}
51
52impl From<&ResourceType> for BindingTypeName {
53    fn from(ty: &ResourceType) -> BindingTypeName {
54        match ty {
55            ResourceType::Buffer { .. } => BindingTypeName::Buffer,
56            ResourceType::Texture {
57                class: naga::ImageClass::External,
58                ..
59            } => BindingTypeName::ExternalTexture,
60            ResourceType::Texture { .. } => BindingTypeName::Texture,
61            ResourceType::Sampler { .. } => BindingTypeName::Sampler,
62            ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
63        }
64    }
65}
66
67impl From<&BindingType> for BindingTypeName {
68    fn from(ty: &BindingType) -> BindingTypeName {
69        match ty {
70            BindingType::Buffer { .. } => BindingTypeName::Buffer,
71            BindingType::Texture { .. } => BindingTypeName::Texture,
72            BindingType::StorageTexture { .. } => BindingTypeName::Texture,
73            BindingType::Sampler { .. } => BindingTypeName::Sampler,
74            BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
75            BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
76        }
77    }
78}
79
80#[derive(Debug)]
81struct Resource {
82    #[allow(unused)]
83    name: Option<String>,
84    bind: naga::ResourceBinding,
85    ty: ResourceType,
86    class: naga::AddressSpace,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
90enum NumericDimension {
91    Scalar,
92    Vector(naga::VectorSize),
93    Matrix(naga::VectorSize, naga::VectorSize),
94}
95
96impl fmt::Display for NumericDimension {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        match *self {
99            Self::Scalar => write!(f, ""),
100            Self::Vector(size) => write!(f, "x{}", size as u8),
101            Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
102        }
103    }
104}
105
106#[derive(Clone, Copy, Debug, Eq, PartialEq)]
107pub struct NumericType {
108    dim: NumericDimension,
109    scalar: naga::Scalar,
110}
111
112impl fmt::Display for NumericType {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(
115            f,
116            "{:?}{}{}",
117            self.scalar.kind,
118            self.scalar.width * 8,
119            self.dim
120        )
121    }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct InterfaceVar {
126    pub ty: NumericType,
127    interpolation: Option<naga::Interpolation>,
128    sampling: Option<naga::Sampling>,
129    per_primitive: bool,
130}
131
132impl InterfaceVar {
133    pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
134        InterfaceVar {
135            ty: NumericType::from_vertex_format(format),
136            interpolation: None,
137            sampling: None,
138            per_primitive: false,
139        }
140    }
141}
142
143impl fmt::Display for InterfaceVar {
144    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145        write!(
146            f,
147            "{} interpolated as {:?} with sampling {:?}",
148            self.ty, self.interpolation, self.sampling
149        )
150    }
151}
152
153#[derive(Debug, Eq, PartialEq)]
154enum Varying {
155    Local { location: u32, iv: InterfaceVar },
156    BuiltIn(BuiltIn),
157}
158
159#[derive(Clone, Debug, Eq, PartialEq)]
160enum BuiltIn {
161    Position { invariant: bool },
162    ViewIndex,
163    BaseInstance,
164    BaseVertex,
165    ClipDistances { array_size: u32 },
166    CullDistance,
167    InstanceIndex,
168    PointSize,
169    VertexIndex,
170    DrawIndex,
171    FragDepth,
172    PointCoord,
173    FrontFacing,
174    PrimitiveIndex,
175    Barycentric { perspective: bool },
176    SampleIndex,
177    SampleMask,
178    GlobalInvocationId,
179    LocalInvocationId,
180    LocalInvocationIndex,
181    WorkGroupId,
182    WorkGroupSize,
183    NumWorkGroups,
184    NumSubgroups,
185    SubgroupId,
186    SubgroupSize,
187    SubgroupInvocationId,
188    MeshTaskSize,
189    CullPrimitive,
190    PointIndex,
191    LineIndices,
192    TriangleIndices,
193    VertexCount,
194    Vertices,
195    PrimitiveCount,
196    Primitives,
197    RayInvocationId,
198    NumRayInvocations,
199    InstanceCustomData,
200    GeometryIndex,
201    WorldRayOrigin,
202    WorldRayDirection,
203    ObjectRayOrigin,
204    ObjectRayDirection,
205    RayTmin,
206    RayTCurrentMax,
207    ObjectToWorld,
208    WorldToObject,
209    HitKind,
210}
211
212impl BuiltIn {
213    pub fn to_naga(&self) -> naga::BuiltIn {
214        match self {
215            &Self::Position { invariant } => naga::BuiltIn::Position { invariant },
216            Self::ViewIndex => naga::BuiltIn::ViewIndex,
217            Self::BaseInstance => naga::BuiltIn::BaseInstance,
218            Self::BaseVertex => naga::BuiltIn::BaseVertex,
219            Self::ClipDistances { .. } => naga::BuiltIn::ClipDistances,
220            Self::CullDistance => naga::BuiltIn::CullDistance,
221            Self::InstanceIndex => naga::BuiltIn::InstanceIndex,
222            Self::PointSize => naga::BuiltIn::PointSize,
223            Self::VertexIndex => naga::BuiltIn::VertexIndex,
224            Self::DrawIndex => naga::BuiltIn::DrawIndex,
225            Self::FragDepth => naga::BuiltIn::FragDepth,
226            Self::PointCoord => naga::BuiltIn::PointCoord,
227            Self::FrontFacing => naga::BuiltIn::FrontFacing,
228            Self::PrimitiveIndex => naga::BuiltIn::PrimitiveIndex,
229            Self::Barycentric { perspective } => naga::BuiltIn::Barycentric {
230                perspective: *perspective,
231            },
232            Self::SampleIndex => naga::BuiltIn::SampleIndex,
233            Self::SampleMask => naga::BuiltIn::SampleMask,
234            Self::GlobalInvocationId => naga::BuiltIn::GlobalInvocationId,
235            Self::LocalInvocationId => naga::BuiltIn::LocalInvocationId,
236            Self::LocalInvocationIndex => naga::BuiltIn::LocalInvocationIndex,
237            Self::WorkGroupId => naga::BuiltIn::WorkGroupId,
238            Self::WorkGroupSize => naga::BuiltIn::WorkGroupSize,
239            Self::NumWorkGroups => naga::BuiltIn::NumWorkGroups,
240            Self::NumSubgroups => naga::BuiltIn::NumSubgroups,
241            Self::SubgroupId => naga::BuiltIn::SubgroupId,
242            Self::SubgroupSize => naga::BuiltIn::SubgroupSize,
243            Self::SubgroupInvocationId => naga::BuiltIn::SubgroupInvocationId,
244            Self::MeshTaskSize => naga::BuiltIn::MeshTaskSize,
245            Self::CullPrimitive => naga::BuiltIn::CullPrimitive,
246            Self::PointIndex => naga::BuiltIn::PointIndex,
247            Self::LineIndices => naga::BuiltIn::LineIndices,
248            Self::TriangleIndices => naga::BuiltIn::TriangleIndices,
249            Self::VertexCount => naga::BuiltIn::VertexCount,
250            Self::Vertices => naga::BuiltIn::Vertices,
251            Self::PrimitiveCount => naga::BuiltIn::PrimitiveCount,
252            Self::Primitives => naga::BuiltIn::Primitives,
253            Self::RayInvocationId => naga::BuiltIn::RayInvocationId,
254            Self::NumRayInvocations => naga::BuiltIn::NumRayInvocations,
255            Self::InstanceCustomData => naga::BuiltIn::InstanceCustomData,
256            Self::GeometryIndex => naga::BuiltIn::GeometryIndex,
257            Self::WorldRayOrigin => naga::BuiltIn::WorldRayOrigin,
258            Self::WorldRayDirection => naga::BuiltIn::WorldRayDirection,
259            Self::ObjectRayOrigin => naga::BuiltIn::ObjectRayOrigin,
260            Self::ObjectRayDirection => naga::BuiltIn::ObjectRayDirection,
261            Self::RayTmin => naga::BuiltIn::RayTmin,
262            Self::RayTCurrentMax => naga::BuiltIn::RayTCurrentMax,
263            Self::ObjectToWorld => naga::BuiltIn::ObjectToWorld,
264            Self::WorldToObject => naga::BuiltIn::WorldToObject,
265            Self::HitKind => naga::BuiltIn::HitKind,
266        }
267    }
268}
269
270#[allow(unused)]
271#[derive(Debug)]
272struct SpecializationConstant {
273    id: u32,
274    ty: NumericType,
275}
276
277#[derive(Debug)]
278struct EntryPointMeshInfo {
279    max_vertices: u32,
280    max_primitives: u32,
281}
282
283#[derive(Debug, Default)]
284struct EntryPoint {
285    inputs: Vec<Varying>,
286    outputs: Vec<Varying>,
287    resources: Vec<naga::Handle<Resource>>,
288    #[allow(unused)]
289    spec_constants: Vec<SpecializationConstant>,
290    sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
291    workgroup_size: [u32; 3],
292    dual_source_blending: bool,
293    task_payload_size: Option<u32>,
294    mesh_info: Option<EntryPointMeshInfo>,
295}
296
297#[derive(Debug)]
298pub struct Interface {
299    limits: wgt::Limits,
300    resources: naga::Arena<Resource>,
301    entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
302}
303
304#[derive(Clone, Debug, Error)]
305#[non_exhaustive]
306pub enum BindingError {
307    #[error("Binding is missing from the pipeline layout")]
308    Missing,
309    #[error("Visibility flags don't include the shader stage")]
310    Invisible,
311    #[error(
312        "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
313    )]
314    WrongType {
315        binding: BindingTypeName,
316        shader: BindingTypeName,
317    },
318    #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
319    WrongAddressSpace {
320        binding: naga::AddressSpace,
321        shader: naga::AddressSpace,
322    },
323    #[error("Address space {space:?} is not a valid Buffer address space")]
324    WrongBufferAddressSpace { space: naga::AddressSpace },
325    #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
326    WrongBufferSize {
327        buffer_size: wgt::BufferSize,
328        min_binding_size: wgt::BufferSize,
329    },
330    #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
331    WrongTextureViewDimension {
332        dim: naga::ImageDimension,
333        is_array: bool,
334        binding: BindingType,
335    },
336    #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
337    WrongTextureClass {
338        binding: naga::ImageClass,
339        shader: naga::ImageClass,
340    },
341    #[error("Comparison flag doesn't match the shader")]
342    WrongSamplerComparison,
343    #[error("Derived bind group layout type is not consistent between stages")]
344    InconsistentlyDerivedType,
345    #[error("Texture format {0:?} is not supported for storage use")]
346    BadStorageFormat(wgt::TextureFormat),
347}
348
349impl WebGpuError for BindingError {
350    fn webgpu_error_type(&self) -> ErrorType {
351        ErrorType::Validation
352    }
353}
354
355#[derive(Clone, Debug, Error)]
356#[non_exhaustive]
357pub enum FilteringError {
358    #[error("Integer textures can't be sampled with a filtering sampler")]
359    Integer,
360    #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
361    Float,
362}
363
364impl WebGpuError for FilteringError {
365    fn webgpu_error_type(&self) -> ErrorType {
366        ErrorType::Validation
367    }
368}
369
370#[derive(Clone, Debug, Error)]
371#[non_exhaustive]
372pub enum InputError {
373    #[error("Input is not provided by the earlier stage in the pipeline")]
374    Missing,
375    #[error("Input type is not compatible with the provided {0}")]
376    WrongType(NumericType),
377    #[error("Input interpolation doesn't match provided {0:?}")]
378    InterpolationMismatch(Option<naga::Interpolation>),
379    #[error("Input sampling doesn't match provided {0:?}")]
380    SamplingMismatch(Option<naga::Sampling>),
381    #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
382    WrongPerPrimitive { pipeline_input: bool, shader: bool },
383}
384
385impl WebGpuError for InputError {
386    fn webgpu_error_type(&self) -> ErrorType {
387        ErrorType::Validation
388    }
389}
390
391/// Errors produced when validating a programmable stage of a pipeline.
392#[derive(Clone, Debug, Error)]
393#[non_exhaustive]
394pub enum StageError {
395    #[error(
396        "Shader entry point's workgroup size {dimensions:?} ({total} total invocations) must be \
397        less or equal to the per-dimension limit `Limits::{per_dimension_limits_desc}` of \
398        {per_dimension_limits:?} and the total invocation limit `Limits::{total_limit_desc}` of \
399        {total_limit}"
400    )]
401    InvalidWorkgroupSize {
402        dimensions: [u32; 3],
403        per_dimension_limits: [u32; 3],
404        per_dimension_limits_desc: &'static str,
405        total: u32,
406        total_limit: u32,
407        total_limit_desc: &'static str,
408    },
409    #[error("Unable to find entry point '{0}'")]
410    MissingEntryPoint(String),
411    #[error("Shader global {0:?} is not available in the pipeline layout")]
412    Binding(naga::ResourceBinding, #[source] BindingError),
413    #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
414    Filtering {
415        texture: naga::ResourceBinding,
416        sampler: naga::ResourceBinding,
417        #[source]
418        error: FilteringError,
419    },
420    #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
421    Input {
422        location: wgt::ShaderLocation,
423        var: InterfaceVar,
424        #[source]
425        error: InputError,
426    },
427    #[error(
428        "Unable to select an entry point: no entry point was found in the provided shader module"
429    )]
430    NoEntryPointFound,
431    #[error(
432        "Unable to select an entry point: \
433        multiple entry points were found in the provided shader module, \
434        but no entry point was specified"
435    )]
436    MultipleEntryPointsFound,
437    #[error(transparent)]
438    InvalidResource(#[from] InvalidResourceError),
439    #[error(
440        "vertex shader output location Location[{location}] ({var}) exceeds the \
441        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
442        // NOTE: Remember: the limit is 0-based for indices.
443        limit - 1,
444        display_deductions_as_optional_list(deductions, |d| d.for_location())
445    )]
446    VertexOutputLocationTooLarge {
447        location: u32,
448        var: InterfaceVar,
449        limit: u32,
450        deductions: Vec<MaxVertexShaderOutputDeduction>,
451    },
452    #[error(
453        "found {num_found} user-defined vertex shader output variables, which exceeds the \
454        `max_inter_stage_shader_variables` limit ({limit}){}",
455        display_deductions_as_optional_list(deductions, |d| d.for_variables())
456    )]
457    TooManyUserDefinedVertexOutputs {
458        num_found: u32,
459        limit: u32,
460        deductions: Vec<MaxVertexShaderOutputDeduction>,
461    },
462    #[error(
463        "fragment shader input location Location[{location}] ({var}) exceeds the \
464        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
465        // NOTE: Remember: the limit is 0-based for indices.
466        limit - 1,
467        // NOTE: WebGPU spec. validation for fragment inputs is expressed in terms of variables
468        // (unlike vertex outputs), so we use `MaxFragmentShaderInputDeduction::for_variables` here
469        // (and not a non-existent `for_locations`).
470        display_deductions_as_optional_list(deductions, |d| d.for_variables())
471    )]
472    FragmentInputLocationTooLarge {
473        location: u32,
474        var: InterfaceVar,
475        limit: u32,
476        deductions: Vec<MaxFragmentShaderInputDeduction>,
477    },
478    #[error(
479        "found {num_found} user-defined fragment shader input variables, which exceeds the \
480        `max_inter_stage_shader_variables` limit ({limit}){}",
481        display_deductions_as_optional_list(deductions, |d| d.for_variables())
482    )]
483    TooManyUserDefinedFragmentInputs {
484        num_found: u32,
485        limit: u32,
486        deductions: Vec<MaxFragmentShaderInputDeduction>,
487    },
488    #[error(
489        "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
490    )]
491    ColorAttachmentLocationTooLarge {
492        location: u32,
493        var: InterfaceVar,
494        limit: u32,
495    },
496    #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
497    TooManyMeshVertices { limit: u32, value: u32 },
498    #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
499    TooManyMeshPrimitives { limit: u32, value: u32 },
500    #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
501    TaskPayloadTooLarge { limit: u32, value: u32 },
502    #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
503    TaskPayloadMustMatch {
504        input: Option<u32>,
505        shader: Option<u32>,
506    },
507    #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
508    InvalidPrimitiveIndex,
509    #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
510    MissingPrimitiveIndex,
511    #[error("DrawId cannot be used in the same pipeline as a task shader")]
512    DrawIdError,
513    #[error("Pipeline uses dual-source blending, but the shader does not support it")]
514    InvalidDualSourceBlending,
515    #[error("Fragment shader writes depth, but pipeline does not have a depth attachment")]
516    MissingFragDepthAttachment,
517}
518
519impl WebGpuError for StageError {
520    fn webgpu_error_type(&self) -> ErrorType {
521        match self {
522            Self::Binding(_, e) => e.webgpu_error_type(),
523            Self::InvalidResource(e) => e.webgpu_error_type(),
524            Self::Filtering {
525                texture: _,
526                sampler: _,
527                error,
528            } => error.webgpu_error_type(),
529            Self::Input {
530                location: _,
531                var: _,
532                error,
533            } => error.webgpu_error_type(),
534            Self::InvalidWorkgroupSize { .. }
535            | Self::MissingEntryPoint(..)
536            | Self::NoEntryPointFound
537            | Self::MultipleEntryPointsFound
538            | Self::VertexOutputLocationTooLarge { .. }
539            | Self::TooManyUserDefinedVertexOutputs { .. }
540            | Self::FragmentInputLocationTooLarge { .. }
541            | Self::TooManyUserDefinedFragmentInputs { .. }
542            | Self::ColorAttachmentLocationTooLarge { .. }
543            | Self::TooManyMeshVertices { .. }
544            | Self::TooManyMeshPrimitives { .. }
545            | Self::TaskPayloadTooLarge { .. }
546            | Self::TaskPayloadMustMatch { .. }
547            | Self::InvalidPrimitiveIndex
548            | Self::MissingPrimitiveIndex
549            | Self::DrawIdError
550            | Self::InvalidDualSourceBlending
551            | Self::MissingFragDepthAttachment => ErrorType::Validation,
552        }
553    }
554}
555
556pub use wgpu_naga_bridge::map_storage_format_from_naga;
557pub use wgpu_naga_bridge::map_storage_format_to_naga;
558
559impl Resource {
560    fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
561        match self.ty {
562            ResourceType::Buffer { size } => {
563                let min_size = match entry.ty {
564                    BindingType::Buffer {
565                        ty,
566                        has_dynamic_offset: _,
567                        min_binding_size,
568                    } => {
569                        let class = match ty {
570                            wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
571                            wgt::BufferBindingType::Storage { read_only } => {
572                                let mut naga_access = naga::StorageAccess::LOAD;
573                                naga_access.set(naga::StorageAccess::STORE, !read_only);
574                                naga::AddressSpace::Storage {
575                                    access: naga_access,
576                                }
577                            }
578                        };
579                        if self.class != class {
580                            return Err(BindingError::WrongAddressSpace {
581                                binding: class,
582                                shader: self.class,
583                            });
584                        }
585                        min_binding_size
586                    }
587                    _ => {
588                        return Err(BindingError::WrongType {
589                            binding: (&entry.ty).into(),
590                            shader: (&self.ty).into(),
591                        })
592                    }
593                };
594                match min_size {
595                    Some(non_zero) if non_zero < size => {
596                        return Err(BindingError::WrongBufferSize {
597                            buffer_size: size,
598                            min_binding_size: non_zero,
599                        })
600                    }
601                    _ => (),
602                }
603            }
604            ResourceType::Sampler { comparison } => match entry.ty {
605                BindingType::Sampler(ty) => {
606                    if (ty == wgt::SamplerBindingType::Comparison) != comparison {
607                        return Err(BindingError::WrongSamplerComparison);
608                    }
609                }
610                _ => {
611                    return Err(BindingError::WrongType {
612                        binding: (&entry.ty).into(),
613                        shader: (&self.ty).into(),
614                    })
615                }
616            },
617            ResourceType::Texture {
618                dim,
619                arrayed,
620                class: shader_class,
621            } => {
622                let view_dimension = match entry.ty {
623                    BindingType::Texture { view_dimension, .. }
624                    | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
625                    BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
626                    _ => {
627                        return Err(BindingError::WrongTextureViewDimension {
628                            dim,
629                            is_array: false,
630                            binding: entry.ty,
631                        })
632                    }
633                };
634                if arrayed {
635                    match (dim, view_dimension) {
636                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
637                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
638                        _ => {
639                            return Err(BindingError::WrongTextureViewDimension {
640                                dim,
641                                is_array: true,
642                                binding: entry.ty,
643                            })
644                        }
645                    }
646                } else {
647                    match (dim, view_dimension) {
648                        (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
649                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
650                        (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
651                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
652                        _ => {
653                            return Err(BindingError::WrongTextureViewDimension {
654                                dim,
655                                is_array: false,
656                                binding: entry.ty,
657                            })
658                        }
659                    }
660                }
661                match entry.ty {
662                    BindingType::Texture {
663                        sample_type,
664                        view_dimension: _,
665                        multisampled: multi,
666                    } => {
667                        let binding_class = match sample_type {
668                            wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
669                                kind: naga::ScalarKind::Float,
670                                multi,
671                            },
672                            wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
673                                kind: naga::ScalarKind::Sint,
674                                multi,
675                            },
676                            wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
677                                kind: naga::ScalarKind::Uint,
678                                multi,
679                            },
680                            wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
681                        };
682                        if shader_class == binding_class {
683                            Ok(())
684                        } else {
685                            Err(binding_class)
686                        }
687                    }
688                    BindingType::StorageTexture {
689                        access: wgt_binding_access,
690                        format: wgt_binding_format,
691                        view_dimension: _,
692                    } => {
693                        const LOAD_STORE: naga::StorageAccess =
694                            naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
695                        let binding_format = map_storage_format_to_naga(wgt_binding_format)
696                            .ok_or(BindingError::BadStorageFormat(wgt_binding_format))?;
697                        let binding_access = match wgt_binding_access {
698                            wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
699                            wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
700                            wgt::StorageTextureAccess::ReadWrite => LOAD_STORE,
701                            wgt::StorageTextureAccess::Atomic => {
702                                naga::StorageAccess::ATOMIC | LOAD_STORE
703                            }
704                        };
705                        match shader_class {
706                            // Formats must match exactly. A write-only shader (but not a
707                            // read-only shader) is compatible with a read-write binding.
708                            naga::ImageClass::Storage {
709                                format: shader_format,
710                                access: shader_access,
711                            } if shader_format == binding_format
712                                && (shader_access == binding_access
713                                    || shader_access == naga::StorageAccess::STORE
714                                        && binding_access == LOAD_STORE) =>
715                            {
716                                Ok(())
717                            }
718                            _ => Err(naga::ImageClass::Storage {
719                                format: binding_format,
720                                access: binding_access,
721                            }),
722                        }
723                    }
724                    BindingType::ExternalTexture => {
725                        let binding_class = naga::ImageClass::External;
726                        if shader_class == binding_class {
727                            Ok(())
728                        } else {
729                            Err(binding_class)
730                        }
731                    }
732                    _ => {
733                        return Err(BindingError::WrongType {
734                            binding: (&entry.ty).into(),
735                            shader: (&self.ty).into(),
736                        })
737                    }
738                }
739                .map_err(|binding_class| BindingError::WrongTextureClass {
740                    binding: binding_class,
741                    shader: shader_class,
742                })?;
743            }
744            ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
745                BindingType::AccelerationStructure {
746                    vertex_return: entry_vertex_return,
747                } if vertex_return == entry_vertex_return => (),
748                _ => {
749                    return Err(BindingError::WrongType {
750                        binding: (&entry.ty).into(),
751                        shader: (&self.ty).into(),
752                    })
753                }
754            },
755        };
756
757        Ok(())
758    }
759
760    fn derive_binding_type(
761        &self,
762        is_reffed_by_sampler_in_entrypoint: bool,
763    ) -> Result<BindingType, BindingError> {
764        Ok(match self.ty {
765            ResourceType::Buffer { size } => BindingType::Buffer {
766                ty: match self.class {
767                    naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
768                    naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
769                        read_only: access == naga::StorageAccess::LOAD,
770                    },
771                    _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
772                },
773                has_dynamic_offset: false,
774                min_binding_size: Some(size),
775            },
776            ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
777                wgt::SamplerBindingType::Comparison
778            } else {
779                wgt::SamplerBindingType::Filtering
780            }),
781            ResourceType::Texture {
782                dim,
783                arrayed,
784                class,
785            } => {
786                let view_dimension = match dim {
787                    naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
788                    naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
789                    naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
790                    naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
791                    naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
792                    naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
793                };
794                match class {
795                    naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
796                        sample_type: match kind {
797                            naga::ScalarKind::Float => wgt::TextureSampleType::Float {
798                                filterable: is_reffed_by_sampler_in_entrypoint,
799                            },
800                            naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
801                            naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
802                            naga::ScalarKind::AbstractInt
803                            | naga::ScalarKind::AbstractFloat
804                            | naga::ScalarKind::Bool => unreachable!(),
805                        },
806                        view_dimension,
807                        multisampled: multi,
808                    },
809                    naga::ImageClass::Depth { multi } => BindingType::Texture {
810                        sample_type: wgt::TextureSampleType::Depth,
811                        view_dimension,
812                        multisampled: multi,
813                    },
814                    naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
815                        access: {
816                            const LOAD_STORE: naga::StorageAccess =
817                                naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
818                            match access {
819                                naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
820                                naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
821                                LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
822                                _ if access.contains(naga::StorageAccess::ATOMIC) => {
823                                    wgt::StorageTextureAccess::Atomic
824                                }
825                                _ => unreachable!(),
826                            }
827                        },
828                        view_dimension,
829                        format: {
830                            let f = map_storage_format_from_naga(format);
831                            let original = map_storage_format_to_naga(f)
832                                .ok_or(BindingError::BadStorageFormat(f))?;
833                            debug_assert_eq!(format, original);
834                            f
835                        },
836                    },
837                    naga::ImageClass::External => BindingType::ExternalTexture,
838                }
839            }
840            ResourceType::AccelerationStructure { vertex_return } => {
841                BindingType::AccelerationStructure { vertex_return }
842            }
843        })
844    }
845}
846
847impl NumericType {
848    fn from_vertex_format(format: wgt::VertexFormat) -> Self {
849        use naga::{Scalar, VectorSize as Vs};
850        use wgt::VertexFormat as Vf;
851
852        let (dim, scalar) = match format {
853            Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
854            Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
855                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
856            }
857            Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
858            Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
859                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
860            }
861            Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
862            Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
863                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
864            }
865            Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
866            Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
867                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
868            }
869            Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
870                (NumericDimension::Scalar, Scalar::F32)
871            }
872            Vf::Unorm8x2
873            | Vf::Snorm8x2
874            | Vf::Unorm16x2
875            | Vf::Snorm16x2
876            | Vf::Float16x2
877            | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
878            Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
879            Vf::Unorm8x4
880            | Vf::Snorm8x4
881            | Vf::Unorm16x4
882            | Vf::Snorm16x4
883            | Vf::Float16x4
884            | Vf::Float32x4
885            | Vf::Unorm10_10_10_2
886            | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
887            Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
888            Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
889            Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
890            Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
891        };
892
893        NumericType {
894            dim,
895            //Note: Shader always sees data as int, uint, or float.
896            // It doesn't know if the original is normalized in a tighter form.
897            scalar,
898        }
899    }
900
901    fn from_texture_format(format: wgt::TextureFormat) -> Self {
902        use naga::{Scalar, VectorSize as Vs};
903        use wgt::TextureFormat as Tf;
904
905        let (dim, scalar) = match format {
906            Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
907                (NumericDimension::Scalar, Scalar::F32)
908            }
909            Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
910            Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
911            Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
912                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
913            }
914            Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
915            Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
916                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
917            }
918            Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
919                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
920            }
921            Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
922            Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
923            Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
924            Tf::Rgba8Unorm
925            | Tf::Rgba8UnormSrgb
926            | Tf::Rgba8Snorm
927            | Tf::Bgra8Unorm
928            | Tf::Bgra8UnormSrgb
929            | Tf::Rgb10a2Unorm
930            | Tf::Rgba16Float
931            | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
932            Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
933                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
934            }
935            Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
936                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
937            }
938            Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
939            Tf::Stencil8
940            | Tf::Depth16Unorm
941            | Tf::Depth32Float
942            | Tf::Depth32FloatStencil8
943            | Tf::Depth24Plus
944            | Tf::Depth24PlusStencil8 => {
945                panic!("Unexpected depth format")
946            }
947            Tf::NV12 => panic!("Unexpected nv12 format"),
948            Tf::P010 => panic!("Unexpected p010 format"),
949            Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
950            Tf::Bc1RgbaUnorm
951            | Tf::Bc1RgbaUnormSrgb
952            | Tf::Bc2RgbaUnorm
953            | Tf::Bc2RgbaUnormSrgb
954            | Tf::Bc3RgbaUnorm
955            | Tf::Bc3RgbaUnormSrgb
956            | Tf::Bc7RgbaUnorm
957            | Tf::Bc7RgbaUnormSrgb
958            | Tf::Etc2Rgb8A1Unorm
959            | Tf::Etc2Rgb8A1UnormSrgb
960            | Tf::Etc2Rgba8Unorm
961            | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
962            Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
963                (NumericDimension::Scalar, Scalar::F32)
964            }
965            Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
966                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
967            }
968            Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
969                (NumericDimension::Vector(Vs::Tri), Scalar::F32)
970            }
971            Tf::Astc {
972                block: _,
973                channel: _,
974            } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
975        };
976
977        NumericType {
978            dim,
979            //Note: Shader always sees data as int, uint, or float.
980            // It doesn't know if the original is normalized in a tighter form.
981            scalar,
982        }
983    }
984
985    fn is_subtype_of(&self, other: &NumericType) -> bool {
986        if self.scalar.width > other.scalar.width {
987            return false;
988        }
989        if self.scalar.kind != other.scalar.kind {
990            return false;
991        }
992        match (self.dim, other.dim) {
993            (NumericDimension::Scalar, NumericDimension::Scalar) => true,
994            (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
995            (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
996            (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
997                c0 == c1 && r0 == r1
998            }
999            _ => false,
1000        }
1001    }
1002}
1003
1004/// Return true if the fragment `format` is covered by the provided `output`.
1005pub fn check_texture_format(
1006    format: wgt::TextureFormat,
1007    output: &NumericType,
1008) -> Result<(), NumericType> {
1009    let nt = NumericType::from_texture_format(format);
1010    if nt.is_subtype_of(output) {
1011        Ok(())
1012    } else {
1013        Err(nt)
1014    }
1015}
1016
1017pub enum BindingLayoutSource {
1018    /// The binding layout is derived from the pipeline layout.
1019    ///
1020    /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces.
1021    Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
1022    /// The binding layout is provided by the user in BGLs.
1023    ///
1024    /// This will be validated against the shader's interfaces.
1025    Provided(Arc<crate::binding_model::PipelineLayout>),
1026}
1027
1028impl BindingLayoutSource {
1029    pub fn new_derived(limits: &wgt::Limits) -> Self {
1030        let mut array = ArrayVec::new();
1031        for _ in 0..limits.max_bind_groups {
1032            array.push(Default::default());
1033        }
1034        BindingLayoutSource::Derived(Box::new(array))
1035    }
1036}
1037
1038#[derive(Debug, Clone, Default)]
1039pub struct StageIo {
1040    pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
1041    /// This must match between mesh & task shaders
1042    pub task_payload_size: Option<u32>,
1043    /// Fragment shaders cannot input primitive index on mesh shaders that don't output it on DX12.
1044    /// Therefore, we track between shader stages if primitive index is written (or if vertex shader
1045    /// is used).
1046    ///
1047    /// This is Some if it was a mesh shader.
1048    pub primitive_index: Option<bool>,
1049}
1050
1051impl Interface {
1052    fn populate(
1053        list: &mut Vec<Varying>,
1054        binding: Option<&naga::Binding>,
1055        ty: naga::Handle<naga::Type>,
1056        arena: &naga::UniqueArena<naga::Type>,
1057    ) {
1058        let numeric_ty = match arena[ty].inner {
1059            naga::TypeInner::Scalar(scalar) => NumericType {
1060                dim: NumericDimension::Scalar,
1061                scalar,
1062            },
1063            naga::TypeInner::Vector { size, scalar } => NumericType {
1064                dim: NumericDimension::Vector(size),
1065                scalar,
1066            },
1067            naga::TypeInner::Matrix {
1068                columns,
1069                rows,
1070                scalar,
1071            } => NumericType {
1072                dim: NumericDimension::Matrix(columns, rows),
1073                scalar,
1074            },
1075            naga::TypeInner::Struct { ref members, .. } => {
1076                for member in members {
1077                    Self::populate(list, member.binding.as_ref(), member.ty, arena);
1078                }
1079                return;
1080            }
1081            naga::TypeInner::Array { base, size, stride }
1082                if matches!(
1083                    binding,
1084                    Some(naga::Binding::BuiltIn(naga::BuiltIn::ClipDistances)),
1085                ) =>
1086            {
1087                // NOTE: We should already have validated these in `naga`.
1088                debug_assert_eq!(
1089                    &arena[base].inner,
1090                    &naga::TypeInner::Scalar(naga::Scalar::F32)
1091                );
1092                debug_assert_eq!(stride, 4);
1093
1094                let naga::ArraySize::Constant(array_size) = size else {
1095                    // NOTE: Based on the
1096                    // [spec](https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types):
1097                    //
1098                    // > The only valid use of a fixed-size array with an element count that is an
1099                    // > override-expression that is not a const-expression is as a memory view in
1100                    // > the workgroup address space.
1101                    unreachable!("non-constant array size for `clip_distances`")
1102                };
1103                let array_size = array_size.get();
1104
1105                list.push(Varying::BuiltIn(BuiltIn::ClipDistances { array_size }));
1106                return;
1107            }
1108            ref other => {
1109                //Note: technically this should be at least `log::error`, but
1110                // the reality is - every shader coming from `glslc` outputs an array
1111                // of clip distances and hits this path :(
1112                // So we lower it to `log::debug` to be less annoying as
1113                // there's nothing the user can do about it.
1114                log::debug!("Unexpected varying type: {other:?}");
1115                return;
1116            }
1117        };
1118
1119        let varying = match binding {
1120            Some(&naga::Binding::Location {
1121                location,
1122                interpolation,
1123                sampling,
1124                per_primitive,
1125                blend_src: _,
1126            }) => Varying::Local {
1127                location,
1128                iv: InterfaceVar {
1129                    ty: numeric_ty,
1130                    interpolation,
1131                    sampling,
1132                    per_primitive,
1133                },
1134            },
1135            Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(match built_in {
1136                naga::BuiltIn::Position { invariant } => BuiltIn::Position { invariant },
1137                naga::BuiltIn::ViewIndex => BuiltIn::ViewIndex,
1138                naga::BuiltIn::BaseInstance => BuiltIn::BaseInstance,
1139                naga::BuiltIn::BaseVertex => BuiltIn::BaseVertex,
1140                naga::BuiltIn::ClipDistances => unreachable!(),
1141                naga::BuiltIn::CullDistance => BuiltIn::CullDistance,
1142                naga::BuiltIn::InstanceIndex => BuiltIn::InstanceIndex,
1143                naga::BuiltIn::PointSize => BuiltIn::PointSize,
1144                naga::BuiltIn::VertexIndex => BuiltIn::VertexIndex,
1145                naga::BuiltIn::DrawIndex => BuiltIn::DrawIndex,
1146                naga::BuiltIn::FragDepth => BuiltIn::FragDepth,
1147                naga::BuiltIn::PointCoord => BuiltIn::PointCoord,
1148                naga::BuiltIn::FrontFacing => BuiltIn::FrontFacing,
1149                naga::BuiltIn::PrimitiveIndex => BuiltIn::PrimitiveIndex,
1150                naga::BuiltIn::Barycentric { perspective } => BuiltIn::Barycentric { perspective },
1151                naga::BuiltIn::SampleIndex => BuiltIn::SampleIndex,
1152                naga::BuiltIn::SampleMask => BuiltIn::SampleMask,
1153                naga::BuiltIn::GlobalInvocationId => BuiltIn::GlobalInvocationId,
1154                naga::BuiltIn::LocalInvocationId => BuiltIn::LocalInvocationId,
1155                naga::BuiltIn::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
1156                naga::BuiltIn::WorkGroupId => BuiltIn::WorkGroupId,
1157                naga::BuiltIn::WorkGroupSize => BuiltIn::WorkGroupSize,
1158                naga::BuiltIn::NumWorkGroups => BuiltIn::NumWorkGroups,
1159                naga::BuiltIn::NumSubgroups => BuiltIn::NumSubgroups,
1160                naga::BuiltIn::SubgroupId => BuiltIn::SubgroupId,
1161                naga::BuiltIn::SubgroupSize => BuiltIn::SubgroupSize,
1162                naga::BuiltIn::SubgroupInvocationId => BuiltIn::SubgroupInvocationId,
1163                naga::BuiltIn::MeshTaskSize => BuiltIn::MeshTaskSize,
1164                naga::BuiltIn::CullPrimitive => BuiltIn::CullPrimitive,
1165                naga::BuiltIn::PointIndex => BuiltIn::PointIndex,
1166                naga::BuiltIn::LineIndices => BuiltIn::LineIndices,
1167                naga::BuiltIn::TriangleIndices => BuiltIn::TriangleIndices,
1168                naga::BuiltIn::VertexCount => BuiltIn::VertexCount,
1169                naga::BuiltIn::Vertices => BuiltIn::Vertices,
1170                naga::BuiltIn::PrimitiveCount => BuiltIn::PrimitiveCount,
1171                naga::BuiltIn::Primitives => BuiltIn::Primitives,
1172                naga::BuiltIn::RayInvocationId => BuiltIn::RayInvocationId,
1173                naga::BuiltIn::NumRayInvocations => BuiltIn::NumRayInvocations,
1174                naga::BuiltIn::InstanceCustomData => BuiltIn::InstanceCustomData,
1175                naga::BuiltIn::GeometryIndex => BuiltIn::GeometryIndex,
1176                naga::BuiltIn::WorldRayOrigin => BuiltIn::WorldRayOrigin,
1177                naga::BuiltIn::WorldRayDirection => BuiltIn::WorldRayDirection,
1178                naga::BuiltIn::ObjectRayOrigin => BuiltIn::ObjectRayOrigin,
1179                naga::BuiltIn::ObjectRayDirection => BuiltIn::ObjectRayDirection,
1180                naga::BuiltIn::RayTmin => BuiltIn::RayTmin,
1181                naga::BuiltIn::RayTCurrentMax => BuiltIn::RayTCurrentMax,
1182                naga::BuiltIn::ObjectToWorld => BuiltIn::ObjectToWorld,
1183                naga::BuiltIn::WorldToObject => BuiltIn::WorldToObject,
1184                naga::BuiltIn::HitKind => BuiltIn::HitKind,
1185            }),
1186            None => {
1187                log::error!("Missing binding for a varying");
1188                return;
1189            }
1190        };
1191        list.push(varying);
1192    }
1193
1194    pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1195        let mut resources = naga::Arena::new();
1196        let mut resource_mapping = FastHashMap::default();
1197        for (var_handle, var) in module.global_variables.iter() {
1198            let bind = match var.binding {
1199                Some(br) => br,
1200                _ => continue,
1201            };
1202            let naga_ty = &module.types[var.ty].inner;
1203
1204            let inner_ty = match *naga_ty {
1205                naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1206                ref ty => ty,
1207            };
1208
1209            let ty = match *inner_ty {
1210                naga::TypeInner::Image {
1211                    dim,
1212                    arrayed,
1213                    class,
1214                } => ResourceType::Texture {
1215                    dim,
1216                    arrayed,
1217                    class,
1218                },
1219                naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1220                naga::TypeInner::AccelerationStructure { vertex_return } => {
1221                    ResourceType::AccelerationStructure { vertex_return }
1222                }
1223                ref other => ResourceType::Buffer {
1224                    size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1225                },
1226            };
1227            let handle = resources.append(
1228                Resource {
1229                    name: var.name.clone(),
1230                    bind,
1231                    ty,
1232                    class: var.space,
1233                },
1234                Default::default(),
1235            );
1236            resource_mapping.insert(var_handle, handle);
1237        }
1238
1239        let mut entry_points = FastHashMap::default();
1240        entry_points.reserve(module.entry_points.len());
1241        for (index, entry_point) in module.entry_points.iter().enumerate() {
1242            let info = info.get_entry_point(index);
1243            let mut ep = EntryPoint::default();
1244            for arg in entry_point.function.arguments.iter() {
1245                Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1246            }
1247            if let Some(ref result) = entry_point.function.result {
1248                Self::populate(
1249                    &mut ep.outputs,
1250                    result.binding.as_ref(),
1251                    result.ty,
1252                    &module.types,
1253                );
1254            }
1255
1256            for (var_handle, var) in module.global_variables.iter() {
1257                let usage = info[var_handle];
1258                if !usage.is_empty() && var.binding.is_some() {
1259                    ep.resources.push(resource_mapping[&var_handle]);
1260                }
1261            }
1262
1263            for key in info.sampling_set.iter() {
1264                ep.sampling_pairs
1265                    .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1266            }
1267            ep.dual_source_blending = info.dual_source_blending;
1268            ep.workgroup_size = entry_point.workgroup_size;
1269
1270            if let Some(task_payload) = entry_point.task_payload {
1271                ep.task_payload_size = Some(
1272                    module.types[module.global_variables[task_payload].ty]
1273                        .inner
1274                        .size(module.to_ctx()),
1275                );
1276            }
1277            if let Some(ref mesh_info) = entry_point.mesh_info {
1278                ep.mesh_info = Some(EntryPointMeshInfo {
1279                    max_vertices: mesh_info.max_vertices,
1280                    max_primitives: mesh_info.max_primitives,
1281                });
1282                Self::populate(
1283                    &mut ep.outputs,
1284                    None,
1285                    mesh_info.vertex_output_type,
1286                    &module.types,
1287                );
1288                Self::populate(
1289                    &mut ep.outputs,
1290                    None,
1291                    mesh_info.primitive_output_type,
1292                    &module.types,
1293                );
1294            }
1295
1296            entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1297        }
1298
1299        Self {
1300            limits,
1301            resources,
1302            entry_points,
1303        }
1304    }
1305
1306    pub fn finalize_entry_point_name(
1307        &self,
1308        stage: naga::ShaderStage,
1309        entry_point_name: Option<&str>,
1310    ) -> Result<String, StageError> {
1311        entry_point_name
1312            .map(|ep| ep.to_string())
1313            .map(Ok)
1314            .unwrap_or_else(|| {
1315                let mut entry_points = self
1316                    .entry_points
1317                    .keys()
1318                    .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1319                let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1320                if entry_points.next().is_some() {
1321                    return Err(StageError::MultipleEntryPointsFound);
1322                }
1323                Ok(first.clone())
1324            })
1325    }
1326
1327    /// Among other things, this implements some validation logic defined by the WebGPU spec. at
1328    /// <https://www.w3.org/TR/webgpu/#abstract-opdef-validating-inter-stage-interfaces>.
1329    pub fn check_stage(
1330        &self,
1331        layouts: &mut BindingLayoutSource,
1332        shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1333        entry_point_name: &str,
1334        shader_stage: ShaderStageForValidation,
1335        inputs: StageIo,
1336    ) -> Result<StageIo, StageError> {
1337        // Since a shader module can have multiple entry points with the same name,
1338        // we need to look for one with the right execution model.
1339        let pair = (shader_stage.to_naga(), entry_point_name.to_string());
1340        let entry_point = match self.entry_points.get(&pair) {
1341            Some(some) => some,
1342            None => return Err(StageError::MissingEntryPoint(pair.1)),
1343        };
1344        let (_, entry_point_name) = pair;
1345
1346        let stage_bit = shader_stage.to_wgt_bit();
1347
1348        // check resources visibility
1349        for &handle in entry_point.resources.iter() {
1350            let res = &self.resources[handle];
1351            let result = 'err: {
1352                match layouts {
1353                    BindingLayoutSource::Provided(pipeline_layout) => {
1354                        // update the required binding size for this buffer
1355                        if let ResourceType::Buffer { size } = res.ty {
1356                            match shader_binding_sizes.entry(res.bind) {
1357                                Entry::Occupied(e) => {
1358                                    *e.into_mut() = size.max(*e.get());
1359                                }
1360                                Entry::Vacant(e) => {
1361                                    e.insert(size);
1362                                }
1363                            }
1364                        }
1365
1366                        let Some(entry) =
1367                            pipeline_layout.get_bgl_entry(res.bind.group, res.bind.binding)
1368                        else {
1369                            break 'err Err(BindingError::Missing);
1370                        };
1371
1372                        if !entry.visibility.contains(stage_bit) {
1373                            break 'err Err(BindingError::Invisible);
1374                        }
1375
1376                        res.check_binding_use(entry)
1377                    }
1378                    BindingLayoutSource::Derived(layouts) => {
1379                        let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1380                            break 'err Err(BindingError::Missing);
1381                        };
1382
1383                        let ty = match res.derive_binding_type(
1384                            entry_point
1385                                .sampling_pairs
1386                                .iter()
1387                                .any(|&(im, _samp)| im == handle),
1388                        ) {
1389                            Ok(ty) => ty,
1390                            Err(error) => break 'err Err(error),
1391                        };
1392
1393                        match map.entry(res.bind.binding) {
1394                            indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1395                                break 'err Err(BindingError::InconsistentlyDerivedType)
1396                            }
1397                            indexmap::map::Entry::Occupied(e) => {
1398                                e.into_mut().visibility |= stage_bit;
1399                            }
1400                            indexmap::map::Entry::Vacant(e) => {
1401                                e.insert(BindGroupLayoutEntry {
1402                                    binding: res.bind.binding,
1403                                    ty,
1404                                    visibility: stage_bit,
1405                                    count: None,
1406                                });
1407                            }
1408                        }
1409                        Ok(())
1410                    }
1411                }
1412            };
1413            if let Err(error) = result {
1414                return Err(StageError::Binding(res.bind, error));
1415            }
1416        }
1417
1418        // Check the compatibility between textures and samplers
1419        //
1420        // We only need to do this if the binding layout is provided by the user, as derived
1421        // layouts will inherently be correctly tagged.
1422        if let BindingLayoutSource::Provided(pipeline_layout) = layouts {
1423            for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1424                let texture_bind = &self.resources[texture_handle].bind;
1425                let sampler_bind = &self.resources[sampler_handle].bind;
1426                let texture_layout = pipeline_layout
1427                    .get_bgl_entry(texture_bind.group, texture_bind.binding)
1428                    .unwrap();
1429                let sampler_layout = pipeline_layout
1430                    .get_bgl_entry(sampler_bind.group, sampler_bind.binding)
1431                    .unwrap();
1432                assert!(texture_layout.visibility.contains(stage_bit));
1433                assert!(sampler_layout.visibility.contains(stage_bit));
1434
1435                let sampler_filtering = matches!(
1436                    sampler_layout.ty,
1437                    BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1438                );
1439                let texture_sample_type = match texture_layout.ty {
1440                    BindingType::Texture { sample_type, .. } => sample_type,
1441                    BindingType::ExternalTexture => {
1442                        wgt::TextureSampleType::Float { filterable: true }
1443                    }
1444                    _ => unreachable!(),
1445                };
1446
1447                let error = match (sampler_filtering, texture_sample_type) {
1448                    (true, wgt::TextureSampleType::Float { filterable: false }) => {
1449                        Some(FilteringError::Float)
1450                    }
1451                    (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1452                    (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1453                    _ => None,
1454                };
1455
1456                if let Some(error) = error {
1457                    return Err(StageError::Filtering {
1458                        texture: *texture_bind,
1459                        sampler: *sampler_bind,
1460                        error,
1461                    });
1462                }
1463            }
1464        }
1465
1466        // check workgroup size limits
1467        if shader_stage.to_naga().compute_like() {
1468            let (
1469                max_workgroup_size_limits,
1470                max_workgroup_size_total,
1471                per_dimension_limit,
1472                total_limit,
1473            ) = match shader_stage.to_naga() {
1474                naga::ShaderStage::Compute => (
1475                    [
1476                        self.limits.max_compute_workgroup_size_x,
1477                        self.limits.max_compute_workgroup_size_y,
1478                        self.limits.max_compute_workgroup_size_z,
1479                    ],
1480                    self.limits.max_compute_invocations_per_workgroup,
1481                    "max_compute_workgroup_size_*",
1482                    "max_compute_invocations_per_workgroup",
1483                ),
1484                naga::ShaderStage::Task => (
1485                    [
1486                        self.limits.max_task_invocations_per_dimension,
1487                        self.limits.max_task_invocations_per_dimension,
1488                        self.limits.max_task_invocations_per_dimension,
1489                    ],
1490                    self.limits.max_task_invocations_per_workgroup,
1491                    "max_task_invocations_per_dimension",
1492                    "max_task_invocations_per_workgroup",
1493                ),
1494                naga::ShaderStage::Mesh => (
1495                    [
1496                        self.limits.max_mesh_invocations_per_dimension,
1497                        self.limits.max_mesh_invocations_per_dimension,
1498                        self.limits.max_mesh_invocations_per_dimension,
1499                    ],
1500                    self.limits.max_mesh_invocations_per_workgroup,
1501                    "max_mesh_invocations_per_dimension",
1502                    "max_mesh_invocations_per_workgroup",
1503                ),
1504                _ => unreachable!(),
1505            };
1506
1507            let total_invocations = entry_point
1508                .workgroup_size
1509                .iter()
1510                .fold(1u32, |total, &dim| total.saturating_mul(dim));
1511            let invalid_total_invocations =
1512                total_invocations > max_workgroup_size_total || total_invocations == 0;
1513
1514            assert_eq!(
1515                entry_point.workgroup_size.len(),
1516                max_workgroup_size_limits.len()
1517            );
1518            let dimension_too_large = entry_point
1519                .workgroup_size
1520                .iter()
1521                .zip(max_workgroup_size_limits.iter())
1522                .any(|(dim, limit)| dim > limit);
1523
1524            if invalid_total_invocations || dimension_too_large {
1525                return Err(StageError::InvalidWorkgroupSize {
1526                    dimensions: entry_point.workgroup_size,
1527                    total: total_invocations,
1528                    per_dimension_limits: max_workgroup_size_limits,
1529                    total_limit: max_workgroup_size_total,
1530                    per_dimension_limits_desc: per_dimension_limit,
1531                    total_limit_desc: total_limit,
1532                });
1533            }
1534        }
1535
1536        let mut this_stage_primitive_index = false;
1537        let mut has_draw_id = false;
1538
1539        // check inputs compatibility
1540        for input in entry_point.inputs.iter() {
1541            match *input {
1542                Varying::Local { location, ref iv } => {
1543                    let result = inputs
1544                        .varyings
1545                        .get(&location)
1546                        .ok_or(InputError::Missing)
1547                        .and_then(|provided| {
1548                            let (compatible, per_primitive_correct) = match shader_stage.to_naga() {
1549                                // For vertex attributes, there are defaults filled out
1550                                // by the driver if data is not provided.
1551                                naga::ShaderStage::Vertex => {
1552                                    let is_compatible =
1553                                        iv.ty.scalar.kind == provided.ty.scalar.kind;
1554                                    // vertex inputs don't count towards inter-stage
1555                                    (is_compatible, !iv.per_primitive)
1556                                }
1557                                naga::ShaderStage::Fragment => {
1558                                    if iv.interpolation != provided.interpolation {
1559                                        return Err(InputError::InterpolationMismatch(
1560                                            provided.interpolation,
1561                                        ));
1562                                    }
1563                                    if iv.sampling != provided.sampling {
1564                                        return Err(InputError::SamplingMismatch(
1565                                            provided.sampling,
1566                                        ));
1567                                    }
1568                                    (
1569                                        iv.ty.is_subtype_of(&provided.ty),
1570                                        iv.per_primitive == provided.per_primitive,
1571                                    )
1572                                }
1573                                // These can't have varying inputs
1574                                naga::ShaderStage::Compute
1575                                | naga::ShaderStage::Task
1576                                | naga::ShaderStage::Mesh => (false, false),
1577                                naga::ShaderStage::RayGeneration
1578                                | naga::ShaderStage::AnyHit
1579                                | naga::ShaderStage::ClosestHit
1580                                | naga::ShaderStage::Miss => {
1581                                    unreachable!()
1582                                }
1583                            };
1584                            if !compatible {
1585                                return Err(InputError::WrongType(provided.ty));
1586                            } else if !per_primitive_correct {
1587                                return Err(InputError::WrongPerPrimitive {
1588                                    pipeline_input: provided.per_primitive,
1589                                    shader: iv.per_primitive,
1590                                });
1591                            }
1592                            Ok(())
1593                        });
1594
1595                    if let Err(error) = result {
1596                        return Err(StageError::Input {
1597                            location,
1598                            var: iv.clone(),
1599                            error,
1600                        });
1601                    }
1602                }
1603                Varying::BuiltIn(BuiltIn::PrimitiveIndex) => {
1604                    this_stage_primitive_index = true;
1605                }
1606                Varying::BuiltIn(BuiltIn::DrawIndex) => {
1607                    has_draw_id = true;
1608                }
1609                Varying::BuiltIn(_) => {}
1610            }
1611        }
1612
1613        match shader_stage {
1614            ShaderStageForValidation::Vertex {
1615                topology,
1616                compare_function,
1617            } => {
1618                let mut max_vertex_shader_output_variables =
1619                    self.limits.max_inter_stage_shader_variables;
1620                let mut max_vertex_shader_output_location = max_vertex_shader_output_variables - 1;
1621
1622                let point_list_deduction = if topology == wgt::PrimitiveTopology::PointList {
1623                    Some(MaxVertexShaderOutputDeduction::PointListPrimitiveTopology)
1624                } else {
1625                    None
1626                };
1627
1628                let clip_distance_deductions = entry_point.outputs.iter().filter_map(|output| {
1629                    if let &Varying::BuiltIn(BuiltIn::ClipDistances { array_size }) = output {
1630                        Some(MaxVertexShaderOutputDeduction::ClipDistances { array_size })
1631                    } else {
1632                        None
1633                    }
1634                });
1635                debug_assert!(
1636                    clip_distance_deductions.clone().count() <= 1,
1637                    "multiple `clip_distances` outputs found"
1638                );
1639
1640                let deductions = point_list_deduction
1641                    .into_iter()
1642                    .chain(clip_distance_deductions);
1643
1644                for deduction in deductions.clone() {
1645                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1646                    // ever exceed the minimum variables available.
1647                    max_vertex_shader_output_variables = max_vertex_shader_output_variables
1648                        .checked_sub(deduction.for_variables())
1649                        .unwrap();
1650                    max_vertex_shader_output_location = max_vertex_shader_output_location
1651                        .checked_sub(deduction.for_location())
1652                        .unwrap();
1653                }
1654
1655                let mut num_user_defined_outputs = 0;
1656
1657                for output in entry_point.outputs.iter() {
1658                    match *output {
1659                        Varying::Local { ref iv, location } => {
1660                            if location > max_vertex_shader_output_location {
1661                                return Err(StageError::VertexOutputLocationTooLarge {
1662                                    location,
1663                                    var: iv.clone(),
1664                                    limit: self.limits.max_inter_stage_shader_variables,
1665                                    deductions: deductions.collect(),
1666                                });
1667                            }
1668                            num_user_defined_outputs += 1;
1669                        }
1670                        Varying::BuiltIn(_) => {}
1671                    };
1672
1673                    if let Some(
1674                        cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1675                    ) = compare_function
1676                    {
1677                        if let Varying::BuiltIn(BuiltIn::Position { invariant: false }) = *output {
1678                            log::warn!(
1679                                concat!(
1680                                    "Vertex shader with entry point {} outputs a ",
1681                                    "@builtin(position) without the @invariant attribute and ",
1682                                    "is used in a pipeline with {cmp:?}. On some machines, ",
1683                                    "this can cause bad artifacting as {cmp:?} assumes the ",
1684                                    "values output from the vertex shader exactly match the ",
1685                                    "value in the depth buffer. The @invariant attribute on the ",
1686                                    "@builtin(position) vertex output ensures that the exact ",
1687                                    "same pixel depths are used every render."
1688                                ),
1689                                entry_point_name,
1690                                cmp = cmp
1691                            );
1692                        }
1693                    }
1694                }
1695
1696                if num_user_defined_outputs > max_vertex_shader_output_variables {
1697                    return Err(StageError::TooManyUserDefinedVertexOutputs {
1698                        num_found: num_user_defined_outputs,
1699                        limit: self.limits.max_inter_stage_shader_variables,
1700                        deductions: deductions.collect(),
1701                    });
1702                }
1703            }
1704            ShaderStageForValidation::Fragment {
1705                dual_source_blending,
1706                has_depth_attachment,
1707            } => {
1708                let mut max_fragment_shader_input_variables =
1709                    self.limits.max_inter_stage_shader_variables;
1710
1711                let deductions = entry_point.inputs.iter().filter_map(|output| match output {
1712                    Varying::Local { .. } => None,
1713                    Varying::BuiltIn(builtin) => {
1714                        MaxFragmentShaderInputDeduction::from_inter_stage_builtin(builtin.to_naga())
1715                            .or_else(|| {
1716                                unreachable!(
1717                                    concat!(
1718                                        "unexpected built-in provided; ",
1719                                        "{:?} is not used for fragment stage input",
1720                                    ),
1721                                    builtin
1722                                )
1723                            })
1724                    }
1725                });
1726
1727                for deduction in deductions.clone() {
1728                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1729                    // ever exceed the minimum variables available.
1730                    max_fragment_shader_input_variables = max_fragment_shader_input_variables
1731                        .checked_sub(deduction.for_variables())
1732                        .unwrap();
1733                }
1734
1735                let mut num_user_defined_inputs = 0;
1736
1737                for output in entry_point.inputs.iter() {
1738                    match *output {
1739                        Varying::Local { ref iv, location } => {
1740                            if location >= self.limits.max_inter_stage_shader_variables {
1741                                return Err(StageError::FragmentInputLocationTooLarge {
1742                                    location,
1743                                    var: iv.clone(),
1744                                    limit: self.limits.max_inter_stage_shader_variables,
1745                                    deductions: deductions.collect(),
1746                                });
1747                            }
1748                            num_user_defined_inputs += 1;
1749                        }
1750                        Varying::BuiltIn(_) => {}
1751                    };
1752                }
1753
1754                if num_user_defined_inputs > max_fragment_shader_input_variables {
1755                    return Err(StageError::TooManyUserDefinedFragmentInputs {
1756                        num_found: num_user_defined_inputs,
1757                        limit: self.limits.max_inter_stage_shader_variables,
1758                        deductions: deductions.collect(),
1759                    });
1760                }
1761
1762                for output in &entry_point.outputs {
1763                    let &Varying::Local { location, ref iv } = output else {
1764                        continue;
1765                    };
1766                    if location >= self.limits.max_color_attachments {
1767                        return Err(StageError::ColorAttachmentLocationTooLarge {
1768                            location,
1769                            var: iv.clone(),
1770                            limit: self.limits.max_color_attachments,
1771                        });
1772                    }
1773                }
1774
1775                // If the pipeline uses dual-source blending, then the shader
1776                // must configure appropriate I/O, but it is not an error to
1777                // use a shader that defines the I/O in a pipeline that only
1778                // uses one blend source.
1779                if dual_source_blending && !entry_point.dual_source_blending {
1780                    return Err(StageError::InvalidDualSourceBlending);
1781                }
1782
1783                if entry_point
1784                    .outputs
1785                    .contains(&Varying::BuiltIn(BuiltIn::FragDepth))
1786                    && !has_depth_attachment
1787                {
1788                    return Err(StageError::MissingFragDepthAttachment);
1789                }
1790            }
1791            ShaderStageForValidation::Mesh => {
1792                for output in &entry_point.outputs {
1793                    if matches!(output, Varying::BuiltIn(BuiltIn::PrimitiveIndex)) {
1794                        this_stage_primitive_index = true;
1795                    }
1796                }
1797            }
1798            _ => (),
1799        }
1800
1801        if let Some(ref mesh_info) = entry_point.mesh_info {
1802            if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1803                return Err(StageError::TooManyMeshVertices {
1804                    limit: self.limits.max_mesh_output_vertices,
1805                    value: mesh_info.max_vertices,
1806                });
1807            }
1808            if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1809                return Err(StageError::TooManyMeshPrimitives {
1810                    limit: self.limits.max_mesh_output_primitives,
1811                    value: mesh_info.max_primitives,
1812                });
1813            }
1814        }
1815        if let Some(task_payload_size) = entry_point.task_payload_size {
1816            if task_payload_size > self.limits.max_task_payload_size {
1817                return Err(StageError::TaskPayloadTooLarge {
1818                    limit: self.limits.max_task_payload_size,
1819                    value: task_payload_size,
1820                });
1821            }
1822        }
1823        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1824            && entry_point.task_payload_size != inputs.task_payload_size
1825        {
1826            return Err(StageError::TaskPayloadMustMatch {
1827                input: inputs.task_payload_size,
1828                shader: entry_point.task_payload_size,
1829            });
1830        }
1831
1832        // Fragment shader primitive index is treated like a varying
1833        if shader_stage.to_naga() == naga::ShaderStage::Fragment
1834            && this_stage_primitive_index
1835            && inputs.primitive_index == Some(false)
1836        {
1837            return Err(StageError::InvalidPrimitiveIndex);
1838        } else if shader_stage.to_naga() == naga::ShaderStage::Fragment
1839            && !this_stage_primitive_index
1840            && inputs.primitive_index == Some(true)
1841        {
1842            return Err(StageError::MissingPrimitiveIndex);
1843        }
1844        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1845            && inputs.task_payload_size.is_some()
1846            && has_draw_id
1847        {
1848            return Err(StageError::DrawIdError);
1849        }
1850
1851        let outputs = entry_point
1852            .outputs
1853            .iter()
1854            .filter_map(|output| match *output {
1855                Varying::Local { location, ref iv } => Some((location, iv.clone())),
1856                Varying::BuiltIn(_) => None,
1857            })
1858            .collect();
1859
1860        Ok(StageIo {
1861            task_payload_size: entry_point.task_payload_size,
1862            varyings: outputs,
1863            primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
1864                Some(this_stage_primitive_index)
1865            } else {
1866                None
1867            },
1868        })
1869    }
1870
1871    pub fn fragment_uses_dual_source_blending(
1872        &self,
1873        entry_point_name: &str,
1874    ) -> Result<bool, StageError> {
1875        let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1876        self.entry_points
1877            .get(&pair)
1878            .ok_or(StageError::MissingEntryPoint(pair.1))
1879            .map(|ep| ep.dual_source_blending)
1880    }
1881}
1882
1883/// Validate a list of color attachment formats against `maxColorAttachmentBytesPerSample`.
1884///
1885/// The color attachments can be from a render pass descriptor or a pipeline descriptor.
1886///
1887/// Implements <https://gpuweb.github.io/gpuweb/#abstract-opdef-calculating-color-attachment-bytes-per-sample>.
1888pub fn validate_color_attachment_bytes_per_sample(
1889    attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1890    limit: u32,
1891) -> Result<(), crate::command::ColorAttachmentError> {
1892    let mut total_bytes_per_sample: u32 = 0;
1893    for format in attachment_formats {
1894        let byte_cost = format.target_pixel_byte_cost().unwrap();
1895        let alignment = format.target_component_alignment().unwrap();
1896
1897        total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1898        total_bytes_per_sample += byte_cost;
1899    }
1900
1901    if total_bytes_per_sample > limit {
1902        return Err(
1903            crate::command::ColorAttachmentError::TooManyBytesPerSample {
1904                total: total_bytes_per_sample,
1905                limit,
1906            },
1907        );
1908    }
1909
1910    Ok(())
1911}
1912
1913pub enum ShaderStageForValidation {
1914    Vertex {
1915        topology: wgt::PrimitiveTopology,
1916        compare_function: Option<wgt::CompareFunction>,
1917    },
1918    Mesh,
1919    Fragment {
1920        dual_source_blending: bool,
1921        has_depth_attachment: bool,
1922    },
1923    Compute,
1924    Task,
1925}
1926
1927impl ShaderStageForValidation {
1928    pub fn to_naga(&self) -> naga::ShaderStage {
1929        match self {
1930            Self::Vertex { .. } => naga::ShaderStage::Vertex,
1931            Self::Mesh => naga::ShaderStage::Mesh,
1932            Self::Fragment { .. } => naga::ShaderStage::Fragment,
1933            Self::Compute => naga::ShaderStage::Compute,
1934            Self::Task => naga::ShaderStage::Task,
1935        }
1936    }
1937
1938    pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
1939        match self {
1940            Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
1941            Self::Mesh => wgt::ShaderStages::MESH,
1942            Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
1943            Self::Compute => wgt::ShaderStages::COMPUTE,
1944            Self::Task => wgt::ShaderStages::TASK,
1945        }
1946    }
1947}