1#![allow(dead_code, unused_imports)]
4
5use core::fmt::Write;
6
7use std::{
8 fs,
9 path::{Path, PathBuf},
10};
11
12use naga::compact::KeepUnused;
13use ron::de;
14
15bitflags::bitflags! {
16 #[derive(Clone, Copy, serde::Deserialize)]
17 #[serde(transparent)]
18 #[derive(Debug, Eq, PartialEq)]
19 pub struct Targets: u32 {
20 const IR = 1;
22
23 const ANALYSIS = 1 << 1;
25
26 const SPIRV = 1 << 2;
27 const METAL = 1 << 3;
28 const GLSL = 1 << 4;
29 const DOT = 1 << 5;
30 const HLSL = 1 << 6;
31 const WGSL = 1 << 7;
32 const NO_VALIDATION = 1 << 8;
33 }
34}
35
36impl Targets {
37 pub fn non_wgsl_default() -> Self {
39 Targets::WGSL
40 }
41
42 pub fn wgsl_default() -> Self {
44 Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
45 }
46}
47
48#[derive(serde::Deserialize)]
49pub struct SpvOutVersion(pub u8, pub u8);
50impl Default for SpvOutVersion {
51 fn default() -> Self {
52 SpvOutVersion(1, 1)
53 }
54}
55
56#[derive(serde::Deserialize)]
57pub struct BindingMapSerialization {
58 pub resource_binding: naga::ResourceBinding,
59 pub bind_target: naga::back::spv::BindingInfo,
60}
61
62pub fn deserialize_binding_map<'de, D>(
63 deserializer: D,
64) -> Result<naga::back::spv::BindingMap, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 use serde::Deserialize;
69
70 let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
71 let mut map = naga::back::spv::BindingMap::default();
72 for item in vec {
73 map.insert(item.resource_binding, item.bind_target);
74 }
75 Ok(map)
76}
77
78#[derive(Default, serde::Deserialize)]
79#[serde(default)]
80pub struct WgslInParameters {
81 pub parse_doc_comments: bool,
82}
83impl From<&WgslInParameters> for naga::front::wgsl::Options {
84 fn from(value: &WgslInParameters) -> Self {
85 Self {
86 parse_doc_comments: value.parse_doc_comments,
87 }
88 }
89}
90
91#[derive(Default, serde::Deserialize)]
92#[serde(default)]
93pub struct SpirvInParameters {
94 pub adjust_coordinate_space: bool,
95}
96impl From<&SpirvInParameters> for naga::front::spv::Options {
97 fn from(value: &SpirvInParameters) -> Self {
98 Self {
99 adjust_coordinate_space: value.adjust_coordinate_space,
100 ..Default::default()
101 }
102 }
103}
104
105#[derive(serde::Deserialize)]
106#[serde(default)]
107pub struct SpirvOutParameters {
108 pub version: SpvOutVersion,
109 pub capabilities: naga::FastHashSet<spirv::Capability>,
110 pub debug: bool,
111 pub adjust_coordinate_space: bool,
112 pub force_point_size: bool,
113 pub clamp_frag_depth: bool,
114 pub separate_entry_points: bool,
115 #[serde(deserialize_with = "deserialize_binding_map")]
116 pub binding_map: naga::back::spv::BindingMap,
117 pub use_storage_input_output_16: bool,
118}
119impl Default for SpirvOutParameters {
120 fn default() -> Self {
121 Self {
122 version: SpvOutVersion::default(),
123 capabilities: naga::FastHashSet::default(),
124 debug: false,
125 adjust_coordinate_space: false,
126 force_point_size: false,
127 clamp_frag_depth: false,
128 separate_entry_points: false,
129 use_storage_input_output_16: true,
130 binding_map: naga::back::spv::BindingMap::default(),
131 }
132 }
133}
134impl SpirvOutParameters {
135 pub fn to_options<'a>(
136 &'a self,
137 bounds_check_policies: naga::proc::BoundsCheckPolicies,
138 debug_info: Option<naga::back::spv::DebugInfo<'a>>,
139 ) -> naga::back::spv::Options<'a> {
140 use naga::back::spv;
141 let mut flags = spv::WriterFlags::LABEL_VARYINGS;
142 flags.set(spv::WriterFlags::DEBUG, self.debug);
143 flags.set(
144 spv::WriterFlags::ADJUST_COORDINATE_SPACE,
145 self.adjust_coordinate_space,
146 );
147 flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
148 flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
149 naga::back::spv::Options {
150 lang_version: (self.version.0, self.version.1),
151 flags,
152 capabilities: if self.capabilities.is_empty() {
153 None
154 } else {
155 Some(self.capabilities.clone())
156 },
157 bounds_check_policies,
158 binding_map: self.binding_map.clone(),
159 zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
160 force_loop_bounding: true,
161 debug_info,
162 use_storage_input_output_16: self.use_storage_input_output_16,
163 }
164 }
165}
166
167#[derive(Default, serde::Deserialize)]
168#[serde(default)]
169pub struct WgslOutParameters {
170 pub explicit_types: bool,
171}
172impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
173 fn from(value: &WgslOutParameters) -> Self {
174 let mut flags = Self::empty();
175 flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
176 flags
177 }
178}
179
180#[derive(Default, serde::Deserialize)]
181pub struct FragmentModule {
182 pub path: String,
183 pub entry_point: String,
184}
185
186#[derive(Default, serde::Deserialize)]
187#[serde(default)]
188pub struct Parameters {
189 pub god_mode: bool,
191
192 #[serde(rename = "wgsl-in")]
194 pub wgsl_in: WgslInParameters,
195
196 #[serde(rename = "spv-in")]
198 pub spv_in: SpirvInParameters,
199
200 pub spv: SpirvOutParameters,
202
203 pub targets: Option<Targets>,
206
207 pub msl: naga::back::msl::Options,
209 #[serde(default)]
210 pub msl_pipeline: naga::back::msl::PipelineOptions,
211
212 pub glsl: naga::back::glsl::Options,
214 pub glsl_exclude_list: naga::FastHashSet<String>,
215 pub glsl_multiview: Option<core::num::NonZeroU32>,
216
217 pub hlsl: naga::back::hlsl::Options,
219
220 pub wgsl: WgslOutParameters,
222
223 pub fragment_module: Option<FragmentModule>,
228
229 pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
230 pub pipeline_constants: naga::back::PipelineConstants,
231}
232
233#[derive(Debug)]
235pub struct Input {
236 pub subdirectory: PathBuf,
241
242 pub file_name: PathBuf,
244
245 pub keep_input_extension: bool,
252}
253
254impl Input {
255 pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
264 Input {
265 subdirectory: PathBuf::from(subdirectory),
266 file_name: PathBuf::from(format!("{name}.{extension}")),
269 keep_input_extension: false,
270 }
271 }
272
273 pub fn files_in_dir<'a>(
275 subdirectory: &'a str,
276 file_extensions: &'a [&'a str],
277 dir_in: &str,
278 ) -> impl Iterator<Item = Input> + 'a {
279 let input_directory = Path::new(dir_in).join(subdirectory);
280
281 let entries = match std::fs::read_dir(&input_directory) {
282 Ok(entries) => entries,
283 Err(err) => panic!(
284 "Error opening directory '{}': {}",
285 input_directory.display(),
286 err
287 ),
288 };
289
290 entries.filter_map(move |result| {
291 let entry = result.expect("error reading directory");
292 if !entry.file_type().unwrap().is_file() {
293 return None;
294 }
295
296 let file_name = PathBuf::from(entry.file_name());
297 let extension = file_name
298 .extension()
299 .expect("all files in snapshot input directory should have extensions");
300
301 if !file_extensions.contains(&extension.to_str().unwrap()) {
302 return None;
303 }
304
305 if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
306 if !file_name.to_string_lossy().contains(&pat) {
307 return None;
308 }
309 }
310
311 let input = Input::new(
312 subdirectory,
313 file_name.file_stem().unwrap().to_str().unwrap(),
314 extension.to_str().unwrap(),
315 );
316 Some(input)
317 })
318 }
319
320 pub fn input_directory(&self, dir_in: &str) -> PathBuf {
322 Path::new(dir_in).join(&self.subdirectory)
323 }
324
325 pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
327 Path::new(dir_out).join(subdirectory)
328 }
329
330 pub fn input_path(&self, dir_in: &str) -> PathBuf {
332 let mut input = self.input_directory(dir_in);
333 input.push(&self.file_name);
334 input
335 }
336
337 pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
338 let mut output = Self::output_directory(subdirectory, dir_out);
339 if self.keep_input_extension {
340 let file_name = format!(
341 "{}-{}.{}",
342 self.subdirectory.display(),
343 self.file_name.display(),
344 extension
345 );
346
347 output.push(&file_name);
348 } else {
349 let file_name = format!(
350 "{}-{}",
351 self.subdirectory.display(),
352 self.file_name.display()
353 );
354
355 output.push(&file_name);
356 output.set_extension(extension);
357 }
358 output
359 }
360
361 pub fn read_source(&self, dir_in: &str, print: bool) -> String {
363 if print {
364 println!("Processing '{}'", self.file_name.display());
365 }
366 let input_path = self.input_path(dir_in);
367 match fs::read_to_string(&input_path) {
368 Ok(source) => source,
369 Err(err) => {
370 panic!(
371 "Couldn't read shader input file `{}`: {}",
372 input_path.display(),
373 err
374 );
375 }
376 }
377 }
378
379 pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
381 if print {
382 println!("Processing '{}'", self.file_name.display());
383 }
384 let input_path = self.input_path(dir_in);
385 match fs::read(&input_path) {
386 Ok(bytes) => bytes,
387 Err(err) => {
388 panic!(
389 "Couldn't read shader input file `{}`: {}",
390 input_path.display(),
391 err
392 );
393 }
394 }
395 }
396
397 pub fn bytes(&self, dir_in: &str) -> u64 {
398 let input_path = self.input_path(dir_in);
399 std::fs::metadata(input_path).unwrap().len()
400 }
401
402 pub fn read_parameters(&self, dir_in: &str) -> Parameters {
404 let mut param_path = self.input_path(dir_in);
405 param_path.set_extension("toml");
406 let mut params = match fs::read_to_string(¶m_path) {
407 Ok(string) => match toml::de::from_str(&string) {
408 Ok(params) => params,
409 Err(e) => panic!(
410 "Couldn't parse param file: {} due to: {e}",
411 param_path.display()
412 ),
413 },
414 Err(_) => Parameters::default(),
415 };
416
417 if params.targets.is_none() {
418 match self
419 .input_path(dir_in)
420 .extension()
421 .unwrap()
422 .to_str()
423 .unwrap()
424 {
425 "wgsl" => params.targets = Some(Targets::wgsl_default()),
426 "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
427 "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
428 e => {
429 panic!("Unknown extension: {e}");
430 }
431 }
432 }
433
434 params
435 }
436
437 pub fn write_output_file(
440 &self,
441 subdirectory: &str,
442 extension: &str,
443 data: impl AsRef<[u8]>,
444 dir_out: &str,
445 ) {
446 let output_path = self.output_path(subdirectory, extension, dir_out);
447 fs::create_dir_all(output_path.parent().unwrap()).unwrap();
448 if let Err(err) = fs::write(&output_path, data) {
449 panic!("Error writing {}: {}", output_path.display(), err);
450 }
451 }
452}