wgpu_types/
shader.rs

1use alloc::borrow::Cow;
2
3/// Describes how shader bound checks should be performed.
4#[derive(Copy, Clone, Debug)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct ShaderRuntimeChecks {
7    /// Enforce bounds checks in shaders, even if the underlying driver doesn't
8    /// support doing so natively.
9    ///
10    /// When this is `true`, `wgpu` promises that shaders can only read or
11    /// write the accessible region of a bindgroup's buffer bindings. If
12    /// the underlying graphics platform cannot implement these bounds checks
13    /// itself, `wgpu` will inject bounds checks before presenting the
14    /// shader to the platform.
15    ///
16    /// When this is `false`, `wgpu` only enforces such bounds checks if the
17    /// underlying platform provides a way to do so itself. `wgpu` does not
18    /// itself add any bounds checks to generated shader code.
19    ///
20    /// Note that `wgpu` users may try to initialize only those portions of
21    /// buffers that they anticipate might be read from. Passing `false` here
22    /// may allow shaders to see wider regions of the buffers than expected,
23    /// making such deferred initialization visible to the application.
24    pub bounds_checks: bool,
25    ///
26    /// If false, the caller MUST ensure that all passed shaders do not contain any infinite loops.
27    ///
28    /// If it does, backend compilers MAY treat such a loop as unreachable code and draw
29    /// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled
30    /// when running untrusted code.
31    pub force_loop_bounding: bool,
32    /// If false, the caller **MUST** ensure that in all passed shaders every function operating
33    /// on a ray query must obey these rules (functions using wgsl naming)
34    /// - `rayQueryInitialize` must have called before `rayQueryProceed`
35    /// - `rayQueryProceed` must have been called, returned true and have hit an AABB before
36    ///   `rayQueryGenerateIntersection` is called
37    /// - `rayQueryProceed` must have been called, returned true and have hit a triangle before
38    ///   `rayQueryConfirmIntersection` is called
39    /// - `rayQueryProceed` must have been called and have returned true before `rayQueryTerminate`,
40    ///   `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called
41    /// - `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection`
42    ///   or `getCommittedHitVertexPositions` are called
43    ///
44    /// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal.
45    pub ray_query_initialization_tracking: bool,
46
47    /// If false, task shaders will not validate that the mesh shader grid they dispatch is within legal limits.
48    pub task_shader_dispatch_tracking: bool,
49
50    /// If false, mesh shaders won't clamp the output primitives' vertex indices, which can lead to
51    /// undefined behavior and arbitrary memory access.
52    pub mesh_shader_primitive_indices_clamp: bool,
53
54    /// If false, integer division and modulo operations will use raw instructions
55    /// without guards against division by zero or signed integer overflow
56    /// (`INT_MIN / -1`). The caller **MUST** ensure that all divisors are non-zero
57    /// and that no signed overflow occurs.
58    pub int_div_checks: bool,
59}
60
61impl ShaderRuntimeChecks {
62    /// Creates a new configuration where the shader is fully checked.
63    #[must_use]
64    pub const fn checked() -> Self {
65        unsafe { Self::all(true) }
66    }
67
68    /// Creates a new configuration where none of the checks are performed.
69    ///
70    /// # Safety
71    ///
72    /// See the documentation for the `set_*` methods for the safety requirements
73    /// of each sub-configuration.
74    #[must_use]
75    pub const fn unchecked() -> Self {
76        unsafe { Self::all(false) }
77    }
78
79    /// Creates a new configuration where all checks are enabled or disabled. To safely
80    /// create a configuration with all checks enabled, use [`ShaderRuntimeChecks::checked`].
81    ///
82    /// # Safety
83    ///
84    /// See the documentation for the `set_*` methods for the safety requirements
85    /// of each sub-configuration.
86    #[must_use]
87    pub const unsafe fn all(all_checks: bool) -> Self {
88        Self {
89            bounds_checks: all_checks,
90            force_loop_bounding: all_checks,
91            ray_query_initialization_tracking: all_checks,
92            task_shader_dispatch_tracking: all_checks,
93            mesh_shader_primitive_indices_clamp: all_checks,
94            int_div_checks: all_checks,
95        }
96    }
97}
98
99impl Default for ShaderRuntimeChecks {
100    fn default() -> Self {
101        Self::checked()
102    }
103}
104
105/// Describes a single entry point in a passthrough shader descriptor.
106#[derive(Debug, Clone)]
107#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
108pub struct PassthroughShaderEntryPoint<'a> {
109    /// The name of the entry point. Only used in validation and for GLSL or DXIL.
110    pub name: Cow<'a, str>,
111    /// Number of workgroups in each dimension x, y and z. Only used for metal with
112    /// compute-like shader stages.
113    pub workgroup_size: (u32, u32, u32),
114}
115
116/// Descriptor for a shader module given by any of several sources.
117/// These shaders are passed through directly to the underlying api.
118/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
119///
120/// Note that you shouldn't expect this to work with bindings except on SPIR-V, and even on SPIR-V
121/// there will be some caveats.
122#[derive(Debug, Clone)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
125    /// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
126    pub label: L,
127    /// The list of entry points and their corresponding workgroup sizes.
128    pub entry_points: Cow<'a, [PassthroughShaderEntryPoint<'a>]>,
129
130    /// Binary SPIR-V data, in 4-byte words.
131    pub spirv: Option<Cow<'a, [u32]>>,
132    /// Shader DXIL source.
133    pub dxil: Option<Cow<'a, [u8]>>,
134    /// Shader HLSL source.
135    pub hlsl: Option<Cow<'a, str>>,
136    /// Shader MetalLib source.
137    pub metallib: Option<Cow<'a, [u8]>>,
138    /// Shader MSL source.
139    pub msl: Option<Cow<'a, str>>,
140    /// Shader GLSL source (currently unused).
141    pub glsl: Option<Cow<'a, str>>,
142    /// Shader WGSL source.
143    pub wgsl: Option<Cow<'a, str>>,
144}
145
146// This is so people don't have to fill in fields they don't use, like num_workgroups,
147// entry_point, or other shader languages they didn't compile for
148impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
149    fn default() -> Self {
150        Self {
151            label: Default::default(),
152            entry_points: Cow::Borrowed(&[]),
153            spirv: None,
154            dxil: None,
155            metallib: None,
156            msl: None,
157            hlsl: None,
158            glsl: None,
159            wgsl: None,
160        }
161    }
162}
163
164impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
165    /// Takes a closure and maps the label of the shader module descriptor into another.
166    pub fn map_label<K>(
167        &self,
168        fun: impl FnOnce(&L) -> K,
169    ) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
170        CreateShaderModuleDescriptorPassthrough {
171            label: fun(&self.label),
172            entry_points: self.entry_points.clone(),
173            spirv: self.spirv.clone(),
174            metallib: self.metallib.clone(),
175            dxil: self.dxil.clone(),
176            msl: self.msl.clone(),
177            hlsl: self.hlsl.clone(),
178            glsl: self.glsl.clone(),
179            wgsl: self.wgsl.clone(),
180        }
181    }
182
183    #[cfg(feature = "trace")]
184    /// Returns the source data for tracing purpose.
185    pub fn trace_data(&self) -> &[u8] {
186        if let Some(spirv) = &self.spirv {
187            bytemuck::cast_slice(spirv)
188        } else if let Some(metallib) = &self.metallib {
189            metallib
190        } else if let Some(msl) = &self.msl {
191            msl.as_bytes()
192        } else if let Some(dxil) = &self.dxil {
193            dxil
194        } else if let Some(hlsl) = &self.hlsl {
195            hlsl.as_bytes()
196        } else if let Some(glsl) = &self.glsl {
197            glsl.as_bytes()
198        } else if let Some(wgsl) = &self.wgsl {
199            wgsl.as_bytes()
200        } else {
201            panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
202        }
203    }
204
205    #[cfg(feature = "trace")]
206    /// Returns the binary file extension for tracing purpose.
207    pub fn trace_binary_ext(&self) -> &'static str {
208        if self.spirv.is_some() {
209            "spv"
210        } else if self.metallib.is_some() {
211            "metallib"
212        } else if self.msl.is_some() {
213            "metal"
214        } else if self.dxil.is_some() {
215            "dxil"
216        } else if self.hlsl.is_some() {
217            "hlsl"
218        } else if self.glsl.is_some() {
219            "glsl"
220        } else if self.wgsl.is_some() {
221            "wgsl"
222        } else {
223            panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
224        }
225    }
226}