wgpu_types/shader.rs
1use alloc::borrow::Cow;
2
3/// Describes how shader bound checks should be performed.
4#[derive(Copy, Clone, Debug)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct ShaderRuntimeChecks {
7 /// Enforce bounds checks in shaders, even if the underlying driver doesn't
8 /// support doing so natively.
9 ///
10 /// When this is `true`, `wgpu` promises that shaders can only read or
11 /// write the accessible region of a bindgroup's buffer bindings. If
12 /// the underlying graphics platform cannot implement these bounds checks
13 /// itself, `wgpu` will inject bounds checks before presenting the
14 /// shader to the platform.
15 ///
16 /// When this is `false`, `wgpu` only enforces such bounds checks if the
17 /// underlying platform provides a way to do so itself. `wgpu` does not
18 /// itself add any bounds checks to generated shader code.
19 ///
20 /// Note that `wgpu` users may try to initialize only those portions of
21 /// buffers that they anticipate might be read from. Passing `false` here
22 /// may allow shaders to see wider regions of the buffers than expected,
23 /// making such deferred initialization visible to the application.
24 pub bounds_checks: bool,
25 ///
26 /// If false, the caller MUST ensure that all passed shaders do not contain any infinite loops.
27 ///
28 /// If it does, backend compilers MAY treat such a loop as unreachable code and draw
29 /// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled
30 /// when running untrusted code.
31 pub force_loop_bounding: bool,
32 /// If false, the caller **MUST** ensure that in all passed shaders every function operating
33 /// on a ray query must obey these rules (functions using wgsl naming)
34 /// - `rayQueryInitialize` must have called before `rayQueryProceed`
35 /// - `rayQueryProceed` must have been called, returned true and have hit an AABB before
36 /// `rayQueryGenerateIntersection` is called
37 /// - `rayQueryProceed` must have been called, returned true and have hit a triangle before
38 /// `rayQueryConfirmIntersection` is called
39 /// - `rayQueryProceed` must have been called and have returned true before `rayQueryTerminate`,
40 /// `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called
41 /// - `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection`
42 /// or `getCommittedHitVertexPositions` are called
43 ///
44 /// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal.
45 pub ray_query_initialization_tracking: bool,
46
47 /// If false, task shaders will not validate that the mesh shader grid they dispatch is within legal limits.
48 pub task_shader_dispatch_tracking: bool,
49
50 /// If false, mesh shaders won't clamp the output primitives' vertex indices, which can lead to
51 /// undefined behavior and arbitrary memory access.
52 pub mesh_shader_primitive_indices_clamp: bool,
53
54 /// If false, integer division and modulo operations will use raw instructions
55 /// without guards against division by zero or signed integer overflow
56 /// (`INT_MIN / -1`). The caller **MUST** ensure that all divisors are non-zero
57 /// and that no signed overflow occurs.
58 pub int_div_checks: bool,
59}
60
61impl ShaderRuntimeChecks {
62 /// Creates a new configuration where the shader is fully checked.
63 #[must_use]
64 pub const fn checked() -> Self {
65 unsafe { Self::all(true) }
66 }
67
68 /// Creates a new configuration where none of the checks are performed.
69 ///
70 /// # Safety
71 ///
72 /// See the documentation for the `set_*` methods for the safety requirements
73 /// of each sub-configuration.
74 #[must_use]
75 pub const fn unchecked() -> Self {
76 unsafe { Self::all(false) }
77 }
78
79 /// Creates a new configuration where all checks are enabled or disabled. To safely
80 /// create a configuration with all checks enabled, use [`ShaderRuntimeChecks::checked`].
81 ///
82 /// # Safety
83 ///
84 /// See the documentation for the `set_*` methods for the safety requirements
85 /// of each sub-configuration.
86 #[must_use]
87 pub const unsafe fn all(all_checks: bool) -> Self {
88 Self {
89 bounds_checks: all_checks,
90 force_loop_bounding: all_checks,
91 ray_query_initialization_tracking: all_checks,
92 task_shader_dispatch_tracking: all_checks,
93 mesh_shader_primitive_indices_clamp: all_checks,
94 int_div_checks: all_checks,
95 }
96 }
97}
98
99impl Default for ShaderRuntimeChecks {
100 fn default() -> Self {
101 Self::checked()
102 }
103}
104
105/// Describes a single entry point in a passthrough shader descriptor.
106#[derive(Debug, Clone)]
107#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
108pub struct PassthroughShaderEntryPoint<'a> {
109 /// The name of the entry point. Only used in validation and for GLSL or DXIL.
110 pub name: Cow<'a, str>,
111 /// Number of workgroups in each dimension x, y and z. Only used for metal with
112 /// compute-like shader stages.
113 pub workgroup_size: (u32, u32, u32),
114}
115
116/// Descriptor for a shader module given by any of several sources.
117/// These shaders are passed through directly to the underlying api.
118/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
119///
120/// Note that you shouldn't expect this to work with bindings except on SPIR-V, and even on SPIR-V
121/// there will be some caveats.
122#[derive(Debug, Clone)]
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
125 /// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
126 pub label: L,
127 /// The list of entry points and their corresponding workgroup sizes.
128 pub entry_points: Cow<'a, [PassthroughShaderEntryPoint<'a>]>,
129
130 /// Binary SPIR-V data, in 4-byte words.
131 pub spirv: Option<Cow<'a, [u32]>>,
132 /// Shader DXIL source.
133 pub dxil: Option<Cow<'a, [u8]>>,
134 /// Shader HLSL source.
135 pub hlsl: Option<Cow<'a, str>>,
136 /// Shader MetalLib source.
137 pub metallib: Option<Cow<'a, [u8]>>,
138 /// Shader MSL source.
139 pub msl: Option<Cow<'a, str>>,
140 /// Shader GLSL source (currently unused).
141 pub glsl: Option<Cow<'a, str>>,
142 /// Shader WGSL source.
143 pub wgsl: Option<Cow<'a, str>>,
144}
145
146// This is so people don't have to fill in fields they don't use, like num_workgroups,
147// entry_point, or other shader languages they didn't compile for
148impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
149 fn default() -> Self {
150 Self {
151 label: Default::default(),
152 entry_points: Cow::Borrowed(&[]),
153 spirv: None,
154 dxil: None,
155 metallib: None,
156 msl: None,
157 hlsl: None,
158 glsl: None,
159 wgsl: None,
160 }
161 }
162}
163
164impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
165 /// Takes a closure and maps the label of the shader module descriptor into another.
166 pub fn map_label<K>(
167 &self,
168 fun: impl FnOnce(&L) -> K,
169 ) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
170 CreateShaderModuleDescriptorPassthrough {
171 label: fun(&self.label),
172 entry_points: self.entry_points.clone(),
173 spirv: self.spirv.clone(),
174 metallib: self.metallib.clone(),
175 dxil: self.dxil.clone(),
176 msl: self.msl.clone(),
177 hlsl: self.hlsl.clone(),
178 glsl: self.glsl.clone(),
179 wgsl: self.wgsl.clone(),
180 }
181 }
182
183 #[cfg(feature = "trace")]
184 /// Returns the source data for tracing purpose.
185 pub fn trace_data(&self) -> &[u8] {
186 if let Some(spirv) = &self.spirv {
187 bytemuck::cast_slice(spirv)
188 } else if let Some(metallib) = &self.metallib {
189 metallib
190 } else if let Some(msl) = &self.msl {
191 msl.as_bytes()
192 } else if let Some(dxil) = &self.dxil {
193 dxil
194 } else if let Some(hlsl) = &self.hlsl {
195 hlsl.as_bytes()
196 } else if let Some(glsl) = &self.glsl {
197 glsl.as_bytes()
198 } else if let Some(wgsl) = &self.wgsl {
199 wgsl.as_bytes()
200 } else {
201 panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
202 }
203 }
204
205 #[cfg(feature = "trace")]
206 /// Returns the binary file extension for tracing purpose.
207 pub fn trace_binary_ext(&self) -> &'static str {
208 if self.spirv.is_some() {
209 "spv"
210 } else if self.metallib.is_some() {
211 "metallib"
212 } else if self.msl.is_some() {
213 "metal"
214 } else if self.dxil.is_some() {
215 "dxil"
216 } else if self.hlsl.is_some() {
217 "hlsl"
218 } else if self.glsl.is_some() {
219 "glsl"
220 } else if self.wgsl.is_some() {
221 "wgsl"
222 } else {
223 panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
224 }
225 }
226}