wgpu/util/
spirv.rs

1//! Utilities for loading SPIR-V module data.
2
3use alloc::borrow::Cow;
4use core::mem;
5
6#[cfg_attr(not(any(feature = "spirv", doc)), expect(unused_imports))]
7use crate::ShaderSource;
8
9#[cfg(doc)]
10use crate::Device;
11
12const SPIRV_MAGIC_NUMBER: u32 = 0x0723_0203;
13
14/// Treat the given byte slice as a SPIR-V module.
15///
16/// # Panics
17///
18/// This function panics if:
19///
20/// - `data.len()` is not a multiple of 4
21/// - `data` does not begin with the SPIR-V magic number
22///
23/// It does not check that the data is a valid SPIR-V module in any other way.
24#[cfg(feature = "spirv")] // ShaderSource::SpirV only exists in this case
25pub fn make_spirv(data: &[u8]) -> ShaderSource<'_> {
26    ShaderSource::SpirV(make_spirv_raw(data))
27}
28
29/// Check whether the byte slice has the SPIR-V magic number (in either byte order) and of an
30/// appropriate size, and panic with a suitable message when it is not.
31///
32/// Returns whether the endianness is opposite of native endianness (i.e. whether
33/// [`u32::swap_bytes()`] should be called.)
34///
35/// Note: this function’s checks are relied upon for the soundness of [`make_spirv_const()`].
36/// Undefined behavior will result if it does not panic when `bytes.len()` is not a multiple of 4.
37#[track_caller]
38const fn assert_has_spirv_magic_number_and_length(bytes: &[u8]) -> bool {
39    // First, check the magic number.
40    // This way we give the best error for wrong formats.
41    // (Plus a special case for the empty slice.)
42    let found_magic_number: Option<bool> = match *bytes {
43        [] => panic!("byte slice is empty, not SPIR-V"),
44        // This would be simpler as slice::starts_with(), but that isn't a const fn yet.
45        [b1, b2, b3, b4, ..] => {
46            let prefix = u32::from_ne_bytes([b1, b2, b3, b4]);
47            if prefix == SPIRV_MAGIC_NUMBER {
48                Some(false)
49            } else if prefix == const { SPIRV_MAGIC_NUMBER.swap_bytes() } {
50                // needs swapping
51                Some(true)
52            } else {
53                None
54            }
55        }
56        _ => None, // fallthrough case = between 1 and 3 bytes
57    };
58
59    match found_magic_number {
60        Some(needs_byte_swap) => {
61            // Note: this assertion is relied upon for the soundness of `make_spirv_const()`.
62            assert!(
63                bytes.len().is_multiple_of(mem::size_of::<u32>()),
64                "SPIR-V data must be a multiple of 4 bytes long"
65            );
66
67            needs_byte_swap
68        }
69        None => {
70            panic!(
71                "byte slice does not start with SPIR-V magic number. \
72            Make sure you are using a binary SPIR-V file."
73            );
74        }
75    }
76}
77
78#[cfg_attr(not(feature = "spirv"), expect(rustdoc::broken_intra_doc_links))]
79/// Version of [`make_spirv()`] intended for use with
80/// [`Device::create_shader_module_passthrough()`].
81///
82/// Returns a raw slice instead of [`ShaderSource`].
83///
84/// # Panics
85///
86/// This function panics if:
87///
88/// - `data.len()` is not a multiple of 4
89/// - `data` does not begin with the SPIR-V magic number
90///
91/// It does not check that the data is a valid SPIR-V module in any other way.
92pub fn make_spirv_raw(bytes: &[u8]) -> Cow<'_, [u32]> {
93    let needs_byte_swap = assert_has_spirv_magic_number_and_length(bytes);
94
95    // If the data happens to be aligned, directly use the byte array,
96    // otherwise copy the byte array in an owned vector and use that instead.
97    let mut words: Cow<'_, [u32]> = match bytemuck::try_cast_slice(bytes) {
98        Ok(words) => Cow::Borrowed(words),
99        // We already checked the length, so if this fails, it fails due to lack of alignment only.
100        Err(_) => Cow::Owned(bytemuck::pod_collect_to_vec(bytes)),
101    };
102
103    // If necessary, swap bytes to native endianness.
104    if needs_byte_swap {
105        for word in Cow::to_mut(&mut words) {
106            *word = word.swap_bytes();
107        }
108    }
109
110    assert!(
111        words[0] == SPIRV_MAGIC_NUMBER,
112        "can't happen: wrong magic number after swap_bytes"
113    );
114    words
115}
116
117/// Version of `make_spirv_raw` used for implementing [`include_spirv!`] and [`include_spirv_raw!`] macros.
118///
119/// Not public API. Also, don't even try calling at runtime; you'll get a stack overflow.
120///
121/// [`include_spirv!`]: crate::include_spirv
122#[doc(hidden)]
123pub const fn make_spirv_const<const IN: usize, const OUT: usize>(bytes: [u8; IN]) -> [u32; OUT] {
124    let needs_byte_swap = assert_has_spirv_magic_number_and_length(&bytes);
125
126    // NOTE: to get around lack of generic const expressions, the input and output lengths must
127    // be specified separately.
128    // Check that they are consistent with each other.
129    assert!(mem::size_of_val(&bytes) == mem::size_of::<[u32; OUT]>());
130
131    // Can't use `bytemuck` in `const fn` (yet), so do it unsafely.
132    // SAFETY:
133    // * The previous assertion checked that the byte sizes of `bytes` and `words` are equal.
134    // * `transmute_copy` doesn't care that the alignment might be wrong.
135    let mut words: [u32; OUT] = unsafe { mem::transmute_copy(&bytes) };
136
137    // If necessary, swap bytes to native endianness.
138    if needs_byte_swap {
139        let mut idx = 0;
140        while idx < words.len() {
141            words[idx] = words[idx].swap_bytes();
142            idx += 1;
143        }
144    }
145
146    assert!(
147        words[0] == SPIRV_MAGIC_NUMBER,
148        "can't happen: wrong magic number after swap_bytes"
149    );
150
151    words
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use alloc::vec;
158
159    fn test_success_with_misalignments<const IN: usize, const OUT: usize>(
160        input: &[u8; IN],
161        expected: [u32; OUT],
162    ) {
163        // We don't know which 3 out of 4 offsets will produce an actually misaligned slice,
164        // but they always will. (Note that it is necessary to reuse the same allocation for all 4
165        // tests, or we could (in theory) get unlucky and not test any misalignments.)
166        let mut buffer = vec![0; input.len() + 4];
167        for offset in 0..4 {
168            let misaligned_slice: &mut [u8; IN] =
169                (&mut buffer[offset..][..input.len()]).try_into().unwrap();
170
171            misaligned_slice.copy_from_slice(input);
172            assert_eq!(*make_spirv_raw(misaligned_slice), expected);
173            assert_eq!(make_spirv_const(*misaligned_slice), expected);
174        }
175    }
176
177    #[test]
178    fn success_be() {
179        // magic number followed by dummy data to see the endianness handling
180        let input = b"\x07\x23\x02\x03\xF1\xF2\xF3\xF4";
181        let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF1F2F3F4];
182        test_success_with_misalignments(input, expected);
183    }
184
185    #[test]
186    fn success_le() {
187        let input = b"\x03\x02\x23\x07\xF1\xF2\xF3\xF4";
188        let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF4F3F2F1];
189        test_success_with_misalignments(input, expected);
190    }
191
192    #[should_panic = "multiple of 4"]
193    #[test]
194    fn nonconst_le_fail() {
195        let _: Cow<'_, [u32]> = make_spirv_raw(&[0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
196    }
197
198    #[should_panic = "multiple of 4"]
199    #[test]
200    fn nonconst_be_fail() {
201        let _: Cow<'_, [u32]> = make_spirv_raw(&[0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
202    }
203
204    #[should_panic = "multiple of 4"]
205    #[test]
206    fn const_le_fail() {
207        let _: [u32; 1] = make_spirv_const([0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
208    }
209
210    #[should_panic = "multiple of 4"]
211    #[test]
212    fn const_be_fail() {
213        let _: [u32; 1] = make_spirv_const([0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
214    }
215
216    #[should_panic = "byte slice is empty, not SPIR-V"]
217    #[test]
218    fn make_spirv_empty() {
219        let _: [u32; 0] = make_spirv_const([]);
220    }
221}