naga_test/
lib.rs

1#![allow(
2    dead_code,
3    unused_imports,
4    reason = "A lot of the code can be unused based on configuration flags; \
5        the corresponding warnings aren't helpful."
6)]
7
8use core::fmt::Write;
9
10use std::{
11    fs,
12    path::{Path, PathBuf},
13};
14
15use naga::compact::KeepUnused;
16use ron::de;
17
18bitflags::bitflags! {
19    #[derive(Clone, Copy, serde::Deserialize)]
20    #[serde(transparent)]
21    #[derive(Debug, Eq, PartialEq)]
22    pub struct Targets: u32 {
23        /// A serialization of the `naga::Module`, in RON format.
24        const IR = 1;
25
26        /// A serialization of the `naga::valid::ModuleInfo`, in RON format.
27        const ANALYSIS = 1 << 1;
28
29        const SPIRV = 1 << 2;
30        const METAL = 1 << 3;
31        const GLSL = 1 << 4;
32        const DOT = 1 << 5;
33        const HLSL = 1 << 6;
34        const WGSL = 1 << 7;
35        const NO_VALIDATION = 1 << 8;
36    }
37}
38
39impl Targets {
40    /// Defaults for `spv` and `glsl` snapshots.
41    pub fn non_wgsl_default() -> Self {
42        Targets::WGSL
43    }
44
45    /// Defaults for `wgsl` snapshots.
46    pub fn wgsl_default() -> Self {
47        Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
48    }
49}
50
51#[derive(serde::Deserialize)]
52pub struct SpvOutVersion(pub u8, pub u8);
53impl Default for SpvOutVersion {
54    fn default() -> Self {
55        SpvOutVersion(1, 1)
56    }
57}
58
59#[derive(serde::Deserialize)]
60pub struct BindingMapSerialization {
61    pub resource_binding: naga::ResourceBinding,
62    pub bind_target: naga::back::spv::BindingInfo,
63}
64
65pub fn deserialize_binding_map<'de, D>(
66    deserializer: D,
67) -> Result<naga::back::spv::BindingMap, D::Error>
68where
69    D: serde::Deserializer<'de>,
70{
71    use serde::Deserialize;
72
73    let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
74    let mut map = naga::back::spv::BindingMap::default();
75    for item in vec {
76        map.insert(item.resource_binding, item.bind_target);
77    }
78    Ok(map)
79}
80
81#[derive(Default, serde::Deserialize)]
82#[serde(default)]
83pub struct WriterSharedOptions {
84    pub mesh_output_validation: bool,
85    pub task_limits: Option<naga::back::TaskDispatchLimits>,
86    pub bounds_checks_policies: naga::proc::BoundsCheckPolicies,
87}
88
89#[derive(Default, serde::Deserialize)]
90#[serde(default)]
91pub struct WgslInParameters {
92    pub parse_doc_comments: bool,
93}
94impl From<&WgslInParameters> for naga::front::wgsl::Options {
95    fn from(value: &WgslInParameters) -> Self {
96        Self {
97            parse_doc_comments: value.parse_doc_comments,
98            capabilities: naga::valid::Capabilities::all(),
99        }
100    }
101}
102
103#[derive(Default, serde::Deserialize)]
104#[serde(default)]
105pub struct SpirvInParameters {
106    pub adjust_coordinate_space: bool,
107}
108impl From<&SpirvInParameters> for naga::front::spv::Options {
109    fn from(value: &SpirvInParameters) -> Self {
110        Self {
111            adjust_coordinate_space: value.adjust_coordinate_space,
112            ..Default::default()
113        }
114    }
115}
116
117#[derive(serde::Deserialize)]
118#[serde(default)]
119pub struct SpirvOutParameters {
120    pub version: SpvOutVersion,
121    pub capabilities: naga::FastHashSet<spirv::Capability>,
122    pub debug: bool,
123    pub adjust_coordinate_space: bool,
124    pub force_point_size: bool,
125    pub clamp_frag_depth: bool,
126    pub separate_entry_points: bool,
127    #[serde(deserialize_with = "deserialize_binding_map")]
128    pub binding_map: naga::back::spv::BindingMap,
129    pub ray_query_initialization_tracking: bool,
130    pub use_storage_input_output_16: bool,
131    pub emit_int_div_checks: bool,
132}
133impl Default for SpirvOutParameters {
134    fn default() -> Self {
135        Self {
136            version: SpvOutVersion::default(),
137            capabilities: naga::FastHashSet::default(),
138            debug: false,
139            adjust_coordinate_space: false,
140            force_point_size: false,
141            clamp_frag_depth: false,
142            separate_entry_points: false,
143            ray_query_initialization_tracking: true,
144            use_storage_input_output_16: true,
145            emit_int_div_checks: true,
146            binding_map: naga::back::spv::BindingMap::default(),
147        }
148    }
149}
150impl SpirvOutParameters {
151    pub fn to_options<'a>(
152        &'a self,
153        shared_info: &WriterSharedOptions,
154        debug_info: Option<naga::back::spv::DebugInfo<'a>>,
155    ) -> naga::back::spv::Options<'a> {
156        use naga::back::spv;
157        let mut flags = spv::WriterFlags::LABEL_VARYINGS;
158        flags.set(spv::WriterFlags::DEBUG, self.debug);
159        flags.set(
160            spv::WriterFlags::ADJUST_COORDINATE_SPACE,
161            self.adjust_coordinate_space,
162        );
163        flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
164        flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
165        naga::back::spv::Options {
166            lang_version: (self.version.0, self.version.1),
167            flags,
168            capabilities: if self.capabilities.is_empty() {
169                None
170            } else {
171                Some(self.capabilities.clone())
172            },
173            bounds_check_policies: shared_info.bounds_checks_policies,
174            fake_missing_bindings: true,
175            binding_map: self.binding_map.clone(),
176            zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
177            force_loop_bounding: true,
178            ray_query_initialization_tracking: true,
179            debug_info,
180            use_storage_input_output_16: self.use_storage_input_output_16,
181            task_dispatch_limits: shared_info.task_limits,
182            mesh_shader_primitive_indices_clamp: shared_info.mesh_output_validation,
183            trace_ray_argument_validation: true,
184            emit_int_div_checks: self.emit_int_div_checks,
185        }
186    }
187}
188
189#[derive(Default, serde::Deserialize)]
190#[serde(default)]
191pub struct WgslOutParameters {
192    pub explicit_types: bool,
193}
194impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
195    fn from(value: &WgslOutParameters) -> Self {
196        let mut flags = Self::empty();
197        flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
198        flags
199    }
200}
201
202#[derive(Default, serde::Deserialize)]
203pub struct FragmentModule {
204    pub path: String,
205    pub entry_point: String,
206}
207
208#[derive(Default, serde::Deserialize)]
209#[serde(default)]
210pub struct Parameters {
211    // -- validation options --
212    //
213    // Capabilities to enable. Defaults to `Capabilities::default()`.
214    pub capabilities: Option<naga::valid::Capabilities>,
215
216    // -- wgsl-in options --
217    #[serde(rename = "wgsl-in")]
218    pub wgsl_in: WgslInParameters,
219
220    // -- spirv-in options --
221    #[serde(rename = "spv-in")]
222    pub spv_in: SpirvInParameters,
223
224    // -- SPIR-V options --
225    pub spv: SpirvOutParameters,
226
227    /// Defaults to [`Targets::non_wgsl_default()`] for `spv` and `glsl` snapshots,
228    /// and [`Targets::wgsl_default()`] for `wgsl` snapshots.
229    pub targets: Option<Targets>,
230
231    // -- MSL options --
232    pub msl: naga::back::msl::Options,
233    #[serde(default)]
234    pub msl_pipeline: naga::back::msl::PipelineOptions,
235
236    // -- GLSL options --
237    pub glsl: naga::back::glsl::Options,
238    pub glsl_exclude_list: naga::FastHashSet<String>,
239    pub glsl_multiview: Option<core::num::NonZeroU32>,
240
241    // -- HLSL options --
242    pub hlsl: naga::back::hlsl::Options,
243
244    // -- WGSL options --
245    pub wgsl: WgslOutParameters,
246
247    // -- General options --
248
249    // Allow backends to be aware of the fragment module.
250    // Is the name of a WGSL file in the same directory as the test file.
251    pub fragment_module: Option<FragmentModule>,
252
253    pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
254    pub pipeline_constants: naga::back::PipelineConstants,
255
256    pub mesh_output_validation: bool,
257    #[serde(default = "default_task_limits")]
258    pub task_limits: Option<naga::back::TaskDispatchLimits>,
259}
260
261fn default_task_limits() -> Option<naga::back::TaskDispatchLimits> {
262    Some(naga::back::TaskDispatchLimits {
263        max_mesh_workgroups_per_dim: 256,
264        max_mesh_workgroups_total: 1024,
265    })
266}
267
268/// Information about a shader input file.
269#[derive(Debug)]
270pub struct Input {
271    /// The subdirectory of `tests/in` to which this input belongs, if any.
272    ///
273    /// If the subdirectory is omitted, we assume that the output goes
274    /// to "wgsl".
275    pub subdirectory: PathBuf,
276
277    /// The input filename name, without a directory.
278    pub file_name: PathBuf,
279
280    /// True if output filenames should add the output extension on top of
281    /// `file_name`'s existing extension, rather than replacing it.
282    ///
283    /// This is used by `convert_snapshots_glsl`, which wants to take input files
284    /// like `210-bevy-2d-shader.frag` and just add `.wgsl` to it, producing
285    /// `210-bevy-2d-shader.frag.wgsl`.
286    pub keep_input_extension: bool,
287}
288
289impl Input {
290    /// Read an input file and its corresponding parameters file.
291    ///
292    /// Given `input`, the relative path of a shader input file, return
293    /// a `Source` value containing its path, code, and parameters.
294    ///
295    /// The `input` path is interpreted relative to the `BASE_DIR_IN`
296    /// subdirectory of the directory given by the `CARGO_MANIFEST_DIR`
297    /// environment variable.
298    pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
299        Input {
300            subdirectory: PathBuf::from(subdirectory),
301            // Don't wipe out any extensions on `name`, as
302            // `with_extension` would do.
303            file_name: PathBuf::from(format!("{name}.{extension}")),
304            keep_input_extension: false,
305        }
306    }
307
308    /// Return an iterator that produces an `Input` for each entry in `subdirectory`.
309    pub fn files_in_dir<'a>(
310        subdirectory: &'a str,
311        file_extensions: &'a [&'a str],
312        dir_in: &str,
313    ) -> impl Iterator<Item = Input> + 'a {
314        let input_directory = Path::new(dir_in).join(subdirectory);
315
316        let entries = match std::fs::read_dir(&input_directory) {
317            Ok(entries) => entries,
318            Err(err) => panic!(
319                "Error opening directory '{}': {}",
320                input_directory.display(),
321                err
322            ),
323        };
324
325        entries.filter_map(move |result| {
326            let entry = result.expect("error reading directory");
327            if !entry.file_type().unwrap().is_file() {
328                return None;
329            }
330
331            let file_name = PathBuf::from(entry.file_name());
332            let extension = file_name
333                .extension()
334                .expect("all files in snapshot input directory should have extensions");
335
336            if !file_extensions.contains(&extension.to_str().unwrap()) {
337                return None;
338            }
339
340            if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
341                if !file_name.to_string_lossy().contains(&pat) {
342                    return None;
343                }
344            }
345
346            let input = Input::new(
347                subdirectory,
348                file_name.file_stem().unwrap().to_str().unwrap(),
349                extension.to_str().unwrap(),
350            );
351            Some(input)
352        })
353    }
354
355    /// Return the path to the input directory.
356    pub fn input_directory(&self, dir_in: &str) -> PathBuf {
357        Path::new(dir_in).join(&self.subdirectory)
358    }
359
360    /// Return the path to the output directory.
361    pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
362        Path::new(dir_out).join(subdirectory)
363    }
364
365    /// Return the path to the input file.
366    pub fn input_path(&self, dir_in: &str) -> PathBuf {
367        let mut input = self.input_directory(dir_in);
368        input.push(&self.file_name);
369        input
370    }
371
372    pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
373        let mut output = Self::output_directory(subdirectory, dir_out);
374        if self.keep_input_extension {
375            let file_name = format!(
376                "{}-{}.{}",
377                self.subdirectory.display(),
378                self.file_name.display(),
379                extension
380            );
381
382            output.push(&file_name);
383        } else {
384            let file_name = format!(
385                "{}-{}",
386                self.subdirectory.display(),
387                self.file_name.display()
388            );
389
390            output.push(&file_name);
391            output.set_extension(extension);
392        }
393        output
394    }
395
396    /// Return the contents of the input file as a string.
397    pub fn read_source(&self, dir_in: &str, print: bool) -> String {
398        if print {
399            println!("Processing '{}'", self.file_name.display());
400        }
401        let input_path = self.input_path(dir_in);
402        match fs::read_to_string(&input_path) {
403            Ok(source) => source,
404            Err(err) => {
405                panic!(
406                    "Couldn't read shader input file `{}`: {}",
407                    input_path.display(),
408                    err
409                );
410            }
411        }
412    }
413
414    /// Return the contents of the input file as a vector of bytes.
415    pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
416        if print {
417            println!("Processing '{}'", self.file_name.display());
418        }
419        let input_path = self.input_path(dir_in);
420        match fs::read(&input_path) {
421            Ok(bytes) => bytes,
422            Err(err) => {
423                panic!(
424                    "Couldn't read shader input file `{}`: {}",
425                    input_path.display(),
426                    err
427                );
428            }
429        }
430    }
431
432    pub fn bytes(&self, dir_in: &str) -> u64 {
433        let input_path = self.input_path(dir_in);
434        std::fs::metadata(input_path).unwrap().len()
435    }
436
437    /// Return this input's parameter file, parsed.
438    pub fn read_parameters(&self, dir_in: &str) -> Parameters {
439        let mut param_path = self.input_path(dir_in);
440        param_path.set_extension("toml");
441        let mut params = match fs::read_to_string(&param_path) {
442            Ok(string) => match toml::de::from_str(&string) {
443                Ok(params) => params,
444                Err(e) => panic!(
445                    "Couldn't parse param file: {} due to: {e}",
446                    param_path.display()
447                ),
448            },
449            Err(_) => Parameters::default(),
450        };
451
452        if params.targets.is_none() {
453            match self
454                .input_path(dir_in)
455                .extension()
456                .unwrap()
457                .to_str()
458                .unwrap()
459            {
460                "wgsl" => params.targets = Some(Targets::wgsl_default()),
461                "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
462                "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
463                e => {
464                    panic!("Unknown extension: {e}");
465                }
466            }
467        }
468
469        params
470    }
471
472    /// Write `data` to a file corresponding to this input file in
473    /// `subdirectory`, with `extension`.
474    pub fn write_output_file(
475        &self,
476        subdirectory: &str,
477        extension: &str,
478        data: impl AsRef<[u8]>,
479        dir_out: &str,
480    ) {
481        let output_path = self.output_path(subdirectory, extension, dir_out);
482        fs::create_dir_all(output_path.parent().unwrap()).unwrap();
483        if let Err(err) = fs::write(&output_path, data) {
484            panic!("Error writing {}: {}", output_path.display(), err);
485        }
486    }
487}