wgpu_types/shader.rs
1use alloc::{borrow::Cow, string::String};
2
3/// Describes how shader bound checks should be performed.
4#[derive(Copy, Clone, Debug)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct ShaderRuntimeChecks {
7 /// Enforce bounds checks in shaders, even if the underlying driver doesn't
8 /// support doing so natively.
9 ///
10 /// When this is `true`, `wgpu` promises that shaders can only read or
11 /// write the accessible region of a bindgroup's buffer bindings. If
12 /// the underlying graphics platform cannot implement these bounds checks
13 /// itself, `wgpu` will inject bounds checks before presenting the
14 /// shader to the platform.
15 ///
16 /// When this is `false`, `wgpu` only enforces such bounds checks if the
17 /// underlying platform provides a way to do so itself. `wgpu` does not
18 /// itself add any bounds checks to generated shader code.
19 ///
20 /// Note that `wgpu` users may try to initialize only those portions of
21 /// buffers that they anticipate might be read from. Passing `false` here
22 /// may allow shaders to see wider regions of the buffers than expected,
23 /// making such deferred initialization visible to the application.
24 pub bounds_checks: bool,
25 ///
26 /// If false, the caller MUST ensure that all passed shaders do not contain any infinite loops.
27 ///
28 /// If it does, backend compilers MAY treat such a loop as unreachable code and draw
29 /// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled
30 /// when running untrusted code.
31 pub force_loop_bounding: bool,
32 /// If false, the caller **MUST** ensure that in all passed shaders every function operating
33 /// on a ray query must obey these rules (functions using wgsl naming)
34 /// - `rayQueryInitialize` must have called before `rayQueryProceed`
35 /// - `rayQueryProceed` must have been called, returned true and have hit an AABB before
36 /// `rayQueryGenerateIntersection` is called
37 /// - `rayQueryProceed` must have been called, returned true and have hit a triangle before
38 /// `rayQueryConfirmIntersection` is called
39 /// - `rayQueryProceed` must have been called and have returned true before `rayQueryTerminate`,
40 /// `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called
41 /// - `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection`
42 /// or `getCommittedHitVertexPositions` are called
43 ///
44 /// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal.
45 pub ray_query_initialization_tracking: bool,
46}
47
48impl ShaderRuntimeChecks {
49 /// Creates a new configuration where the shader is fully checked.
50 #[must_use]
51 pub const fn checked() -> Self {
52 unsafe { Self::all(true) }
53 }
54
55 /// Creates a new configuration where none of the checks are performed.
56 ///
57 /// # Safety
58 ///
59 /// See the documentation for the `set_*` methods for the safety requirements
60 /// of each sub-configuration.
61 #[must_use]
62 pub const fn unchecked() -> Self {
63 unsafe { Self::all(false) }
64 }
65
66 /// Creates a new configuration where all checks are enabled or disabled. To safely
67 /// create a configuration with all checks enabled, use [`ShaderRuntimeChecks::checked`].
68 ///
69 /// # Safety
70 ///
71 /// See the documentation for the `set_*` methods for the safety requirements
72 /// of each sub-configuration.
73 #[must_use]
74 pub const unsafe fn all(all_checks: bool) -> Self {
75 Self {
76 bounds_checks: all_checks,
77 force_loop_bounding: all_checks,
78 ray_query_initialization_tracking: all_checks,
79 }
80 }
81}
82
83impl Default for ShaderRuntimeChecks {
84 fn default() -> Self {
85 Self::checked()
86 }
87}
88
89/// Descriptor for a shader module given by any of several sources.
90/// These shaders are passed through directly to the underlying api.
91/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
92#[derive(Debug, Clone)]
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
95 /// Entrypoint. Unused for Spir-V.
96 pub entry_point: String,
97 /// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
98 pub label: L,
99 /// Number of workgroups in each dimension x, y and z. Unused for Spir-V.
100 pub num_workgroups: (u32, u32, u32),
101 /// Runtime checks that should be enabled.
102 pub runtime_checks: ShaderRuntimeChecks,
103
104 /// Binary SPIR-V data, in 4-byte words.
105 pub spirv: Option<Cow<'a, [u32]>>,
106 /// Shader DXIL source.
107 pub dxil: Option<Cow<'a, [u8]>>,
108 /// Shader MSL source.
109 pub msl: Option<Cow<'a, str>>,
110 /// Shader HLSL source.
111 pub hlsl: Option<Cow<'a, str>>,
112 /// Shader GLSL source (currently unused).
113 pub glsl: Option<Cow<'a, str>>,
114 /// Shader WGSL source.
115 pub wgsl: Option<Cow<'a, str>>,
116}
117
118// This is so people don't have to fill in fields they don't use, like num_workgroups,
119// entry_point, or other shader languages they didn't compile for
120impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
121 fn default() -> Self {
122 Self {
123 entry_point: "".into(),
124 label: Default::default(),
125 num_workgroups: (0, 0, 0),
126 runtime_checks: ShaderRuntimeChecks::unchecked(),
127 spirv: None,
128 dxil: None,
129 msl: None,
130 hlsl: None,
131 glsl: None,
132 wgsl: None,
133 }
134 }
135}
136
137impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
138 /// Takes a closure and maps the label of the shader module descriptor into another.
139 pub fn map_label<K>(
140 &self,
141 fun: impl FnOnce(&L) -> K,
142 ) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
143 CreateShaderModuleDescriptorPassthrough {
144 entry_point: self.entry_point.clone(),
145 label: fun(&self.label),
146 num_workgroups: self.num_workgroups,
147 runtime_checks: self.runtime_checks,
148 spirv: self.spirv.clone(),
149 dxil: self.dxil.clone(),
150 msl: self.msl.clone(),
151 hlsl: self.hlsl.clone(),
152 glsl: self.glsl.clone(),
153 wgsl: self.wgsl.clone(),
154 }
155 }
156
157 #[cfg(feature = "trace")]
158 /// Returns the source data for tracing purpose.
159 pub fn trace_data(&self) -> &[u8] {
160 if let Some(spirv) = &self.spirv {
161 bytemuck::cast_slice(spirv)
162 } else if let Some(msl) = &self.msl {
163 msl.as_bytes()
164 } else if let Some(dxil) = &self.dxil {
165 dxil
166 } else {
167 panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
168 }
169 }
170
171 #[cfg(feature = "trace")]
172 /// Returns the binary file extension for tracing purpose.
173 pub fn trace_binary_ext(&self) -> &'static str {
174 if self.spirv.is_some() {
175 "spv"
176 } else if self.msl.is_some() {
177 "msl"
178 } else if self.dxil.is_some() {
179 "dxil"
180 } else {
181 panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
182 }
183 }
184}