wgpu_core/
validation.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString as _},
4    sync::Arc,
5    vec::Vec,
6};
7use core::fmt;
8
9use arrayvec::ArrayVec;
10use hashbrown::hash_map::Entry;
11use shader_io_deductions::{display_deductions_as_optional_list, MaxVertexShaderOutputDeduction};
12use thiserror::Error;
13use wgt::{
14    error::{ErrorType, WebGpuError},
15    BindGroupLayoutEntry, BindingType,
16};
17
18use crate::{
19    device::bgl, resource::InvalidResourceError,
20    validation::shader_io_deductions::MaxFragmentShaderInputDeduction, FastHashMap, FastHashSet,
21};
22
23pub mod shader_io_deductions;
24
25#[derive(Debug)]
26enum ResourceType {
27    Buffer {
28        size: wgt::BufferSize,
29    },
30    Texture {
31        dim: naga::ImageDimension,
32        arrayed: bool,
33        class: naga::ImageClass,
34    },
35    Sampler {
36        comparison: bool,
37    },
38    AccelerationStructure {
39        vertex_return: bool,
40    },
41}
42
43#[derive(Clone, Debug)]
44pub enum BindingTypeName {
45    Buffer,
46    Texture,
47    Sampler,
48    AccelerationStructure,
49    ExternalTexture,
50}
51
52impl From<&ResourceType> for BindingTypeName {
53    fn from(ty: &ResourceType) -> BindingTypeName {
54        match ty {
55            ResourceType::Buffer { .. } => BindingTypeName::Buffer,
56            ResourceType::Texture {
57                class: naga::ImageClass::External,
58                ..
59            } => BindingTypeName::ExternalTexture,
60            ResourceType::Texture { .. } => BindingTypeName::Texture,
61            ResourceType::Sampler { .. } => BindingTypeName::Sampler,
62            ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
63        }
64    }
65}
66
67impl From<&BindingType> for BindingTypeName {
68    fn from(ty: &BindingType) -> BindingTypeName {
69        match ty {
70            BindingType::Buffer { .. } => BindingTypeName::Buffer,
71            BindingType::Texture { .. } => BindingTypeName::Texture,
72            BindingType::StorageTexture { .. } => BindingTypeName::Texture,
73            BindingType::Sampler { .. } => BindingTypeName::Sampler,
74            BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
75            BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
76        }
77    }
78}
79
80#[derive(Debug)]
81struct Resource {
82    #[allow(unused)]
83    name: Option<String>,
84    bind: naga::ResourceBinding,
85    ty: ResourceType,
86    class: naga::AddressSpace,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
90enum NumericDimension {
91    Scalar,
92    Vector(naga::VectorSize),
93    Matrix(naga::VectorSize, naga::VectorSize),
94}
95
96impl fmt::Display for NumericDimension {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        match *self {
99            Self::Scalar => write!(f, ""),
100            Self::Vector(size) => write!(f, "x{}", size as u8),
101            Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
102        }
103    }
104}
105
106#[derive(Clone, Copy, Debug, Eq, PartialEq)]
107pub struct NumericType {
108    dim: NumericDimension,
109    scalar: naga::Scalar,
110}
111
112impl fmt::Display for NumericType {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(
115            f,
116            "{:?}{}{}",
117            self.scalar.kind,
118            self.scalar.width * 8,
119            self.dim
120        )
121    }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct InterfaceVar {
126    pub ty: NumericType,
127    interpolation: Option<naga::Interpolation>,
128    sampling: Option<naga::Sampling>,
129    per_primitive: bool,
130}
131
132impl InterfaceVar {
133    pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
134        InterfaceVar {
135            ty: NumericType::from_vertex_format(format),
136            interpolation: None,
137            sampling: None,
138            per_primitive: false,
139        }
140    }
141}
142
143impl fmt::Display for InterfaceVar {
144    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145        write!(
146            f,
147            "{} interpolated as {:?} with sampling {:?}",
148            self.ty, self.interpolation, self.sampling
149        )
150    }
151}
152
153#[derive(Debug, Eq, PartialEq)]
154enum Varying {
155    Local { location: u32, iv: InterfaceVar },
156    BuiltIn(naga::BuiltIn),
157}
158
159#[allow(unused)]
160#[derive(Debug)]
161struct SpecializationConstant {
162    id: u32,
163    ty: NumericType,
164}
165
166#[derive(Debug)]
167struct EntryPointMeshInfo {
168    max_vertices: u32,
169    max_primitives: u32,
170}
171
172#[derive(Debug, Default)]
173struct EntryPoint {
174    inputs: Vec<Varying>,
175    outputs: Vec<Varying>,
176    resources: Vec<naga::Handle<Resource>>,
177    #[allow(unused)]
178    spec_constants: Vec<SpecializationConstant>,
179    sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
180    workgroup_size: [u32; 3],
181    dual_source_blending: bool,
182    task_payload_size: Option<u32>,
183    mesh_info: Option<EntryPointMeshInfo>,
184}
185
186#[derive(Debug)]
187pub struct Interface {
188    limits: wgt::Limits,
189    resources: naga::Arena<Resource>,
190    entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
191}
192
193#[derive(Clone, Debug, Error)]
194#[non_exhaustive]
195pub enum BindingError {
196    #[error("Binding is missing from the pipeline layout")]
197    Missing,
198    #[error("Visibility flags don't include the shader stage")]
199    Invisible,
200    #[error(
201        "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
202    )]
203    WrongType {
204        binding: BindingTypeName,
205        shader: BindingTypeName,
206    },
207    #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
208    WrongAddressSpace {
209        binding: naga::AddressSpace,
210        shader: naga::AddressSpace,
211    },
212    #[error("Address space {space:?} is not a valid Buffer address space")]
213    WrongBufferAddressSpace { space: naga::AddressSpace },
214    #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
215    WrongBufferSize {
216        buffer_size: wgt::BufferSize,
217        min_binding_size: wgt::BufferSize,
218    },
219    #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
220    WrongTextureViewDimension {
221        dim: naga::ImageDimension,
222        is_array: bool,
223        binding: BindingType,
224    },
225    #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
226    WrongTextureClass {
227        binding: naga::ImageClass,
228        shader: naga::ImageClass,
229    },
230    #[error("Comparison flag doesn't match the shader")]
231    WrongSamplerComparison,
232    #[error("Derived bind group layout type is not consistent between stages")]
233    InconsistentlyDerivedType,
234    #[error("Texture format {0:?} is not supported for storage use")]
235    BadStorageFormat(wgt::TextureFormat),
236}
237
238impl WebGpuError for BindingError {
239    fn webgpu_error_type(&self) -> ErrorType {
240        ErrorType::Validation
241    }
242}
243
244#[derive(Clone, Debug, Error)]
245#[non_exhaustive]
246pub enum FilteringError {
247    #[error("Integer textures can't be sampled with a filtering sampler")]
248    Integer,
249    #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
250    Float,
251}
252
253impl WebGpuError for FilteringError {
254    fn webgpu_error_type(&self) -> ErrorType {
255        ErrorType::Validation
256    }
257}
258
259#[derive(Clone, Debug, Error)]
260#[non_exhaustive]
261pub enum InputError {
262    #[error("Input is not provided by the earlier stage in the pipeline")]
263    Missing,
264    #[error("Input type is not compatible with the provided {0}")]
265    WrongType(NumericType),
266    #[error("Input interpolation doesn't match provided {0:?}")]
267    InterpolationMismatch(Option<naga::Interpolation>),
268    #[error("Input sampling doesn't match provided {0:?}")]
269    SamplingMismatch(Option<naga::Sampling>),
270    #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
271    WrongPerPrimitive { pipeline_input: bool, shader: bool },
272}
273
274impl WebGpuError for InputError {
275    fn webgpu_error_type(&self) -> ErrorType {
276        ErrorType::Validation
277    }
278}
279
280/// Errors produced when validating a programmable stage of a pipeline.
281#[derive(Clone, Debug, Error)]
282#[non_exhaustive]
283pub enum StageError {
284    #[error(
285        "Shader entry point's workgroup size {current:?} ({current_total} total invocations) must be less or equal to the per-dimension
286        limit `Limits::{per_dimension_limit}` of {limit:?} and the total invocation limit `Limits::{total_limit}` of {total}"
287    )]
288    InvalidWorkgroupSize {
289        current: [u32; 3],
290        current_total: u32,
291        limit: [u32; 3],
292        total: u32,
293        per_dimension_limit: &'static str,
294        total_limit: &'static str,
295    },
296    #[error("Unable to find entry point '{0}'")]
297    MissingEntryPoint(String),
298    #[error("Shader global {0:?} is not available in the pipeline layout")]
299    Binding(naga::ResourceBinding, #[source] BindingError),
300    #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
301    Filtering {
302        texture: naga::ResourceBinding,
303        sampler: naga::ResourceBinding,
304        #[source]
305        error: FilteringError,
306    },
307    #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
308    Input {
309        location: wgt::ShaderLocation,
310        var: InterfaceVar,
311        #[source]
312        error: InputError,
313    },
314    #[error(
315        "Unable to select an entry point: no entry point was found in the provided shader module"
316    )]
317    NoEntryPointFound,
318    #[error(
319        "Unable to select an entry point: \
320        multiple entry points were found in the provided shader module, \
321        but no entry point was specified"
322    )]
323    MultipleEntryPointsFound,
324    #[error(transparent)]
325    InvalidResource(#[from] InvalidResourceError),
326    #[error(
327        "vertex shader output location Location[{location}] ({var}) exceeds the \
328        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
329        // NOTE: Remember: the limit is 0-based for indices.
330        limit - 1,
331        display_deductions_as_optional_list(deductions, |d| d.for_location())
332    )]
333    VertexOutputLocationTooLarge {
334        location: u32,
335        var: InterfaceVar,
336        limit: u32,
337        deductions: Vec<MaxVertexShaderOutputDeduction>,
338    },
339    #[error(
340        "found {num_found} user-defined vertex shader output variables, which exceeds the \
341        `max_inter_stage_shader_variables` limit ({limit}){}",
342        display_deductions_as_optional_list(deductions, |d| d.for_variables())
343    )]
344    TooManyUserDefinedVertexOutputs {
345        num_found: u32,
346        limit: u32,
347        deductions: Vec<MaxVertexShaderOutputDeduction>,
348    },
349    #[error(
350        "fragment shader input location Location[{location}] ({var}) exceeds the \
351        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
352        // NOTE: Remember: the limit is 0-based for indices.
353        limit - 1,
354        // NOTE: WebGPU spec. validation for fragment inputs is expressed in terms of variables
355        // (unlike vertex outputs), so we use `MaxFragmentShaderInputDeduction::for_variables` here
356        // (and not a non-existent `for_locations`).
357        display_deductions_as_optional_list(deductions, |d| d.for_variables())
358    )]
359    FragmentInputLocationTooLarge {
360        location: u32,
361        var: InterfaceVar,
362        limit: u32,
363        deductions: Vec<MaxFragmentShaderInputDeduction>,
364    },
365    #[error(
366        "found {num_found} user-defined fragment shader input variables, which exceeds the \
367        `max_inter_stage_shader_variables` limit ({limit}){}",
368        display_deductions_as_optional_list(deductions, |d| d.for_variables())
369    )]
370    TooManyUserDefinedFragmentInputs {
371        num_found: u32,
372        limit: u32,
373        deductions: Vec<MaxFragmentShaderInputDeduction>,
374    },
375    #[error(
376        "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
377    )]
378    ColorAttachmentLocationTooLarge {
379        location: u32,
380        var: InterfaceVar,
381        limit: u32,
382    },
383    #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
384    TooManyMeshVertices { limit: u32, value: u32 },
385    #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
386    TooManyMeshPrimitives { limit: u32, value: u32 },
387    #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
388    TaskPayloadTooLarge { limit: u32, value: u32 },
389    #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
390    TaskPayloadMustMatch {
391        input: Option<u32>,
392        shader: Option<u32>,
393    },
394    #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
395    InvalidPrimitiveIndex,
396    #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
397    MissingPrimitiveIndex,
398    #[error("DrawId cannot be used in the same pipeline as a task shader")]
399    DrawIdError,
400    #[error("Pipeline uses dual-source blending, but the shader does not support it")]
401    InvalidDualSourceBlending,
402    #[error("Fragment shader writes depth, but pipeline does not have a depth attachment")]
403    MissingFragDepthAttachment,
404}
405
406impl WebGpuError for StageError {
407    fn webgpu_error_type(&self) -> ErrorType {
408        let e: &dyn WebGpuError = match self {
409            Self::Binding(_, e) => e,
410            Self::InvalidResource(e) => e,
411            Self::Filtering {
412                texture: _,
413                sampler: _,
414                error,
415            } => error,
416            Self::Input {
417                location: _,
418                var: _,
419                error,
420            } => error,
421            Self::InvalidWorkgroupSize { .. }
422            | Self::MissingEntryPoint(..)
423            | Self::NoEntryPointFound
424            | Self::MultipleEntryPointsFound
425            | Self::VertexOutputLocationTooLarge { .. }
426            | Self::TooManyUserDefinedVertexOutputs { .. }
427            | Self::FragmentInputLocationTooLarge { .. }
428            | Self::TooManyUserDefinedFragmentInputs { .. }
429            | Self::ColorAttachmentLocationTooLarge { .. }
430            | Self::TooManyMeshVertices { .. }
431            | Self::TooManyMeshPrimitives { .. }
432            | Self::TaskPayloadTooLarge { .. }
433            | Self::TaskPayloadMustMatch { .. }
434            | Self::InvalidPrimitiveIndex
435            | Self::MissingPrimitiveIndex
436            | Self::DrawIdError
437            | Self::InvalidDualSourceBlending
438            | Self::MissingFragDepthAttachment => return ErrorType::Validation,
439        };
440        e.webgpu_error_type()
441    }
442}
443
444pub fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> {
445    use naga::StorageFormat as Sf;
446    use wgt::TextureFormat as Tf;
447
448    Some(match format {
449        Tf::R8Unorm => Sf::R8Unorm,
450        Tf::R8Snorm => Sf::R8Snorm,
451        Tf::R8Uint => Sf::R8Uint,
452        Tf::R8Sint => Sf::R8Sint,
453
454        Tf::R16Uint => Sf::R16Uint,
455        Tf::R16Sint => Sf::R16Sint,
456        Tf::R16Float => Sf::R16Float,
457        Tf::Rg8Unorm => Sf::Rg8Unorm,
458        Tf::Rg8Snorm => Sf::Rg8Snorm,
459        Tf::Rg8Uint => Sf::Rg8Uint,
460        Tf::Rg8Sint => Sf::Rg8Sint,
461
462        Tf::R32Uint => Sf::R32Uint,
463        Tf::R32Sint => Sf::R32Sint,
464        Tf::R32Float => Sf::R32Float,
465        Tf::Rg16Uint => Sf::Rg16Uint,
466        Tf::Rg16Sint => Sf::Rg16Sint,
467        Tf::Rg16Float => Sf::Rg16Float,
468        Tf::Rgba8Unorm => Sf::Rgba8Unorm,
469        Tf::Rgba8Snorm => Sf::Rgba8Snorm,
470        Tf::Rgba8Uint => Sf::Rgba8Uint,
471        Tf::Rgba8Sint => Sf::Rgba8Sint,
472        Tf::Bgra8Unorm => Sf::Bgra8Unorm,
473
474        Tf::Rgb10a2Uint => Sf::Rgb10a2Uint,
475        Tf::Rgb10a2Unorm => Sf::Rgb10a2Unorm,
476        Tf::Rg11b10Ufloat => Sf::Rg11b10Ufloat,
477
478        Tf::R64Uint => Sf::R64Uint,
479        Tf::Rg32Uint => Sf::Rg32Uint,
480        Tf::Rg32Sint => Sf::Rg32Sint,
481        Tf::Rg32Float => Sf::Rg32Float,
482        Tf::Rgba16Uint => Sf::Rgba16Uint,
483        Tf::Rgba16Sint => Sf::Rgba16Sint,
484        Tf::Rgba16Float => Sf::Rgba16Float,
485
486        Tf::Rgba32Uint => Sf::Rgba32Uint,
487        Tf::Rgba32Sint => Sf::Rgba32Sint,
488        Tf::Rgba32Float => Sf::Rgba32Float,
489
490        Tf::R16Unorm => Sf::R16Unorm,
491        Tf::R16Snorm => Sf::R16Snorm,
492        Tf::Rg16Unorm => Sf::Rg16Unorm,
493        Tf::Rg16Snorm => Sf::Rg16Snorm,
494        Tf::Rgba16Unorm => Sf::Rgba16Unorm,
495        Tf::Rgba16Snorm => Sf::Rgba16Snorm,
496
497        _ => return None,
498    })
499}
500
501pub fn map_storage_format_from_naga(format: naga::StorageFormat) -> wgt::TextureFormat {
502    use naga::StorageFormat as Sf;
503    use wgt::TextureFormat as Tf;
504
505    match format {
506        Sf::R8Unorm => Tf::R8Unorm,
507        Sf::R8Snorm => Tf::R8Snorm,
508        Sf::R8Uint => Tf::R8Uint,
509        Sf::R8Sint => Tf::R8Sint,
510
511        Sf::R16Uint => Tf::R16Uint,
512        Sf::R16Sint => Tf::R16Sint,
513        Sf::R16Float => Tf::R16Float,
514        Sf::Rg8Unorm => Tf::Rg8Unorm,
515        Sf::Rg8Snorm => Tf::Rg8Snorm,
516        Sf::Rg8Uint => Tf::Rg8Uint,
517        Sf::Rg8Sint => Tf::Rg8Sint,
518
519        Sf::R32Uint => Tf::R32Uint,
520        Sf::R32Sint => Tf::R32Sint,
521        Sf::R32Float => Tf::R32Float,
522        Sf::Rg16Uint => Tf::Rg16Uint,
523        Sf::Rg16Sint => Tf::Rg16Sint,
524        Sf::Rg16Float => Tf::Rg16Float,
525        Sf::Rgba8Unorm => Tf::Rgba8Unorm,
526        Sf::Rgba8Snorm => Tf::Rgba8Snorm,
527        Sf::Rgba8Uint => Tf::Rgba8Uint,
528        Sf::Rgba8Sint => Tf::Rgba8Sint,
529        Sf::Bgra8Unorm => Tf::Bgra8Unorm,
530
531        Sf::Rgb10a2Uint => Tf::Rgb10a2Uint,
532        Sf::Rgb10a2Unorm => Tf::Rgb10a2Unorm,
533        Sf::Rg11b10Ufloat => Tf::Rg11b10Ufloat,
534
535        Sf::R64Uint => Tf::R64Uint,
536        Sf::Rg32Uint => Tf::Rg32Uint,
537        Sf::Rg32Sint => Tf::Rg32Sint,
538        Sf::Rg32Float => Tf::Rg32Float,
539        Sf::Rgba16Uint => Tf::Rgba16Uint,
540        Sf::Rgba16Sint => Tf::Rgba16Sint,
541        Sf::Rgba16Float => Tf::Rgba16Float,
542
543        Sf::Rgba32Uint => Tf::Rgba32Uint,
544        Sf::Rgba32Sint => Tf::Rgba32Sint,
545        Sf::Rgba32Float => Tf::Rgba32Float,
546
547        Sf::R16Unorm => Tf::R16Unorm,
548        Sf::R16Snorm => Tf::R16Snorm,
549        Sf::Rg16Unorm => Tf::Rg16Unorm,
550        Sf::Rg16Snorm => Tf::Rg16Snorm,
551        Sf::Rgba16Unorm => Tf::Rgba16Unorm,
552        Sf::Rgba16Snorm => Tf::Rgba16Snorm,
553    }
554}
555
556impl Resource {
557    fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
558        match self.ty {
559            ResourceType::Buffer { size } => {
560                let min_size = match entry.ty {
561                    BindingType::Buffer {
562                        ty,
563                        has_dynamic_offset: _,
564                        min_binding_size,
565                    } => {
566                        let class = match ty {
567                            wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
568                            wgt::BufferBindingType::Storage { read_only } => {
569                                let mut naga_access = naga::StorageAccess::LOAD;
570                                naga_access.set(naga::StorageAccess::STORE, !read_only);
571                                naga::AddressSpace::Storage {
572                                    access: naga_access,
573                                }
574                            }
575                        };
576                        if self.class != class {
577                            return Err(BindingError::WrongAddressSpace {
578                                binding: class,
579                                shader: self.class,
580                            });
581                        }
582                        min_binding_size
583                    }
584                    _ => {
585                        return Err(BindingError::WrongType {
586                            binding: (&entry.ty).into(),
587                            shader: (&self.ty).into(),
588                        })
589                    }
590                };
591                match min_size {
592                    Some(non_zero) if non_zero < size => {
593                        return Err(BindingError::WrongBufferSize {
594                            buffer_size: size,
595                            min_binding_size: non_zero,
596                        })
597                    }
598                    _ => (),
599                }
600            }
601            ResourceType::Sampler { comparison } => match entry.ty {
602                BindingType::Sampler(ty) => {
603                    if (ty == wgt::SamplerBindingType::Comparison) != comparison {
604                        return Err(BindingError::WrongSamplerComparison);
605                    }
606                }
607                _ => {
608                    return Err(BindingError::WrongType {
609                        binding: (&entry.ty).into(),
610                        shader: (&self.ty).into(),
611                    })
612                }
613            },
614            ResourceType::Texture {
615                dim,
616                arrayed,
617                class: shader_class,
618            } => {
619                let view_dimension = match entry.ty {
620                    BindingType::Texture { view_dimension, .. }
621                    | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
622                    BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
623                    _ => {
624                        return Err(BindingError::WrongTextureViewDimension {
625                            dim,
626                            is_array: false,
627                            binding: entry.ty,
628                        })
629                    }
630                };
631                if arrayed {
632                    match (dim, view_dimension) {
633                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
634                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
635                        _ => {
636                            return Err(BindingError::WrongTextureViewDimension {
637                                dim,
638                                is_array: true,
639                                binding: entry.ty,
640                            })
641                        }
642                    }
643                } else {
644                    match (dim, view_dimension) {
645                        (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
646                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
647                        (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
648                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
649                        _ => {
650                            return Err(BindingError::WrongTextureViewDimension {
651                                dim,
652                                is_array: false,
653                                binding: entry.ty,
654                            })
655                        }
656                    }
657                }
658                match entry.ty {
659                    BindingType::Texture {
660                        sample_type,
661                        view_dimension: _,
662                        multisampled: multi,
663                    } => {
664                        let binding_class = match sample_type {
665                            wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
666                                kind: naga::ScalarKind::Float,
667                                multi,
668                            },
669                            wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
670                                kind: naga::ScalarKind::Sint,
671                                multi,
672                            },
673                            wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
674                                kind: naga::ScalarKind::Uint,
675                                multi,
676                            },
677                            wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
678                        };
679                        if shader_class == binding_class {
680                            Ok(())
681                        } else {
682                            Err(binding_class)
683                        }
684                    }
685                    BindingType::StorageTexture {
686                        access: wgt_binding_access,
687                        format: wgt_binding_format,
688                        view_dimension: _,
689                    } => {
690                        const LOAD_STORE: naga::StorageAccess =
691                            naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
692                        let binding_format = map_storage_format_to_naga(wgt_binding_format)
693                            .ok_or(BindingError::BadStorageFormat(wgt_binding_format))?;
694                        let binding_access = match wgt_binding_access {
695                            wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
696                            wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
697                            wgt::StorageTextureAccess::ReadWrite => LOAD_STORE,
698                            wgt::StorageTextureAccess::Atomic => {
699                                naga::StorageAccess::ATOMIC | LOAD_STORE
700                            }
701                        };
702                        match shader_class {
703                            // Formats must match exactly. A write-only shader (but not a
704                            // read-only shader) is compatible with a read-write binding.
705                            naga::ImageClass::Storage {
706                                format: shader_format,
707                                access: shader_access,
708                            } if shader_format == binding_format
709                                && (shader_access == binding_access
710                                    || shader_access == naga::StorageAccess::STORE
711                                        && binding_access == LOAD_STORE) =>
712                            {
713                                Ok(())
714                            }
715                            _ => Err(naga::ImageClass::Storage {
716                                format: binding_format,
717                                access: binding_access,
718                            }),
719                        }
720                    }
721                    BindingType::ExternalTexture => {
722                        let binding_class = naga::ImageClass::External;
723                        if shader_class == binding_class {
724                            Ok(())
725                        } else {
726                            Err(binding_class)
727                        }
728                    }
729                    _ => {
730                        return Err(BindingError::WrongType {
731                            binding: (&entry.ty).into(),
732                            shader: (&self.ty).into(),
733                        })
734                    }
735                }
736                .map_err(|binding_class| BindingError::WrongTextureClass {
737                    binding: binding_class,
738                    shader: shader_class,
739                })?;
740            }
741            ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
742                BindingType::AccelerationStructure {
743                    vertex_return: entry_vertex_return,
744                } if vertex_return == entry_vertex_return => (),
745                _ => {
746                    return Err(BindingError::WrongType {
747                        binding: (&entry.ty).into(),
748                        shader: (&self.ty).into(),
749                    })
750                }
751            },
752        };
753
754        Ok(())
755    }
756
757    fn derive_binding_type(
758        &self,
759        is_reffed_by_sampler_in_entrypoint: bool,
760    ) -> Result<BindingType, BindingError> {
761        Ok(match self.ty {
762            ResourceType::Buffer { size } => BindingType::Buffer {
763                ty: match self.class {
764                    naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
765                    naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
766                        read_only: access == naga::StorageAccess::LOAD,
767                    },
768                    _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
769                },
770                has_dynamic_offset: false,
771                min_binding_size: Some(size),
772            },
773            ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
774                wgt::SamplerBindingType::Comparison
775            } else {
776                wgt::SamplerBindingType::Filtering
777            }),
778            ResourceType::Texture {
779                dim,
780                arrayed,
781                class,
782            } => {
783                let view_dimension = match dim {
784                    naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
785                    naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
786                    naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
787                    naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
788                    naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
789                    naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
790                };
791                match class {
792                    naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
793                        sample_type: match kind {
794                            naga::ScalarKind::Float => wgt::TextureSampleType::Float {
795                                filterable: is_reffed_by_sampler_in_entrypoint,
796                            },
797                            naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
798                            naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
799                            naga::ScalarKind::AbstractInt
800                            | naga::ScalarKind::AbstractFloat
801                            | naga::ScalarKind::Bool => unreachable!(),
802                        },
803                        view_dimension,
804                        multisampled: multi,
805                    },
806                    naga::ImageClass::Depth { multi } => BindingType::Texture {
807                        sample_type: wgt::TextureSampleType::Depth,
808                        view_dimension,
809                        multisampled: multi,
810                    },
811                    naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
812                        access: {
813                            const LOAD_STORE: naga::StorageAccess =
814                                naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
815                            match access {
816                                naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
817                                naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
818                                LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
819                                _ if access.contains(naga::StorageAccess::ATOMIC) => {
820                                    wgt::StorageTextureAccess::Atomic
821                                }
822                                _ => unreachable!(),
823                            }
824                        },
825                        view_dimension,
826                        format: {
827                            let f = map_storage_format_from_naga(format);
828                            let original = map_storage_format_to_naga(f)
829                                .ok_or(BindingError::BadStorageFormat(f))?;
830                            debug_assert_eq!(format, original);
831                            f
832                        },
833                    },
834                    naga::ImageClass::External => BindingType::ExternalTexture,
835                }
836            }
837            ResourceType::AccelerationStructure { vertex_return } => {
838                BindingType::AccelerationStructure { vertex_return }
839            }
840        })
841    }
842}
843
844impl NumericType {
845    fn from_vertex_format(format: wgt::VertexFormat) -> Self {
846        use naga::{Scalar, VectorSize as Vs};
847        use wgt::VertexFormat as Vf;
848
849        let (dim, scalar) = match format {
850            Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
851            Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
852                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
853            }
854            Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
855            Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
856                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
857            }
858            Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
859            Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
860                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
861            }
862            Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
863            Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
864                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
865            }
866            Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
867                (NumericDimension::Scalar, Scalar::F32)
868            }
869            Vf::Unorm8x2
870            | Vf::Snorm8x2
871            | Vf::Unorm16x2
872            | Vf::Snorm16x2
873            | Vf::Float16x2
874            | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
875            Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
876            Vf::Unorm8x4
877            | Vf::Snorm8x4
878            | Vf::Unorm16x4
879            | Vf::Snorm16x4
880            | Vf::Float16x4
881            | Vf::Float32x4
882            | Vf::Unorm10_10_10_2
883            | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
884            Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
885            Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
886            Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
887            Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
888        };
889
890        NumericType {
891            dim,
892            //Note: Shader always sees data as int, uint, or float.
893            // It doesn't know if the original is normalized in a tighter form.
894            scalar,
895        }
896    }
897
898    fn from_texture_format(format: wgt::TextureFormat) -> Self {
899        use naga::{Scalar, VectorSize as Vs};
900        use wgt::TextureFormat as Tf;
901
902        let (dim, scalar) = match format {
903            Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
904                (NumericDimension::Scalar, Scalar::F32)
905            }
906            Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
907            Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
908            Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
909                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
910            }
911            Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
912            Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
913                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
914            }
915            Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
916                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
917            }
918            Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
919            Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
920            Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
921            Tf::Rgba8Unorm
922            | Tf::Rgba8UnormSrgb
923            | Tf::Rgba8Snorm
924            | Tf::Bgra8Unorm
925            | Tf::Bgra8UnormSrgb
926            | Tf::Rgb10a2Unorm
927            | Tf::Rgba16Float
928            | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
929            Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
930                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
931            }
932            Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
933                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
934            }
935            Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
936            Tf::Stencil8
937            | Tf::Depth16Unorm
938            | Tf::Depth32Float
939            | Tf::Depth32FloatStencil8
940            | Tf::Depth24Plus
941            | Tf::Depth24PlusStencil8 => {
942                panic!("Unexpected depth format")
943            }
944            Tf::NV12 => panic!("Unexpected nv12 format"),
945            Tf::P010 => panic!("Unexpected p010 format"),
946            Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
947            Tf::Bc1RgbaUnorm
948            | Tf::Bc1RgbaUnormSrgb
949            | Tf::Bc2RgbaUnorm
950            | Tf::Bc2RgbaUnormSrgb
951            | Tf::Bc3RgbaUnorm
952            | Tf::Bc3RgbaUnormSrgb
953            | Tf::Bc7RgbaUnorm
954            | Tf::Bc7RgbaUnormSrgb
955            | Tf::Etc2Rgb8A1Unorm
956            | Tf::Etc2Rgb8A1UnormSrgb
957            | Tf::Etc2Rgba8Unorm
958            | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
959            Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
960                (NumericDimension::Scalar, Scalar::F32)
961            }
962            Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
963                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
964            }
965            Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
966                (NumericDimension::Vector(Vs::Tri), Scalar::F32)
967            }
968            Tf::Astc {
969                block: _,
970                channel: _,
971            } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
972        };
973
974        NumericType {
975            dim,
976            //Note: Shader always sees data as int, uint, or float.
977            // It doesn't know if the original is normalized in a tighter form.
978            scalar,
979        }
980    }
981
982    fn is_subtype_of(&self, other: &NumericType) -> bool {
983        if self.scalar.width > other.scalar.width {
984            return false;
985        }
986        if self.scalar.kind != other.scalar.kind {
987            return false;
988        }
989        match (self.dim, other.dim) {
990            (NumericDimension::Scalar, NumericDimension::Scalar) => true,
991            (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
992            (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
993            (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
994                c0 == c1 && r0 == r1
995            }
996            _ => false,
997        }
998    }
999}
1000
1001/// Return true if the fragment `format` is covered by the provided `output`.
1002pub fn check_texture_format(
1003    format: wgt::TextureFormat,
1004    output: &NumericType,
1005) -> Result<(), NumericType> {
1006    let nt = NumericType::from_texture_format(format);
1007    if nt.is_subtype_of(output) {
1008        Ok(())
1009    } else {
1010        Err(nt)
1011    }
1012}
1013
1014pub enum BindingLayoutSource {
1015    /// The binding layout is derived from the pipeline layout.
1016    ///
1017    /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces.
1018    Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
1019    /// The binding layout is provided by the user in BGLs.
1020    ///
1021    /// This will be validated against the shader's interfaces.
1022    Provided(Arc<crate::binding_model::PipelineLayout>),
1023}
1024
1025impl BindingLayoutSource {
1026    pub fn new_derived(limits: &wgt::Limits) -> Self {
1027        let mut array = ArrayVec::new();
1028        for _ in 0..limits.max_bind_groups {
1029            array.push(Default::default());
1030        }
1031        BindingLayoutSource::Derived(Box::new(array))
1032    }
1033}
1034
1035#[derive(Debug, Clone, Default)]
1036pub struct StageIo {
1037    pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
1038    /// This must match between mesh & task shaders
1039    pub task_payload_size: Option<u32>,
1040    /// Fragment shaders cannot input primitive index on mesh shaders that don't output it on DX12.
1041    /// Therefore, we track between shader stages if primitive index is written (or if vertex shader
1042    /// is used).
1043    ///
1044    /// This is Some if it was a mesh shader.
1045    pub primitive_index: Option<bool>,
1046}
1047
1048impl Interface {
1049    fn populate(
1050        list: &mut Vec<Varying>,
1051        binding: Option<&naga::Binding>,
1052        ty: naga::Handle<naga::Type>,
1053        arena: &naga::UniqueArena<naga::Type>,
1054    ) {
1055        let numeric_ty = match arena[ty].inner {
1056            naga::TypeInner::Scalar(scalar) => NumericType {
1057                dim: NumericDimension::Scalar,
1058                scalar,
1059            },
1060            naga::TypeInner::Vector { size, scalar } => NumericType {
1061                dim: NumericDimension::Vector(size),
1062                scalar,
1063            },
1064            naga::TypeInner::Matrix {
1065                columns,
1066                rows,
1067                scalar,
1068            } => NumericType {
1069                dim: NumericDimension::Matrix(columns, rows),
1070                scalar,
1071            },
1072            naga::TypeInner::Struct { ref members, .. } => {
1073                for member in members {
1074                    Self::populate(list, member.binding.as_ref(), member.ty, arena);
1075                }
1076                return;
1077            }
1078            ref other => {
1079                //Note: technically this should be at least `log::error`, but
1080                // the reality is - every shader coming from `glslc` outputs an array
1081                // of clip distances and hits this path :(
1082                // So we lower it to `log::debug` to be less annoying as
1083                // there's nothing the user can do about it.
1084                log::debug!("Unexpected varying type: {other:?}");
1085                return;
1086            }
1087        };
1088
1089        let varying = match binding {
1090            Some(&naga::Binding::Location {
1091                location,
1092                interpolation,
1093                sampling,
1094                per_primitive,
1095                blend_src: _,
1096            }) => Varying::Local {
1097                location,
1098                iv: InterfaceVar {
1099                    ty: numeric_ty,
1100                    interpolation,
1101                    sampling,
1102                    per_primitive,
1103                },
1104            },
1105            Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(built_in),
1106            None => {
1107                log::error!("Missing binding for a varying");
1108                return;
1109            }
1110        };
1111        list.push(varying);
1112    }
1113
1114    pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1115        let mut resources = naga::Arena::new();
1116        let mut resource_mapping = FastHashMap::default();
1117        for (var_handle, var) in module.global_variables.iter() {
1118            let bind = match var.binding {
1119                Some(br) => br,
1120                _ => continue,
1121            };
1122            let naga_ty = &module.types[var.ty].inner;
1123
1124            let inner_ty = match *naga_ty {
1125                naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1126                ref ty => ty,
1127            };
1128
1129            let ty = match *inner_ty {
1130                naga::TypeInner::Image {
1131                    dim,
1132                    arrayed,
1133                    class,
1134                } => ResourceType::Texture {
1135                    dim,
1136                    arrayed,
1137                    class,
1138                },
1139                naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1140                naga::TypeInner::AccelerationStructure { vertex_return } => {
1141                    ResourceType::AccelerationStructure { vertex_return }
1142                }
1143                ref other => ResourceType::Buffer {
1144                    size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1145                },
1146            };
1147            let handle = resources.append(
1148                Resource {
1149                    name: var.name.clone(),
1150                    bind,
1151                    ty,
1152                    class: var.space,
1153                },
1154                Default::default(),
1155            );
1156            resource_mapping.insert(var_handle, handle);
1157        }
1158
1159        let mut entry_points = FastHashMap::default();
1160        entry_points.reserve(module.entry_points.len());
1161        for (index, entry_point) in module.entry_points.iter().enumerate() {
1162            let info = info.get_entry_point(index);
1163            let mut ep = EntryPoint::default();
1164            for arg in entry_point.function.arguments.iter() {
1165                Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1166            }
1167            if let Some(ref result) = entry_point.function.result {
1168                Self::populate(
1169                    &mut ep.outputs,
1170                    result.binding.as_ref(),
1171                    result.ty,
1172                    &module.types,
1173                );
1174            }
1175
1176            for (var_handle, var) in module.global_variables.iter() {
1177                let usage = info[var_handle];
1178                if !usage.is_empty() && var.binding.is_some() {
1179                    ep.resources.push(resource_mapping[&var_handle]);
1180                }
1181            }
1182
1183            for key in info.sampling_set.iter() {
1184                ep.sampling_pairs
1185                    .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1186            }
1187            ep.dual_source_blending = info.dual_source_blending;
1188            ep.workgroup_size = entry_point.workgroup_size;
1189
1190            if let Some(task_payload) = entry_point.task_payload {
1191                ep.task_payload_size = Some(
1192                    module.types[module.global_variables[task_payload].ty]
1193                        .inner
1194                        .size(module.to_ctx()),
1195                );
1196            }
1197            if let Some(ref mesh_info) = entry_point.mesh_info {
1198                ep.mesh_info = Some(EntryPointMeshInfo {
1199                    max_vertices: mesh_info.max_vertices,
1200                    max_primitives: mesh_info.max_primitives,
1201                });
1202                Self::populate(
1203                    &mut ep.outputs,
1204                    None,
1205                    mesh_info.vertex_output_type,
1206                    &module.types,
1207                );
1208                Self::populate(
1209                    &mut ep.outputs,
1210                    None,
1211                    mesh_info.primitive_output_type,
1212                    &module.types,
1213                );
1214            }
1215
1216            entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1217        }
1218
1219        Self {
1220            limits,
1221            resources,
1222            entry_points,
1223        }
1224    }
1225
1226    pub fn finalize_entry_point_name(
1227        &self,
1228        stage: naga::ShaderStage,
1229        entry_point_name: Option<&str>,
1230    ) -> Result<String, StageError> {
1231        entry_point_name
1232            .map(|ep| ep.to_string())
1233            .map(Ok)
1234            .unwrap_or_else(|| {
1235                let mut entry_points = self
1236                    .entry_points
1237                    .keys()
1238                    .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1239                let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1240                if entry_points.next().is_some() {
1241                    return Err(StageError::MultipleEntryPointsFound);
1242                }
1243                Ok(first.clone())
1244            })
1245    }
1246
1247    /// Among other things, this implements some validation logic defined by the WebGPU spec. at
1248    /// <https://www.w3.org/TR/webgpu/#abstract-opdef-validating-inter-stage-interfaces>.
1249    pub fn check_stage(
1250        &self,
1251        layouts: &mut BindingLayoutSource,
1252        shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1253        entry_point_name: &str,
1254        shader_stage: ShaderStageForValidation,
1255        inputs: StageIo,
1256    ) -> Result<StageIo, StageError> {
1257        // Since a shader module can have multiple entry points with the same name,
1258        // we need to look for one with the right execution model.
1259        let pair = (shader_stage.to_naga(), entry_point_name.to_string());
1260        let entry_point = match self.entry_points.get(&pair) {
1261            Some(some) => some,
1262            None => return Err(StageError::MissingEntryPoint(pair.1)),
1263        };
1264        let (_, entry_point_name) = pair;
1265
1266        let stage_bit = shader_stage.to_wgt_bit();
1267
1268        // check resources visibility
1269        for &handle in entry_point.resources.iter() {
1270            let res = &self.resources[handle];
1271            let result = 'err: {
1272                match layouts {
1273                    BindingLayoutSource::Provided(pipeline_layout) => {
1274                        // update the required binding size for this buffer
1275                        if let ResourceType::Buffer { size } = res.ty {
1276                            match shader_binding_sizes.entry(res.bind) {
1277                                Entry::Occupied(e) => {
1278                                    *e.into_mut() = size.max(*e.get());
1279                                }
1280                                Entry::Vacant(e) => {
1281                                    e.insert(size);
1282                                }
1283                            }
1284                        }
1285
1286                        let Some(entry) =
1287                            pipeline_layout.get_bgl_entry(res.bind.group, res.bind.binding)
1288                        else {
1289                            break 'err Err(BindingError::Missing);
1290                        };
1291
1292                        if !entry.visibility.contains(stage_bit) {
1293                            break 'err Err(BindingError::Invisible);
1294                        }
1295
1296                        res.check_binding_use(entry)
1297                    }
1298                    BindingLayoutSource::Derived(layouts) => {
1299                        let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1300                            break 'err Err(BindingError::Missing);
1301                        };
1302
1303                        let ty = match res.derive_binding_type(
1304                            entry_point
1305                                .sampling_pairs
1306                                .iter()
1307                                .any(|&(im, _samp)| im == handle),
1308                        ) {
1309                            Ok(ty) => ty,
1310                            Err(error) => break 'err Err(error),
1311                        };
1312
1313                        match map.entry(res.bind.binding) {
1314                            indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1315                                break 'err Err(BindingError::InconsistentlyDerivedType)
1316                            }
1317                            indexmap::map::Entry::Occupied(e) => {
1318                                e.into_mut().visibility |= stage_bit;
1319                            }
1320                            indexmap::map::Entry::Vacant(e) => {
1321                                e.insert(BindGroupLayoutEntry {
1322                                    binding: res.bind.binding,
1323                                    ty,
1324                                    visibility: stage_bit,
1325                                    count: None,
1326                                });
1327                            }
1328                        }
1329                        Ok(())
1330                    }
1331                }
1332            };
1333            if let Err(error) = result {
1334                return Err(StageError::Binding(res.bind, error));
1335            }
1336        }
1337
1338        // Check the compatibility between textures and samplers
1339        //
1340        // We only need to do this if the binding layout is provided by the user, as derived
1341        // layouts will inherently be correctly tagged.
1342        if let BindingLayoutSource::Provided(pipeline_layout) = layouts {
1343            for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1344                let texture_bind = &self.resources[texture_handle].bind;
1345                let sampler_bind = &self.resources[sampler_handle].bind;
1346                let texture_layout = pipeline_layout
1347                    .get_bgl_entry(texture_bind.group, texture_bind.binding)
1348                    .unwrap();
1349                let sampler_layout = pipeline_layout
1350                    .get_bgl_entry(sampler_bind.group, sampler_bind.binding)
1351                    .unwrap();
1352                assert!(texture_layout.visibility.contains(stage_bit));
1353                assert!(sampler_layout.visibility.contains(stage_bit));
1354
1355                let sampler_filtering = matches!(
1356                    sampler_layout.ty,
1357                    BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1358                );
1359                let texture_sample_type = match texture_layout.ty {
1360                    BindingType::Texture { sample_type, .. } => sample_type,
1361                    BindingType::ExternalTexture => {
1362                        wgt::TextureSampleType::Float { filterable: true }
1363                    }
1364                    _ => unreachable!(),
1365                };
1366
1367                let error = match (sampler_filtering, texture_sample_type) {
1368                    (true, wgt::TextureSampleType::Float { filterable: false }) => {
1369                        Some(FilteringError::Float)
1370                    }
1371                    (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1372                    (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1373                    _ => None,
1374                };
1375
1376                if let Some(error) = error {
1377                    return Err(StageError::Filtering {
1378                        texture: *texture_bind,
1379                        sampler: *sampler_bind,
1380                        error,
1381                    });
1382                }
1383            }
1384        }
1385
1386        // check workgroup size limits
1387        if shader_stage.to_naga().compute_like() {
1388            let (
1389                max_workgroup_size_limits,
1390                max_workgroup_size_total,
1391                per_dimension_limit,
1392                total_limit,
1393            ) = match shader_stage.to_naga() {
1394                naga::ShaderStage::Compute => (
1395                    [
1396                        self.limits.max_compute_workgroup_size_x,
1397                        self.limits.max_compute_workgroup_size_y,
1398                        self.limits.max_compute_workgroup_size_z,
1399                    ],
1400                    self.limits.max_compute_invocations_per_workgroup,
1401                    "max_compute_workgroup_size_*",
1402                    "max_compute_invocations_per_workgroup",
1403                ),
1404                naga::ShaderStage::Task => (
1405                    [
1406                        self.limits.max_task_invocations_per_dimension,
1407                        self.limits.max_task_invocations_per_dimension,
1408                        self.limits.max_task_invocations_per_dimension,
1409                    ],
1410                    self.limits.max_task_invocations_per_workgroup,
1411                    "max_task_invocations_per_dimension",
1412                    "max_task_invocations_per_workgroup",
1413                ),
1414                naga::ShaderStage::Mesh => (
1415                    [
1416                        self.limits.max_mesh_invocations_per_dimension,
1417                        self.limits.max_mesh_invocations_per_dimension,
1418                        self.limits.max_mesh_invocations_per_dimension,
1419                    ],
1420                    self.limits.max_mesh_invocations_per_workgroup,
1421                    "max_mesh_invocations_per_dimension",
1422                    "max_mesh_invocations_per_workgroup",
1423                ),
1424                _ => unreachable!(),
1425            };
1426            let total_invocations = entry_point.workgroup_size.iter().product::<u32>();
1427
1428            let workgroup_size_is_zero = entry_point.workgroup_size.contains(&0);
1429            let too_many_invocations = total_invocations > max_workgroup_size_total;
1430            let dimension_too_large = entry_point.workgroup_size[0] > max_workgroup_size_limits[0]
1431                || entry_point.workgroup_size[1] > max_workgroup_size_limits[1]
1432                || entry_point.workgroup_size[2] > max_workgroup_size_limits[2];
1433            if workgroup_size_is_zero || too_many_invocations || dimension_too_large {
1434                return Err(StageError::InvalidWorkgroupSize {
1435                    current: entry_point.workgroup_size,
1436                    current_total: total_invocations,
1437                    limit: max_workgroup_size_limits,
1438                    total: max_workgroup_size_total,
1439                    per_dimension_limit,
1440                    total_limit,
1441                });
1442            }
1443        }
1444
1445        let mut this_stage_primitive_index = false;
1446        let mut has_draw_id = false;
1447
1448        // check inputs compatibility
1449        for input in entry_point.inputs.iter() {
1450            match *input {
1451                Varying::Local { location, ref iv } => {
1452                    let result = inputs
1453                        .varyings
1454                        .get(&location)
1455                        .ok_or(InputError::Missing)
1456                        .and_then(|provided| {
1457                            let (compatible, per_primitive_correct) = match shader_stage.to_naga() {
1458                                // For vertex attributes, there are defaults filled out
1459                                // by the driver if data is not provided.
1460                                naga::ShaderStage::Vertex => {
1461                                    let is_compatible =
1462                                        iv.ty.scalar.kind == provided.ty.scalar.kind;
1463                                    // vertex inputs don't count towards inter-stage
1464                                    (is_compatible, !iv.per_primitive)
1465                                }
1466                                naga::ShaderStage::Fragment => {
1467                                    if iv.interpolation != provided.interpolation {
1468                                        return Err(InputError::InterpolationMismatch(
1469                                            provided.interpolation,
1470                                        ));
1471                                    }
1472                                    if iv.sampling != provided.sampling {
1473                                        return Err(InputError::SamplingMismatch(
1474                                            provided.sampling,
1475                                        ));
1476                                    }
1477                                    (
1478                                        iv.ty.is_subtype_of(&provided.ty),
1479                                        iv.per_primitive == provided.per_primitive,
1480                                    )
1481                                }
1482                                // These can't have varying inputs
1483                                naga::ShaderStage::Compute
1484                                | naga::ShaderStage::Task
1485                                | naga::ShaderStage::Mesh => (false, false),
1486                                naga::ShaderStage::RayGeneration
1487                                | naga::ShaderStage::AnyHit
1488                                | naga::ShaderStage::ClosestHit
1489                                | naga::ShaderStage::Miss => {
1490                                    unreachable!()
1491                                }
1492                            };
1493                            if !compatible {
1494                                return Err(InputError::WrongType(provided.ty));
1495                            } else if !per_primitive_correct {
1496                                return Err(InputError::WrongPerPrimitive {
1497                                    pipeline_input: provided.per_primitive,
1498                                    shader: iv.per_primitive,
1499                                });
1500                            }
1501                            Ok(())
1502                        });
1503
1504                    if let Err(error) = result {
1505                        return Err(StageError::Input {
1506                            location,
1507                            var: iv.clone(),
1508                            error,
1509                        });
1510                    }
1511                }
1512                Varying::BuiltIn(naga::BuiltIn::PrimitiveIndex) => {
1513                    this_stage_primitive_index = true;
1514                }
1515                Varying::BuiltIn(naga::BuiltIn::DrawIndex) => {
1516                    has_draw_id = true;
1517                }
1518                Varying::BuiltIn(_) => {}
1519            }
1520        }
1521
1522        match shader_stage {
1523            ShaderStageForValidation::Vertex {
1524                topology,
1525                compare_function,
1526            } => {
1527                let mut max_vertex_shader_output_variables =
1528                    self.limits.max_inter_stage_shader_variables;
1529                let mut max_vertex_shader_output_location = max_vertex_shader_output_variables - 1;
1530
1531                let point_list_deduction = if topology == wgt::PrimitiveTopology::PointList {
1532                    Some(MaxVertexShaderOutputDeduction::PointListPrimitiveTopology)
1533                } else {
1534                    None
1535                };
1536
1537                let deductions = point_list_deduction.into_iter();
1538
1539                for deduction in deductions.clone() {
1540                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1541                    // ever exceed the minimum variables available.
1542                    max_vertex_shader_output_variables = max_vertex_shader_output_variables
1543                        .checked_sub(deduction.for_variables())
1544                        .unwrap();
1545                    max_vertex_shader_output_location = max_vertex_shader_output_location
1546                        .checked_sub(deduction.for_location())
1547                        .unwrap();
1548                }
1549
1550                let mut num_user_defined_outputs = 0;
1551
1552                for output in entry_point.outputs.iter() {
1553                    match *output {
1554                        Varying::Local { ref iv, location } => {
1555                            if location > max_vertex_shader_output_location {
1556                                return Err(StageError::VertexOutputLocationTooLarge {
1557                                    location,
1558                                    var: iv.clone(),
1559                                    limit: self.limits.max_inter_stage_shader_variables,
1560                                    deductions: deductions.collect(),
1561                                });
1562                            }
1563                            num_user_defined_outputs += 1;
1564                        }
1565                        Varying::BuiltIn(_) => {}
1566                    };
1567
1568                    if let Some(
1569                        cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1570                    ) = compare_function
1571                    {
1572                        if let Varying::BuiltIn(naga::BuiltIn::Position { invariant: false }) =
1573                            *output
1574                        {
1575                            log::warn!(
1576                                concat!(
1577                                    "Vertex shader with entry point {} outputs a ",
1578                                    "@builtin(position) without the @invariant attribute and ",
1579                                    "is used in a pipeline with {cmp:?}. On some machines, ",
1580                                    "this can cause bad artifacting as {cmp:?} assumes the ",
1581                                    "values output from the vertex shader exactly match the ",
1582                                    "value in the depth buffer. The @invariant attribute on the ",
1583                                    "@builtin(position) vertex output ensures that the exact ",
1584                                    "same pixel depths are used every render."
1585                                ),
1586                                entry_point_name,
1587                                cmp = cmp
1588                            );
1589                        }
1590                    }
1591                }
1592
1593                if num_user_defined_outputs > max_vertex_shader_output_variables {
1594                    return Err(StageError::TooManyUserDefinedVertexOutputs {
1595                        num_found: num_user_defined_outputs,
1596                        limit: self.limits.max_inter_stage_shader_variables,
1597                        deductions: deductions.collect(),
1598                    });
1599                }
1600            }
1601            ShaderStageForValidation::Fragment {
1602                dual_source_blending,
1603                has_depth_attachment,
1604            } => {
1605                let mut max_fragment_shader_input_variables =
1606                    self.limits.max_inter_stage_shader_variables;
1607
1608                let deductions = entry_point.inputs.iter().filter_map(|output| match output {
1609                    Varying::Local { .. } => None,
1610                    Varying::BuiltIn(builtin) => {
1611                        MaxFragmentShaderInputDeduction::from_inter_stage_builtin(*builtin).or_else(
1612                            || {
1613                                unreachable!(
1614                                    concat!(
1615                                        "unexpected built-in provided; ",
1616                                        "{:?} is not used for fragment stage input",
1617                                    ),
1618                                    builtin
1619                                )
1620                            },
1621                        )
1622                    }
1623                });
1624
1625                for deduction in deductions.clone() {
1626                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1627                    // ever exceed the minimum variables available.
1628                    max_fragment_shader_input_variables = max_fragment_shader_input_variables
1629                        .checked_sub(deduction.for_variables())
1630                        .unwrap();
1631                }
1632
1633                let mut num_user_defined_inputs = 0;
1634
1635                for output in entry_point.inputs.iter() {
1636                    match *output {
1637                        Varying::Local { ref iv, location } => {
1638                            if location >= self.limits.max_inter_stage_shader_variables {
1639                                return Err(StageError::FragmentInputLocationTooLarge {
1640                                    location,
1641                                    var: iv.clone(),
1642                                    limit: self.limits.max_inter_stage_shader_variables,
1643                                    deductions: deductions.collect(),
1644                                });
1645                            }
1646                            num_user_defined_inputs += 1;
1647                        }
1648                        Varying::BuiltIn(_) => {}
1649                    };
1650                }
1651
1652                if num_user_defined_inputs > max_fragment_shader_input_variables {
1653                    return Err(StageError::TooManyUserDefinedFragmentInputs {
1654                        num_found: num_user_defined_inputs,
1655                        limit: self.limits.max_inter_stage_shader_variables,
1656                        deductions: deductions.collect(),
1657                    });
1658                }
1659
1660                for output in &entry_point.outputs {
1661                    let &Varying::Local { location, ref iv } = output else {
1662                        continue;
1663                    };
1664                    if location >= self.limits.max_color_attachments {
1665                        return Err(StageError::ColorAttachmentLocationTooLarge {
1666                            location,
1667                            var: iv.clone(),
1668                            limit: self.limits.max_color_attachments,
1669                        });
1670                    }
1671                }
1672
1673                // If the pipeline uses dual-source blending, then the shader
1674                // must configure appropriate I/O, but it is not an error to
1675                // use a shader that defines the I/O in a pipeline that only
1676                // uses one blend source.
1677                if dual_source_blending && !entry_point.dual_source_blending {
1678                    return Err(StageError::InvalidDualSourceBlending);
1679                }
1680
1681                if entry_point
1682                    .outputs
1683                    .contains(&Varying::BuiltIn(naga::BuiltIn::FragDepth))
1684                    && !has_depth_attachment
1685                {
1686                    return Err(StageError::MissingFragDepthAttachment);
1687                }
1688            }
1689            ShaderStageForValidation::Mesh => {
1690                for output in &entry_point.outputs {
1691                    if matches!(output, Varying::BuiltIn(naga::BuiltIn::PrimitiveIndex)) {
1692                        this_stage_primitive_index = true;
1693                    }
1694                }
1695            }
1696            _ => (),
1697        }
1698
1699        if let Some(ref mesh_info) = entry_point.mesh_info {
1700            if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1701                return Err(StageError::TooManyMeshVertices {
1702                    limit: self.limits.max_mesh_output_vertices,
1703                    value: mesh_info.max_vertices,
1704                });
1705            }
1706            if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1707                return Err(StageError::TooManyMeshPrimitives {
1708                    limit: self.limits.max_mesh_output_primitives,
1709                    value: mesh_info.max_primitives,
1710                });
1711            }
1712        }
1713        if let Some(task_payload_size) = entry_point.task_payload_size {
1714            if task_payload_size > self.limits.max_task_payload_size {
1715                return Err(StageError::TaskPayloadTooLarge {
1716                    limit: self.limits.max_task_payload_size,
1717                    value: task_payload_size,
1718                });
1719            }
1720        }
1721        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1722            && entry_point.task_payload_size != inputs.task_payload_size
1723        {
1724            return Err(StageError::TaskPayloadMustMatch {
1725                input: inputs.task_payload_size,
1726                shader: entry_point.task_payload_size,
1727            });
1728        }
1729
1730        // Fragment shader primitive index is treated like a varying
1731        if shader_stage.to_naga() == naga::ShaderStage::Fragment
1732            && this_stage_primitive_index
1733            && inputs.primitive_index == Some(false)
1734        {
1735            return Err(StageError::InvalidPrimitiveIndex);
1736        } else if shader_stage.to_naga() == naga::ShaderStage::Fragment
1737            && !this_stage_primitive_index
1738            && inputs.primitive_index == Some(true)
1739        {
1740            return Err(StageError::MissingPrimitiveIndex);
1741        }
1742        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1743            && inputs.task_payload_size.is_some()
1744            && has_draw_id
1745        {
1746            return Err(StageError::DrawIdError);
1747        }
1748
1749        let outputs = entry_point
1750            .outputs
1751            .iter()
1752            .filter_map(|output| match *output {
1753                Varying::Local { location, ref iv } => Some((location, iv.clone())),
1754                Varying::BuiltIn(_) => None,
1755            })
1756            .collect();
1757
1758        Ok(StageIo {
1759            task_payload_size: entry_point.task_payload_size,
1760            varyings: outputs,
1761            primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
1762                Some(this_stage_primitive_index)
1763            } else {
1764                None
1765            },
1766        })
1767    }
1768
1769    pub fn fragment_uses_dual_source_blending(
1770        &self,
1771        entry_point_name: &str,
1772    ) -> Result<bool, StageError> {
1773        let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1774        self.entry_points
1775            .get(&pair)
1776            .ok_or(StageError::MissingEntryPoint(pair.1))
1777            .map(|ep| ep.dual_source_blending)
1778    }
1779}
1780
1781/// Validate a list of color attachment formats against `maxColorAttachmentBytesPerSample`.
1782///
1783/// The color attachments can be from a render pass descriptor or a pipeline descriptor.
1784///
1785/// Implements <https://gpuweb.github.io/gpuweb/#abstract-opdef-calculating-color-attachment-bytes-per-sample>.
1786pub fn validate_color_attachment_bytes_per_sample(
1787    attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1788    limit: u32,
1789) -> Result<(), crate::command::ColorAttachmentError> {
1790    let mut total_bytes_per_sample: u32 = 0;
1791    for format in attachment_formats {
1792        let byte_cost = format.target_pixel_byte_cost().unwrap();
1793        let alignment = format.target_component_alignment().unwrap();
1794
1795        total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1796        total_bytes_per_sample += byte_cost;
1797    }
1798
1799    if total_bytes_per_sample > limit {
1800        return Err(
1801            crate::command::ColorAttachmentError::TooManyBytesPerSample {
1802                total: total_bytes_per_sample,
1803                limit,
1804            },
1805        );
1806    }
1807
1808    Ok(())
1809}
1810
1811pub enum ShaderStageForValidation {
1812    Vertex {
1813        topology: wgt::PrimitiveTopology,
1814        compare_function: Option<wgt::CompareFunction>,
1815    },
1816    Mesh,
1817    Fragment {
1818        dual_source_blending: bool,
1819        has_depth_attachment: bool,
1820    },
1821    Compute,
1822    Task,
1823}
1824
1825impl ShaderStageForValidation {
1826    pub fn to_naga(&self) -> naga::ShaderStage {
1827        match self {
1828            Self::Vertex { .. } => naga::ShaderStage::Vertex,
1829            Self::Mesh => naga::ShaderStage::Mesh,
1830            Self::Fragment { .. } => naga::ShaderStage::Fragment,
1831            Self::Compute => naga::ShaderStage::Compute,
1832            Self::Task => naga::ShaderStage::Task,
1833        }
1834    }
1835
1836    pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
1837        match self {
1838            Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
1839            Self::Mesh => wgt::ShaderStages::MESH,
1840            Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
1841            Self::Compute => wgt::ShaderStages::COMPUTE,
1842            Self::Task => wgt::ShaderStages::TASK,
1843        }
1844    }
1845}