wgpu_core/validation/
shader_io_deductions.rs

1use core::fmt::{self, Debug, Display, Formatter};
2
3#[cfg(doc)]
4#[expect(unused_imports)]
5use crate::validation::StageError;
6
7/// Max shader I/O variable deductions for vertex shader output. Used by
8/// [`StageError::TooManyUserDefinedVertexOutputs`] and
9/// [`StageError::VertexOutputLocationTooLarge`].
10#[derive(Clone, Copy, Debug, Eq, PartialEq)]
11pub enum MaxVertexShaderOutputDeduction {
12    /// When a pipeline's [`crate::pipeline::RenderPipelineDescriptor::primitive`] is set to
13    /// [`wgt::PrimitiveTopology::PointList`].
14    PointListPrimitiveTopology,
15    /// When a clip distances array primitive is used in an output.
16    ClipDistances { array_size: u32 },
17}
18
19impl MaxVertexShaderOutputDeduction {
20    fn variables_from_clip_distance_slot(num_slots: u32) -> u32 {
21        num_slots.div_ceil(4)
22    }
23}
24
25impl MaxVertexShaderOutputDeduction {
26    pub fn for_variables(self) -> u32 {
27        match self {
28            Self::PointListPrimitiveTopology => 1,
29            Self::ClipDistances { array_size } => {
30                Self::variables_from_clip_distance_slot(array_size)
31            }
32        }
33    }
34
35    pub fn for_location(self) -> u32 {
36        match self {
37            Self::PointListPrimitiveTopology => 0,
38            Self::ClipDistances { array_size } => {
39                Self::variables_from_clip_distance_slot(array_size)
40            }
41        }
42    }
43}
44
45/// Max shader I/O variable deductions for vertex shader output. Used by
46/// [`StageError::TooManyUserDefinedFragmentInputs`] and
47/// [`StageError::FragmentInputLocationTooLarge`].
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
49pub enum MaxFragmentShaderInputDeduction {
50    InterStageBuiltIn(InterStageBuiltIn),
51}
52
53impl MaxFragmentShaderInputDeduction {
54    pub fn for_variables(self) -> u32 {
55        match self {
56            Self::InterStageBuiltIn(builtin) => match builtin {
57                InterStageBuiltIn::FrontFacing
58                | InterStageBuiltIn::SampleIndex
59                | InterStageBuiltIn::SampleMask
60                | InterStageBuiltIn::PrimitiveIndex
61                | InterStageBuiltIn::SubgroupInvocationId
62                | InterStageBuiltIn::SubgroupSize
63                | InterStageBuiltIn::ViewIndex
64                | InterStageBuiltIn::PointCoord => 1,
65                InterStageBuiltIn::Barycentric => 3,
66                InterStageBuiltIn::Position => 0,
67            },
68        }
69    }
70
71    pub fn from_inter_stage_builtin(builtin: naga::BuiltIn) -> Option<Self> {
72        use naga::BuiltIn;
73
74        Some(Self::InterStageBuiltIn(match builtin {
75            BuiltIn::Position { .. } => InterStageBuiltIn::Position,
76            BuiltIn::FrontFacing => InterStageBuiltIn::FrontFacing,
77            BuiltIn::SampleIndex => InterStageBuiltIn::SampleIndex,
78            BuiltIn::SampleMask => InterStageBuiltIn::SampleMask,
79            BuiltIn::PrimitiveIndex => InterStageBuiltIn::PrimitiveIndex,
80            BuiltIn::SubgroupSize => InterStageBuiltIn::SubgroupSize,
81            BuiltIn::SubgroupInvocationId => InterStageBuiltIn::SubgroupInvocationId,
82            BuiltIn::PointCoord => InterStageBuiltIn::PointCoord,
83            BuiltIn::Barycentric { .. } => InterStageBuiltIn::Barycentric,
84            BuiltIn::ViewIndex => InterStageBuiltIn::ViewIndex,
85            BuiltIn::BaseInstance
86            | BuiltIn::BaseVertex
87            | BuiltIn::ClipDistances
88            | BuiltIn::CullDistance
89            | BuiltIn::InstanceIndex
90            | BuiltIn::PointSize
91            | BuiltIn::VertexIndex
92            | BuiltIn::DrawIndex
93            | BuiltIn::FragDepth
94            | BuiltIn::GlobalInvocationId
95            | BuiltIn::LocalInvocationId
96            | BuiltIn::LocalInvocationIndex
97            | BuiltIn::WorkGroupId
98            | BuiltIn::WorkGroupSize
99            | BuiltIn::NumWorkGroups
100            | BuiltIn::NumSubgroups
101            | BuiltIn::SubgroupId
102            | BuiltIn::MeshTaskSize
103            | BuiltIn::CullPrimitive
104            | BuiltIn::PointIndex
105            | BuiltIn::LineIndices
106            | BuiltIn::TriangleIndices
107            | BuiltIn::VertexCount
108            | BuiltIn::Vertices
109            | BuiltIn::PrimitiveCount
110            | BuiltIn::Primitives
111            | BuiltIn::RayInvocationId
112            | BuiltIn::NumRayInvocations
113            | BuiltIn::InstanceCustomData
114            | BuiltIn::GeometryIndex
115            | BuiltIn::WorldRayOrigin
116            | BuiltIn::WorldRayDirection
117            | BuiltIn::ObjectRayOrigin
118            | BuiltIn::ObjectRayDirection
119            | BuiltIn::RayTmin
120            | BuiltIn::RayTCurrentMax
121            | BuiltIn::ObjectToWorld
122            | BuiltIn::WorldToObject
123            | BuiltIn::HitKind => return None,
124        }))
125    }
126}
127
128/// A [`naga::BuiltIn`] that counts towards
129/// a [`MaxFragmentShaderInputDeduction::InterStageBuiltIn`].
130///
131/// See also <https://www.w3.org/TR/webgpu/#inter-stage-builtins>.
132#[derive(Clone, Copy, Debug, Eq, PartialEq)]
133pub enum InterStageBuiltIn {
134    // Standard for WebGPU
135    Position,
136    FrontFacing,
137    SampleIndex,
138    SampleMask,
139    PrimitiveIndex,
140    SubgroupInvocationId,
141    SubgroupSize,
142
143    // Non-standard
144    PointCoord,
145    Barycentric,
146    ViewIndex,
147}
148
149pub(in crate::validation) fn display_deductions_as_optional_list<T>(
150    deductions: &[T],
151    accessor: fn(&T) -> u32,
152) -> impl Display + '_
153where
154    T: Debug,
155{
156    struct DisplayFromFn<F>(F);
157
158    impl<F> Display for DisplayFromFn<F>
159    where
160        F: Fn(&mut Formatter<'_>) -> fmt::Result,
161    {
162        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
163            let Self(inner) = self;
164            inner(f)
165        }
166    }
167
168    DisplayFromFn(move |f: &mut Formatter<'_>| {
169        let relevant_deductions = deductions
170            .iter()
171            .map(|deduction| (deduction, accessor(deduction)))
172            .filter(|(_, effective_deduction)| *effective_deduction > 0);
173        if relevant_deductions.clone().next().is_some() {
174            writeln!(f, "; note that some deductions apply during validation:")?;
175            let mut wrote_something = false;
176            for deduction in deductions {
177                let deducted_amount = accessor(deduction);
178                if deducted_amount > 0 {
179                    writeln!(f, "\n- {deduction:?}: {}", accessor(deduction))?;
180                    wrote_something = true;
181                }
182            }
183            debug_assert!(
184                wrote_something,
185                "no substantial deductions found in error display"
186            );
187        }
188        Ok(())
189    })
190}