wgpu/util/
spirv.rs

1//! Utilities for loading SPIR-V module data.
2
3use alloc::borrow::Cow;
4use core::mem;
5
6#[cfg_attr(not(any(feature = "spirv", doc)), expect(unused_imports))]
7use crate::ShaderSource;
8
9#[cfg(doc)]
10use crate::Device;
11
12const SPIRV_MAGIC_NUMBER: u32 = 0x0723_0203;
13
14/// Treat the given byte slice as a SPIR-V module.
15///
16/// # Panics
17///
18/// This function panics if:
19///
20/// - `data.len()` is not a multiple of 4
21/// - `data` does not begin with the SPIR-V magic number
22///
23/// It does not check that the data is a valid SPIR-V module in any other way.
24#[cfg(feature = "spirv")] // ShaderSource::SpirV only exists in this case
25pub fn make_spirv(data: &[u8]) -> ShaderSource<'_> {
26    ShaderSource::SpirV(make_spirv_raw(data))
27}
28
29/// Check whether the byte slice has the SPIR-V magic number (in either byte order) and of an
30/// appropriate size, and panic with a suitable message when it is not.
31///
32/// Returns whether the endianness is opposite of native endianness (i.e. whether
33/// [`u32::swap_bytes()`] should be called.)
34///
35/// Note: this function’s checks are relied upon for the soundness of [`make_spirv_const()`].
36/// Undefined behavior will result if it does not panic when `bytes.len()` is not a multiple of 4.
37#[track_caller]
38const fn assert_has_spirv_magic_number_and_length(bytes: &[u8]) -> bool {
39    // First, check the magic number.
40    // This way we give the best error for wrong formats.
41    // (Plus a special case for the empty slice.)
42    let found_magic_number: Option<bool> = match *bytes {
43        [] => panic!("byte slice is empty, not SPIR-V"),
44        // This would be simpler as slice::starts_with(), but that isn't a const fn yet.
45        [b1, b2, b3, b4, ..] => {
46            let prefix = u32::from_ne_bytes([b1, b2, b3, b4]);
47            if prefix == SPIRV_MAGIC_NUMBER {
48                Some(false)
49            } else if prefix == const { SPIRV_MAGIC_NUMBER.swap_bytes() } {
50                // needs swapping
51                Some(true)
52            } else {
53                None
54            }
55        }
56        _ => None, // fallthrough case = between 1 and 3 bytes
57    };
58
59    match found_magic_number {
60        Some(needs_byte_swap) => {
61            // Note: this assertion is relied upon for the soundness of `make_spirv_const()`.
62            assert!(
63                bytes.len() % size_of::<u32>() == 0,
64                "SPIR-V data must be a multiple of 4 bytes long"
65            );
66
67            needs_byte_swap
68        }
69        None => {
70            panic!(
71                "byte slice does not start with SPIR-V magic number. \
72            Make sure you are using a binary SPIR-V file."
73            );
74        }
75    }
76}
77
78/// Version of [`make_spirv()`] intended for use with
79/// [`Device::create_shader_module_passthrough()`].
80///
81/// Returns a raw slice instead of [`ShaderSource`].
82///
83/// # Panics
84///
85/// This function panics if:
86///
87/// - `data.len()` is not a multiple of 4
88/// - `data` does not begin with the SPIR-V magic number
89///
90/// It does not check that the data is a valid SPIR-V module in any other way.
91pub fn make_spirv_raw(bytes: &[u8]) -> Cow<'_, [u32]> {
92    let needs_byte_swap = assert_has_spirv_magic_number_and_length(bytes);
93
94    // If the data happens to be aligned, directly use the byte array,
95    // otherwise copy the byte array in an owned vector and use that instead.
96    let mut words: Cow<'_, [u32]> = match bytemuck::try_cast_slice(bytes) {
97        Ok(words) => Cow::Borrowed(words),
98        // We already checked the length, so if this fails, it fails due to lack of alignment only.
99        Err(_) => Cow::Owned(bytemuck::pod_collect_to_vec(bytes)),
100    };
101
102    // If necessary, swap bytes to native endianness.
103    if needs_byte_swap {
104        for word in Cow::to_mut(&mut words) {
105            *word = word.swap_bytes();
106        }
107    }
108
109    assert!(
110        words[0] == SPIRV_MAGIC_NUMBER,
111        "can't happen: wrong magic number after swap_bytes"
112    );
113    words
114}
115
116/// Version of `make_spirv_raw` used for implementing [`include_spirv!`] and [`include_spirv_raw!`] macros.
117///
118/// Not public API. Also, don't even try calling at runtime; you'll get a stack overflow.
119///
120/// [`include_spirv!`]: crate::include_spirv
121#[doc(hidden)]
122pub const fn make_spirv_const<const IN: usize, const OUT: usize>(bytes: [u8; IN]) -> [u32; OUT] {
123    let needs_byte_swap = assert_has_spirv_magic_number_and_length(&bytes);
124
125    // NOTE: to get around lack of generic const expressions, the input and output lengths must
126    // be specified separately.
127    // Check that they are consistent with each other.
128    assert!(mem::size_of_val(&bytes) == mem::size_of::<[u32; OUT]>());
129
130    // Can't use `bytemuck` in `const fn` (yet), so do it unsafely.
131    // SAFETY:
132    // * The previous assertion checked that the byte sizes of `bytes` and `words` are equal.
133    // * `transmute_copy` doesn't care that the alignment might be wrong.
134    let mut words: [u32; OUT] = unsafe { mem::transmute_copy(&bytes) };
135
136    // If necessary, swap bytes to native endianness.
137    if needs_byte_swap {
138        let mut idx = 0;
139        while idx < words.len() {
140            words[idx] = words[idx].swap_bytes();
141            idx += 1;
142        }
143    }
144
145    assert!(
146        words[0] == SPIRV_MAGIC_NUMBER,
147        "can't happen: wrong magic number after swap_bytes"
148    );
149
150    words
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use alloc::vec;
157
158    fn test_success_with_misalignments<const IN: usize, const OUT: usize>(
159        input: &[u8; IN],
160        expected: [u32; OUT],
161    ) {
162        // We don't know which 3 out of 4 offsets will produce an actually misaligned slice,
163        // but they always will. (Note that it is necessary to reuse the same allocation for all 4
164        // tests, or we could (in theory) get unlucky and not test any misalignments.)
165        let mut buffer = vec![0; input.len() + 4];
166        for offset in 0..4 {
167            let misaligned_slice: &mut [u8; IN] =
168                (&mut buffer[offset..][..input.len()]).try_into().unwrap();
169
170            misaligned_slice.copy_from_slice(input);
171            assert_eq!(*make_spirv_raw(misaligned_slice), expected);
172            assert_eq!(make_spirv_const(*misaligned_slice), expected);
173        }
174    }
175
176    #[test]
177    fn success_be() {
178        // magic number followed by dummy data to see the endianness handling
179        let input = b"\x07\x23\x02\x03\xF1\xF2\xF3\xF4";
180        let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF1F2F3F4];
181        test_success_with_misalignments(input, expected);
182    }
183
184    #[test]
185    fn success_le() {
186        let input = b"\x03\x02\x23\x07\xF1\xF2\xF3\xF4";
187        let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF4F3F2F1];
188        test_success_with_misalignments(input, expected);
189    }
190
191    #[should_panic = "multiple of 4"]
192    #[test]
193    fn nonconst_le_fail() {
194        let _: Cow<'_, [u32]> = make_spirv_raw(&[0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
195    }
196
197    #[should_panic = "multiple of 4"]
198    #[test]
199    fn nonconst_be_fail() {
200        let _: Cow<'_, [u32]> = make_spirv_raw(&[0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
201    }
202
203    #[should_panic = "multiple of 4"]
204    #[test]
205    fn const_le_fail() {
206        let _: [u32; 1] = make_spirv_const([0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
207    }
208
209    #[should_panic = "multiple of 4"]
210    #[test]
211    fn const_be_fail() {
212        let _: [u32; 1] = make_spirv_const([0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
213    }
214
215    #[should_panic = "byte slice is empty, not SPIR-V"]
216    #[test]
217    fn make_spirv_empty() {
218        let _: [u32; 0] = make_spirv_const([]);
219    }
220}