naga_test/
lib.rs

1// A lot of the code can be unused based on configuration flags,
2// the corresponding warnings aren't helpful.
3#![allow(dead_code, unused_imports)]
4
5use core::fmt::Write;
6
7use std::{
8    fs,
9    path::{Path, PathBuf},
10};
11
12use naga::compact::KeepUnused;
13use ron::de;
14
15bitflags::bitflags! {
16    #[derive(Clone, Copy, serde::Deserialize)]
17    #[serde(transparent)]
18    #[derive(Debug, Eq, PartialEq)]
19    pub struct Targets: u32 {
20        /// A serialization of the `naga::Module`, in RON format.
21        const IR = 1;
22
23        /// A serialization of the `naga::valid::ModuleInfo`, in RON format.
24        const ANALYSIS = 1 << 1;
25
26        const SPIRV = 1 << 2;
27        const METAL = 1 << 3;
28        const GLSL = 1 << 4;
29        const DOT = 1 << 5;
30        const HLSL = 1 << 6;
31        const WGSL = 1 << 7;
32        const NO_VALIDATION = 1 << 8;
33    }
34}
35
36impl Targets {
37    /// Defaults for `spv` and `glsl` snapshots.
38    pub fn non_wgsl_default() -> Self {
39        Targets::WGSL
40    }
41
42    /// Defaults for `wgsl` snapshots.
43    pub fn wgsl_default() -> Self {
44        Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
45    }
46}
47
48#[derive(serde::Deserialize)]
49pub struct SpvOutVersion(pub u8, pub u8);
50impl Default for SpvOutVersion {
51    fn default() -> Self {
52        SpvOutVersion(1, 1)
53    }
54}
55
56#[derive(serde::Deserialize)]
57pub struct BindingMapSerialization {
58    pub resource_binding: naga::ResourceBinding,
59    pub bind_target: naga::back::spv::BindingInfo,
60}
61
62pub fn deserialize_binding_map<'de, D>(
63    deserializer: D,
64) -> Result<naga::back::spv::BindingMap, D::Error>
65where
66    D: serde::Deserializer<'de>,
67{
68    use serde::Deserialize;
69
70    let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
71    let mut map = naga::back::spv::BindingMap::default();
72    for item in vec {
73        map.insert(item.resource_binding, item.bind_target);
74    }
75    Ok(map)
76}
77
78#[derive(Default, serde::Deserialize)]
79#[serde(default)]
80pub struct WgslInParameters {
81    pub parse_doc_comments: bool,
82}
83impl From<&WgslInParameters> for naga::front::wgsl::Options {
84    fn from(value: &WgslInParameters) -> Self {
85        Self {
86            parse_doc_comments: value.parse_doc_comments,
87        }
88    }
89}
90
91#[derive(Default, serde::Deserialize)]
92#[serde(default)]
93pub struct SpirvInParameters {
94    pub adjust_coordinate_space: bool,
95}
96impl From<&SpirvInParameters> for naga::front::spv::Options {
97    fn from(value: &SpirvInParameters) -> Self {
98        Self {
99            adjust_coordinate_space: value.adjust_coordinate_space,
100            ..Default::default()
101        }
102    }
103}
104
105#[derive(serde::Deserialize)]
106#[serde(default)]
107pub struct SpirvOutParameters {
108    pub version: SpvOutVersion,
109    pub capabilities: naga::FastHashSet<spirv::Capability>,
110    pub debug: bool,
111    pub adjust_coordinate_space: bool,
112    pub force_point_size: bool,
113    pub clamp_frag_depth: bool,
114    pub separate_entry_points: bool,
115    #[serde(deserialize_with = "deserialize_binding_map")]
116    pub binding_map: naga::back::spv::BindingMap,
117    pub use_storage_input_output_16: bool,
118}
119impl Default for SpirvOutParameters {
120    fn default() -> Self {
121        Self {
122            version: SpvOutVersion::default(),
123            capabilities: naga::FastHashSet::default(),
124            debug: false,
125            adjust_coordinate_space: false,
126            force_point_size: false,
127            clamp_frag_depth: false,
128            separate_entry_points: false,
129            use_storage_input_output_16: true,
130            binding_map: naga::back::spv::BindingMap::default(),
131        }
132    }
133}
134impl SpirvOutParameters {
135    pub fn to_options<'a>(
136        &'a self,
137        bounds_check_policies: naga::proc::BoundsCheckPolicies,
138        debug_info: Option<naga::back::spv::DebugInfo<'a>>,
139    ) -> naga::back::spv::Options<'a> {
140        use naga::back::spv;
141        let mut flags = spv::WriterFlags::LABEL_VARYINGS;
142        flags.set(spv::WriterFlags::DEBUG, self.debug);
143        flags.set(
144            spv::WriterFlags::ADJUST_COORDINATE_SPACE,
145            self.adjust_coordinate_space,
146        );
147        flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
148        flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
149        naga::back::spv::Options {
150            lang_version: (self.version.0, self.version.1),
151            flags,
152            capabilities: if self.capabilities.is_empty() {
153                None
154            } else {
155                Some(self.capabilities.clone())
156            },
157            bounds_check_policies,
158            binding_map: self.binding_map.clone(),
159            zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
160            force_loop_bounding: true,
161            debug_info,
162            use_storage_input_output_16: self.use_storage_input_output_16,
163        }
164    }
165}
166
167#[derive(Default, serde::Deserialize)]
168#[serde(default)]
169pub struct WgslOutParameters {
170    pub explicit_types: bool,
171}
172impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
173    fn from(value: &WgslOutParameters) -> Self {
174        let mut flags = Self::empty();
175        flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
176        flags
177    }
178}
179
180#[derive(Default, serde::Deserialize)]
181pub struct FragmentModule {
182    pub path: String,
183    pub entry_point: String,
184}
185
186#[derive(Default, serde::Deserialize)]
187#[serde(default)]
188pub struct Parameters {
189    // -- GOD MODE --
190    pub god_mode: bool,
191
192    // -- wgsl-in options --
193    #[serde(rename = "wgsl-in")]
194    pub wgsl_in: WgslInParameters,
195
196    // -- spirv-in options --
197    #[serde(rename = "spv-in")]
198    pub spv_in: SpirvInParameters,
199
200    // -- SPIR-V options --
201    pub spv: SpirvOutParameters,
202
203    /// Defaults to [`Targets::non_wgsl_default()`] for `spv` and `glsl` snapshots,
204    /// and [`Targets::wgsl_default()`] for `wgsl` snapshots.
205    pub targets: Option<Targets>,
206
207    // -- MSL options --
208    pub msl: naga::back::msl::Options,
209    #[serde(default)]
210    pub msl_pipeline: naga::back::msl::PipelineOptions,
211
212    // -- GLSL options --
213    pub glsl: naga::back::glsl::Options,
214    pub glsl_exclude_list: naga::FastHashSet<String>,
215    pub glsl_multiview: Option<core::num::NonZeroU32>,
216
217    // -- HLSL options --
218    pub hlsl: naga::back::hlsl::Options,
219
220    // -- WGSL options --
221    pub wgsl: WgslOutParameters,
222
223    // -- General options --
224
225    // Allow backends to be aware of the fragment module.
226    // Is the name of a WGSL file in the same directory as the test file.
227    pub fragment_module: Option<FragmentModule>,
228
229    pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
230    pub pipeline_constants: naga::back::PipelineConstants,
231}
232
233/// Information about a shader input file.
234#[derive(Debug)]
235pub struct Input {
236    /// The subdirectory of `tests/in` to which this input belongs, if any.
237    ///
238    /// If the subdirectory is omitted, we assume that the output goes
239    /// to "wgsl".
240    pub subdirectory: PathBuf,
241
242    /// The input filename name, without a directory.
243    pub file_name: PathBuf,
244
245    /// True if output filenames should add the output extension on top of
246    /// `file_name`'s existing extension, rather than replacing it.
247    ///
248    /// This is used by `convert_snapshots_glsl`, which wants to take input files
249    /// like `210-bevy-2d-shader.frag` and just add `.wgsl` to it, producing
250    /// `210-bevy-2d-shader.frag.wgsl`.
251    pub keep_input_extension: bool,
252}
253
254impl Input {
255    /// Read an input file and its corresponding parameters file.
256    ///
257    /// Given `input`, the relative path of a shader input file, return
258    /// a `Source` value containing its path, code, and parameters.
259    ///
260    /// The `input` path is interpreted relative to the `BASE_DIR_IN`
261    /// subdirectory of the directory given by the `CARGO_MANIFEST_DIR`
262    /// environment variable.
263    pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
264        Input {
265            subdirectory: PathBuf::from(subdirectory),
266            // Don't wipe out any extensions on `name`, as
267            // `with_extension` would do.
268            file_name: PathBuf::from(format!("{name}.{extension}")),
269            keep_input_extension: false,
270        }
271    }
272
273    /// Return an iterator that produces an `Input` for each entry in `subdirectory`.
274    pub fn files_in_dir<'a>(
275        subdirectory: &'a str,
276        file_extensions: &'a [&'a str],
277        dir_in: &str,
278    ) -> impl Iterator<Item = Input> + 'a {
279        let input_directory = Path::new(dir_in).join(subdirectory);
280
281        let entries = match std::fs::read_dir(&input_directory) {
282            Ok(entries) => entries,
283            Err(err) => panic!(
284                "Error opening directory '{}': {}",
285                input_directory.display(),
286                err
287            ),
288        };
289
290        entries.filter_map(move |result| {
291            let entry = result.expect("error reading directory");
292            if !entry.file_type().unwrap().is_file() {
293                return None;
294            }
295
296            let file_name = PathBuf::from(entry.file_name());
297            let extension = file_name
298                .extension()
299                .expect("all files in snapshot input directory should have extensions");
300
301            if !file_extensions.contains(&extension.to_str().unwrap()) {
302                return None;
303            }
304
305            if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
306                if !file_name.to_string_lossy().contains(&pat) {
307                    return None;
308                }
309            }
310
311            let input = Input::new(
312                subdirectory,
313                file_name.file_stem().unwrap().to_str().unwrap(),
314                extension.to_str().unwrap(),
315            );
316            Some(input)
317        })
318    }
319
320    /// Return the path to the input directory.
321    pub fn input_directory(&self, dir_in: &str) -> PathBuf {
322        Path::new(dir_in).join(&self.subdirectory)
323    }
324
325    /// Return the path to the output directory.
326    pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
327        Path::new(dir_out).join(subdirectory)
328    }
329
330    /// Return the path to the input file.
331    pub fn input_path(&self, dir_in: &str) -> PathBuf {
332        let mut input = self.input_directory(dir_in);
333        input.push(&self.file_name);
334        input
335    }
336
337    pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
338        let mut output = Self::output_directory(subdirectory, dir_out);
339        if self.keep_input_extension {
340            let file_name = format!(
341                "{}-{}.{}",
342                self.subdirectory.display(),
343                self.file_name.display(),
344                extension
345            );
346
347            output.push(&file_name);
348        } else {
349            let file_name = format!(
350                "{}-{}",
351                self.subdirectory.display(),
352                self.file_name.display()
353            );
354
355            output.push(&file_name);
356            output.set_extension(extension);
357        }
358        output
359    }
360
361    /// Return the contents of the input file as a string.
362    pub fn read_source(&self, dir_in: &str, print: bool) -> String {
363        if print {
364            println!("Processing '{}'", self.file_name.display());
365        }
366        let input_path = self.input_path(dir_in);
367        match fs::read_to_string(&input_path) {
368            Ok(source) => source,
369            Err(err) => {
370                panic!(
371                    "Couldn't read shader input file `{}`: {}",
372                    input_path.display(),
373                    err
374                );
375            }
376        }
377    }
378
379    /// Return the contents of the input file as a vector of bytes.
380    pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
381        if print {
382            println!("Processing '{}'", self.file_name.display());
383        }
384        let input_path = self.input_path(dir_in);
385        match fs::read(&input_path) {
386            Ok(bytes) => bytes,
387            Err(err) => {
388                panic!(
389                    "Couldn't read shader input file `{}`: {}",
390                    input_path.display(),
391                    err
392                );
393            }
394        }
395    }
396
397    pub fn bytes(&self, dir_in: &str) -> u64 {
398        let input_path = self.input_path(dir_in);
399        std::fs::metadata(input_path).unwrap().len()
400    }
401
402    /// Return this input's parameter file, parsed.
403    pub fn read_parameters(&self, dir_in: &str) -> Parameters {
404        let mut param_path = self.input_path(dir_in);
405        param_path.set_extension("toml");
406        let mut params = match fs::read_to_string(&param_path) {
407            Ok(string) => match toml::de::from_str(&string) {
408                Ok(params) => params,
409                Err(e) => panic!(
410                    "Couldn't parse param file: {} due to: {e}",
411                    param_path.display()
412                ),
413            },
414            Err(_) => Parameters::default(),
415        };
416
417        if params.targets.is_none() {
418            match self
419                .input_path(dir_in)
420                .extension()
421                .unwrap()
422                .to_str()
423                .unwrap()
424            {
425                "wgsl" => params.targets = Some(Targets::wgsl_default()),
426                "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
427                "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
428                e => {
429                    panic!("Unknown extension: {e}");
430                }
431            }
432        }
433
434        params
435    }
436
437    /// Write `data` to a file corresponding to this input file in
438    /// `subdirectory`, with `extension`.
439    pub fn write_output_file(
440        &self,
441        subdirectory: &str,
442        extension: &str,
443        data: impl AsRef<[u8]>,
444        dir_out: &str,
445    ) {
446        let output_path = self.output_path(subdirectory, extension, dir_out);
447        fs::create_dir_all(output_path.parent().unwrap()).unwrap();
448        if let Err(err) = fs::write(&output_path, data) {
449            panic!("Error writing {}: {}", output_path.display(), err);
450        }
451    }
452}