naga_test/
lib.rs

1// A lot of the code can be unused based on configuration flags,
2// the corresponding warnings aren't helpful.
3#![allow(dead_code, unused_imports)]
4
5use core::fmt::Write;
6
7use std::{
8    fs,
9    path::{Path, PathBuf},
10};
11
12use naga::compact::KeepUnused;
13use ron::de;
14
15bitflags::bitflags! {
16    #[derive(Clone, Copy, serde::Deserialize)]
17    #[serde(transparent)]
18    #[derive(Debug, Eq, PartialEq)]
19    pub struct Targets: u32 {
20        /// A serialization of the `naga::Module`, in RON format.
21        const IR = 1;
22
23        /// A serialization of the `naga::valid::ModuleInfo`, in RON format.
24        const ANALYSIS = 1 << 1;
25
26        const SPIRV = 1 << 2;
27        const METAL = 1 << 3;
28        const GLSL = 1 << 4;
29        const DOT = 1 << 5;
30        const HLSL = 1 << 6;
31        const WGSL = 1 << 7;
32        const NO_VALIDATION = 1 << 8;
33    }
34}
35
36impl Targets {
37    /// Defaults for `spv` and `glsl` snapshots.
38    pub fn non_wgsl_default() -> Self {
39        Targets::WGSL
40    }
41
42    /// Defaults for `wgsl` snapshots.
43    pub fn wgsl_default() -> Self {
44        Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
45    }
46}
47
48#[derive(serde::Deserialize)]
49pub struct SpvOutVersion(pub u8, pub u8);
50impl Default for SpvOutVersion {
51    fn default() -> Self {
52        SpvOutVersion(1, 1)
53    }
54}
55
56#[derive(serde::Deserialize)]
57pub struct BindingMapSerialization {
58    pub resource_binding: naga::ResourceBinding,
59    pub bind_target: naga::back::spv::BindingInfo,
60}
61
62pub fn deserialize_binding_map<'de, D>(
63    deserializer: D,
64) -> Result<naga::back::spv::BindingMap, D::Error>
65where
66    D: serde::Deserializer<'de>,
67{
68    use serde::Deserialize;
69
70    let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
71    let mut map = naga::back::spv::BindingMap::default();
72    for item in vec {
73        map.insert(item.resource_binding, item.bind_target);
74    }
75    Ok(map)
76}
77
78#[derive(Default, serde::Deserialize)]
79#[serde(default)]
80pub struct WgslInParameters {
81    pub parse_doc_comments: bool,
82}
83impl From<&WgslInParameters> for naga::front::wgsl::Options {
84    fn from(value: &WgslInParameters) -> Self {
85        Self {
86            parse_doc_comments: value.parse_doc_comments,
87        }
88    }
89}
90
91#[derive(Default, serde::Deserialize)]
92#[serde(default)]
93pub struct SpirvInParameters {
94    pub adjust_coordinate_space: bool,
95}
96impl From<&SpirvInParameters> for naga::front::spv::Options {
97    fn from(value: &SpirvInParameters) -> Self {
98        Self {
99            adjust_coordinate_space: value.adjust_coordinate_space,
100            ..Default::default()
101        }
102    }
103}
104
105#[derive(serde::Deserialize)]
106#[serde(default)]
107pub struct SpirvOutParameters {
108    pub version: SpvOutVersion,
109    pub capabilities: naga::FastHashSet<spirv::Capability>,
110    pub debug: bool,
111    pub adjust_coordinate_space: bool,
112    pub force_point_size: bool,
113    pub clamp_frag_depth: bool,
114    pub separate_entry_points: bool,
115    #[serde(deserialize_with = "deserialize_binding_map")]
116    pub binding_map: naga::back::spv::BindingMap,
117    pub use_storage_input_output_16: bool,
118}
119impl Default for SpirvOutParameters {
120    fn default() -> Self {
121        Self {
122            version: SpvOutVersion::default(),
123            capabilities: naga::FastHashSet::default(),
124            debug: false,
125            adjust_coordinate_space: false,
126            force_point_size: false,
127            clamp_frag_depth: false,
128            separate_entry_points: false,
129            use_storage_input_output_16: true,
130            binding_map: naga::back::spv::BindingMap::default(),
131        }
132    }
133}
134impl SpirvOutParameters {
135    pub fn to_options<'a>(
136        &'a self,
137        bounds_check_policies: naga::proc::BoundsCheckPolicies,
138        debug_info: Option<naga::back::spv::DebugInfo<'a>>,
139    ) -> naga::back::spv::Options<'a> {
140        use naga::back::spv;
141        let mut flags = spv::WriterFlags::LABEL_VARYINGS;
142        flags.set(spv::WriterFlags::DEBUG, self.debug);
143        flags.set(
144            spv::WriterFlags::ADJUST_COORDINATE_SPACE,
145            self.adjust_coordinate_space,
146        );
147        flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
148        flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
149        naga::back::spv::Options {
150            lang_version: (self.version.0, self.version.1),
151            flags,
152            capabilities: if self.capabilities.is_empty() {
153                None
154            } else {
155                Some(self.capabilities.clone())
156            },
157            bounds_check_policies,
158            fake_missing_bindings: true,
159            binding_map: self.binding_map.clone(),
160            zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
161            force_loop_bounding: true,
162            debug_info,
163            use_storage_input_output_16: self.use_storage_input_output_16,
164        }
165    }
166}
167
168#[derive(Default, serde::Deserialize)]
169#[serde(default)]
170pub struct WgslOutParameters {
171    pub explicit_types: bool,
172}
173impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
174    fn from(value: &WgslOutParameters) -> Self {
175        let mut flags = Self::empty();
176        flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
177        flags
178    }
179}
180
181#[derive(Default, serde::Deserialize)]
182pub struct FragmentModule {
183    pub path: String,
184    pub entry_point: String,
185}
186
187#[derive(Default, serde::Deserialize)]
188#[serde(default)]
189pub struct Parameters {
190    // -- GOD MODE --
191    pub god_mode: bool,
192
193    // -- wgsl-in options --
194    #[serde(rename = "wgsl-in")]
195    pub wgsl_in: WgslInParameters,
196
197    // -- spirv-in options --
198    #[serde(rename = "spv-in")]
199    pub spv_in: SpirvInParameters,
200
201    // -- SPIR-V options --
202    pub spv: SpirvOutParameters,
203
204    /// Defaults to [`Targets::non_wgsl_default()`] for `spv` and `glsl` snapshots,
205    /// and [`Targets::wgsl_default()`] for `wgsl` snapshots.
206    pub targets: Option<Targets>,
207
208    // -- MSL options --
209    pub msl: naga::back::msl::Options,
210    #[serde(default)]
211    pub msl_pipeline: naga::back::msl::PipelineOptions,
212
213    // -- GLSL options --
214    pub glsl: naga::back::glsl::Options,
215    pub glsl_exclude_list: naga::FastHashSet<String>,
216    pub glsl_multiview: Option<core::num::NonZeroU32>,
217
218    // -- HLSL options --
219    pub hlsl: naga::back::hlsl::Options,
220
221    // -- WGSL options --
222    pub wgsl: WgslOutParameters,
223
224    // -- General options --
225
226    // Allow backends to be aware of the fragment module.
227    // Is the name of a WGSL file in the same directory as the test file.
228    pub fragment_module: Option<FragmentModule>,
229
230    pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
231    pub pipeline_constants: naga::back::PipelineConstants,
232}
233
234/// Information about a shader input file.
235#[derive(Debug)]
236pub struct Input {
237    /// The subdirectory of `tests/in` to which this input belongs, if any.
238    ///
239    /// If the subdirectory is omitted, we assume that the output goes
240    /// to "wgsl".
241    pub subdirectory: PathBuf,
242
243    /// The input filename name, without a directory.
244    pub file_name: PathBuf,
245
246    /// True if output filenames should add the output extension on top of
247    /// `file_name`'s existing extension, rather than replacing it.
248    ///
249    /// This is used by `convert_snapshots_glsl`, which wants to take input files
250    /// like `210-bevy-2d-shader.frag` and just add `.wgsl` to it, producing
251    /// `210-bevy-2d-shader.frag.wgsl`.
252    pub keep_input_extension: bool,
253}
254
255impl Input {
256    /// Read an input file and its corresponding parameters file.
257    ///
258    /// Given `input`, the relative path of a shader input file, return
259    /// a `Source` value containing its path, code, and parameters.
260    ///
261    /// The `input` path is interpreted relative to the `BASE_DIR_IN`
262    /// subdirectory of the directory given by the `CARGO_MANIFEST_DIR`
263    /// environment variable.
264    pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
265        Input {
266            subdirectory: PathBuf::from(subdirectory),
267            // Don't wipe out any extensions on `name`, as
268            // `with_extension` would do.
269            file_name: PathBuf::from(format!("{name}.{extension}")),
270            keep_input_extension: false,
271        }
272    }
273
274    /// Return an iterator that produces an `Input` for each entry in `subdirectory`.
275    pub fn files_in_dir<'a>(
276        subdirectory: &'a str,
277        file_extensions: &'a [&'a str],
278        dir_in: &str,
279    ) -> impl Iterator<Item = Input> + 'a {
280        let input_directory = Path::new(dir_in).join(subdirectory);
281
282        let entries = match std::fs::read_dir(&input_directory) {
283            Ok(entries) => entries,
284            Err(err) => panic!(
285                "Error opening directory '{}': {}",
286                input_directory.display(),
287                err
288            ),
289        };
290
291        entries.filter_map(move |result| {
292            let entry = result.expect("error reading directory");
293            if !entry.file_type().unwrap().is_file() {
294                return None;
295            }
296
297            let file_name = PathBuf::from(entry.file_name());
298            let extension = file_name
299                .extension()
300                .expect("all files in snapshot input directory should have extensions");
301
302            if !file_extensions.contains(&extension.to_str().unwrap()) {
303                return None;
304            }
305
306            if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
307                if !file_name.to_string_lossy().contains(&pat) {
308                    return None;
309                }
310            }
311
312            let input = Input::new(
313                subdirectory,
314                file_name.file_stem().unwrap().to_str().unwrap(),
315                extension.to_str().unwrap(),
316            );
317            Some(input)
318        })
319    }
320
321    /// Return the path to the input directory.
322    pub fn input_directory(&self, dir_in: &str) -> PathBuf {
323        Path::new(dir_in).join(&self.subdirectory)
324    }
325
326    /// Return the path to the output directory.
327    pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
328        Path::new(dir_out).join(subdirectory)
329    }
330
331    /// Return the path to the input file.
332    pub fn input_path(&self, dir_in: &str) -> PathBuf {
333        let mut input = self.input_directory(dir_in);
334        input.push(&self.file_name);
335        input
336    }
337
338    pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
339        let mut output = Self::output_directory(subdirectory, dir_out);
340        if self.keep_input_extension {
341            let file_name = format!(
342                "{}-{}.{}",
343                self.subdirectory.display(),
344                self.file_name.display(),
345                extension
346            );
347
348            output.push(&file_name);
349        } else {
350            let file_name = format!(
351                "{}-{}",
352                self.subdirectory.display(),
353                self.file_name.display()
354            );
355
356            output.push(&file_name);
357            output.set_extension(extension);
358        }
359        output
360    }
361
362    /// Return the contents of the input file as a string.
363    pub fn read_source(&self, dir_in: &str, print: bool) -> String {
364        if print {
365            println!("Processing '{}'", self.file_name.display());
366        }
367        let input_path = self.input_path(dir_in);
368        match fs::read_to_string(&input_path) {
369            Ok(source) => source,
370            Err(err) => {
371                panic!(
372                    "Couldn't read shader input file `{}`: {}",
373                    input_path.display(),
374                    err
375                );
376            }
377        }
378    }
379
380    /// Return the contents of the input file as a vector of bytes.
381    pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
382        if print {
383            println!("Processing '{}'", self.file_name.display());
384        }
385        let input_path = self.input_path(dir_in);
386        match fs::read(&input_path) {
387            Ok(bytes) => bytes,
388            Err(err) => {
389                panic!(
390                    "Couldn't read shader input file `{}`: {}",
391                    input_path.display(),
392                    err
393                );
394            }
395        }
396    }
397
398    pub fn bytes(&self, dir_in: &str) -> u64 {
399        let input_path = self.input_path(dir_in);
400        std::fs::metadata(input_path).unwrap().len()
401    }
402
403    /// Return this input's parameter file, parsed.
404    pub fn read_parameters(&self, dir_in: &str) -> Parameters {
405        let mut param_path = self.input_path(dir_in);
406        param_path.set_extension("toml");
407        let mut params = match fs::read_to_string(&param_path) {
408            Ok(string) => match toml::de::from_str(&string) {
409                Ok(params) => params,
410                Err(e) => panic!(
411                    "Couldn't parse param file: {} due to: {e}",
412                    param_path.display()
413                ),
414            },
415            Err(_) => Parameters::default(),
416        };
417
418        if params.targets.is_none() {
419            match self
420                .input_path(dir_in)
421                .extension()
422                .unwrap()
423                .to_str()
424                .unwrap()
425            {
426                "wgsl" => params.targets = Some(Targets::wgsl_default()),
427                "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
428                "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
429                e => {
430                    panic!("Unknown extension: {e}");
431                }
432            }
433        }
434
435        params
436    }
437
438    /// Write `data` to a file corresponding to this input file in
439    /// `subdirectory`, with `extension`.
440    pub fn write_output_file(
441        &self,
442        subdirectory: &str,
443        extension: &str,
444        data: impl AsRef<[u8]>,
445        dir_out: &str,
446    ) {
447        let output_path = self.output_path(subdirectory, extension, dir_out);
448        fs::create_dir_all(output_path.parent().unwrap()).unwrap();
449        if let Err(err) = fs::write(&output_path, data) {
450            panic!("Error writing {}: {}", output_path.display(), err);
451        }
452    }
453}