wgpu_types/
shader.rs

1use alloc::{borrow::Cow, string::String};
2
3/// Describes how shader bound checks should be performed.
4#[derive(Copy, Clone, Debug)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct ShaderRuntimeChecks {
7    /// Enforce bounds checks in shaders, even if the underlying driver doesn't
8    /// support doing so natively.
9    ///
10    /// When this is `true`, `wgpu` promises that shaders can only read or
11    /// write the accessible region of a bindgroup's buffer bindings. If
12    /// the underlying graphics platform cannot implement these bounds checks
13    /// itself, `wgpu` will inject bounds checks before presenting the
14    /// shader to the platform.
15    ///
16    /// When this is `false`, `wgpu` only enforces such bounds checks if the
17    /// underlying platform provides a way to do so itself. `wgpu` does not
18    /// itself add any bounds checks to generated shader code.
19    ///
20    /// Note that `wgpu` users may try to initialize only those portions of
21    /// buffers that they anticipate might be read from. Passing `false` here
22    /// may allow shaders to see wider regions of the buffers than expected,
23    /// making such deferred initialization visible to the application.
24    pub bounds_checks: bool,
25    ///
26    /// If false, the caller MUST ensure that all passed shaders do not contain any infinite loops.
27    ///
28    /// If it does, backend compilers MAY treat such a loop as unreachable code and draw
29    /// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled
30    /// when running untrusted code.
31    pub force_loop_bounding: bool,
32    /// If false, the caller **MUST** ensure that in all passed shaders every function operating
33    /// on a ray query must obey these rules (functions using wgsl naming)
34    /// - `rayQueryInitialize` must have called before `rayQueryProceed`
35    /// - `rayQueryProceed` must have been called, returned true and have hit an AABB before
36    ///   `rayQueryGenerateIntersection` is called
37    /// - `rayQueryProceed` must have been called, returned true and have hit a triangle before
38    ///   `rayQueryConfirmIntersection` is called
39    /// - `rayQueryProceed` must have been called and have returned true before `rayQueryTerminate`,
40    ///   `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called
41    /// - `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection`
42    ///   or `getCommittedHitVertexPositions` are called
43    ///
44    /// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal.
45    pub ray_query_initialization_tracking: bool,
46}
47
48impl ShaderRuntimeChecks {
49    /// Creates a new configuration where the shader is fully checked.
50    #[must_use]
51    pub const fn checked() -> Self {
52        unsafe { Self::all(true) }
53    }
54
55    /// Creates a new configuration where none of the checks are performed.
56    ///
57    /// # Safety
58    ///
59    /// See the documentation for the `set_*` methods for the safety requirements
60    /// of each sub-configuration.
61    #[must_use]
62    pub const fn unchecked() -> Self {
63        unsafe { Self::all(false) }
64    }
65
66    /// Creates a new configuration where all checks are enabled or disabled. To safely
67    /// create a configuration with all checks enabled, use [`ShaderRuntimeChecks::checked`].
68    ///
69    /// # Safety
70    ///
71    /// See the documentation for the `set_*` methods for the safety requirements
72    /// of each sub-configuration.
73    #[must_use]
74    pub const unsafe fn all(all_checks: bool) -> Self {
75        Self {
76            bounds_checks: all_checks,
77            force_loop_bounding: all_checks,
78            ray_query_initialization_tracking: all_checks,
79        }
80    }
81}
82
83impl Default for ShaderRuntimeChecks {
84    fn default() -> Self {
85        Self::checked()
86    }
87}
88
89/// Descriptor for a shader module given by any of several sources.
90/// These shaders are passed through directly to the underlying api.
91/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
92#[derive(Debug, Clone)]
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
95    /// Entrypoint. Unused for Spir-V.
96    pub entry_point: String,
97    /// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
98    pub label: L,
99    /// Number of workgroups in each dimension x, y and z. Unused for Spir-V.
100    pub num_workgroups: (u32, u32, u32),
101    /// Runtime checks that should be enabled.
102    pub runtime_checks: ShaderRuntimeChecks,
103
104    /// Binary SPIR-V data, in 4-byte words.
105    pub spirv: Option<Cow<'a, [u32]>>,
106    /// Shader DXIL source.
107    pub dxil: Option<Cow<'a, [u8]>>,
108    /// Shader MSL source.
109    pub msl: Option<Cow<'a, str>>,
110    /// Shader HLSL source.
111    pub hlsl: Option<Cow<'a, str>>,
112    /// Shader GLSL source (currently unused).
113    pub glsl: Option<Cow<'a, str>>,
114    /// Shader WGSL source.
115    pub wgsl: Option<Cow<'a, str>>,
116}
117
118// This is so people don't have to fill in fields they don't use, like num_workgroups,
119// entry_point, or other shader languages they didn't compile for
120impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
121    fn default() -> Self {
122        Self {
123            entry_point: "".into(),
124            label: Default::default(),
125            num_workgroups: (0, 0, 0),
126            runtime_checks: ShaderRuntimeChecks::unchecked(),
127            spirv: None,
128            dxil: None,
129            msl: None,
130            hlsl: None,
131            glsl: None,
132            wgsl: None,
133        }
134    }
135}
136
137impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
138    /// Takes a closure and maps the label of the shader module descriptor into another.
139    pub fn map_label<K>(
140        &self,
141        fun: impl FnOnce(&L) -> K,
142    ) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
143        CreateShaderModuleDescriptorPassthrough {
144            entry_point: self.entry_point.clone(),
145            label: fun(&self.label),
146            num_workgroups: self.num_workgroups,
147            runtime_checks: self.runtime_checks,
148            spirv: self.spirv.clone(),
149            dxil: self.dxil.clone(),
150            msl: self.msl.clone(),
151            hlsl: self.hlsl.clone(),
152            glsl: self.glsl.clone(),
153            wgsl: self.wgsl.clone(),
154        }
155    }
156
157    #[cfg(feature = "trace")]
158    /// Returns the source data for tracing purpose.
159    pub fn trace_data(&self) -> &[u8] {
160        if let Some(spirv) = &self.spirv {
161            bytemuck::cast_slice(spirv)
162        } else if let Some(msl) = &self.msl {
163            msl.as_bytes()
164        } else if let Some(dxil) = &self.dxil {
165            dxil
166        } else {
167            panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
168        }
169    }
170
171    #[cfg(feature = "trace")]
172    /// Returns the binary file extension for tracing purpose.
173    pub fn trace_binary_ext(&self) -> &'static str {
174        if self.spirv.is_some() {
175            "spv"
176        } else if self.msl.is_some() {
177            "msl"
178        } else if self.dxil.is_some() {
179            "dxil"
180        } else {
181            panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
182        }
183    }
184}