1#![allow(dead_code, unused_imports)]
4
5use core::fmt::Write;
6
7use std::{
8 fs,
9 path::{Path, PathBuf},
10};
11
12use naga::compact::KeepUnused;
13use ron::de;
14
15bitflags::bitflags! {
16 #[derive(Clone, Copy, serde::Deserialize)]
17 #[serde(transparent)]
18 #[derive(Debug, Eq, PartialEq)]
19 pub struct Targets: u32 {
20 const IR = 1;
22
23 const ANALYSIS = 1 << 1;
25
26 const SPIRV = 1 << 2;
27 const METAL = 1 << 3;
28 const GLSL = 1 << 4;
29 const DOT = 1 << 5;
30 const HLSL = 1 << 6;
31 const WGSL = 1 << 7;
32 const NO_VALIDATION = 1 << 8;
33 }
34}
35
36impl Targets {
37 pub fn non_wgsl_default() -> Self {
39 Targets::WGSL
40 }
41
42 pub fn wgsl_default() -> Self {
44 Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
45 }
46}
47
48#[derive(serde::Deserialize)]
49pub struct SpvOutVersion(pub u8, pub u8);
50impl Default for SpvOutVersion {
51 fn default() -> Self {
52 SpvOutVersion(1, 1)
53 }
54}
55
56#[derive(serde::Deserialize)]
57pub struct BindingMapSerialization {
58 pub resource_binding: naga::ResourceBinding,
59 pub bind_target: naga::back::spv::BindingInfo,
60}
61
62pub fn deserialize_binding_map<'de, D>(
63 deserializer: D,
64) -> Result<naga::back::spv::BindingMap, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 use serde::Deserialize;
69
70 let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
71 let mut map = naga::back::spv::BindingMap::default();
72 for item in vec {
73 map.insert(item.resource_binding, item.bind_target);
74 }
75 Ok(map)
76}
77
78#[derive(Default, serde::Deserialize)]
79#[serde(default)]
80pub struct WgslInParameters {
81 pub parse_doc_comments: bool,
82}
83impl From<&WgslInParameters> for naga::front::wgsl::Options {
84 fn from(value: &WgslInParameters) -> Self {
85 Self {
86 parse_doc_comments: value.parse_doc_comments,
87 }
88 }
89}
90
91#[derive(Default, serde::Deserialize)]
92#[serde(default)]
93pub struct SpirvInParameters {
94 pub adjust_coordinate_space: bool,
95}
96impl From<&SpirvInParameters> for naga::front::spv::Options {
97 fn from(value: &SpirvInParameters) -> Self {
98 Self {
99 adjust_coordinate_space: value.adjust_coordinate_space,
100 ..Default::default()
101 }
102 }
103}
104
105#[derive(serde::Deserialize)]
106#[serde(default)]
107pub struct SpirvOutParameters {
108 pub version: SpvOutVersion,
109 pub capabilities: naga::FastHashSet<spirv::Capability>,
110 pub debug: bool,
111 pub adjust_coordinate_space: bool,
112 pub force_point_size: bool,
113 pub clamp_frag_depth: bool,
114 pub separate_entry_points: bool,
115 #[serde(deserialize_with = "deserialize_binding_map")]
116 pub binding_map: naga::back::spv::BindingMap,
117 pub use_storage_input_output_16: bool,
118}
119impl Default for SpirvOutParameters {
120 fn default() -> Self {
121 Self {
122 version: SpvOutVersion::default(),
123 capabilities: naga::FastHashSet::default(),
124 debug: false,
125 adjust_coordinate_space: false,
126 force_point_size: false,
127 clamp_frag_depth: false,
128 separate_entry_points: false,
129 use_storage_input_output_16: true,
130 binding_map: naga::back::spv::BindingMap::default(),
131 }
132 }
133}
134impl SpirvOutParameters {
135 pub fn to_options<'a>(
136 &'a self,
137 bounds_check_policies: naga::proc::BoundsCheckPolicies,
138 debug_info: Option<naga::back::spv::DebugInfo<'a>>,
139 ) -> naga::back::spv::Options<'a> {
140 use naga::back::spv;
141 let mut flags = spv::WriterFlags::LABEL_VARYINGS;
142 flags.set(spv::WriterFlags::DEBUG, self.debug);
143 flags.set(
144 spv::WriterFlags::ADJUST_COORDINATE_SPACE,
145 self.adjust_coordinate_space,
146 );
147 flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
148 flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
149 naga::back::spv::Options {
150 lang_version: (self.version.0, self.version.1),
151 flags,
152 capabilities: if self.capabilities.is_empty() {
153 None
154 } else {
155 Some(self.capabilities.clone())
156 },
157 bounds_check_policies,
158 fake_missing_bindings: true,
159 binding_map: self.binding_map.clone(),
160 zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
161 force_loop_bounding: true,
162 debug_info,
163 use_storage_input_output_16: self.use_storage_input_output_16,
164 }
165 }
166}
167
168#[derive(Default, serde::Deserialize)]
169#[serde(default)]
170pub struct WgslOutParameters {
171 pub explicit_types: bool,
172}
173impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
174 fn from(value: &WgslOutParameters) -> Self {
175 let mut flags = Self::empty();
176 flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
177 flags
178 }
179}
180
181#[derive(Default, serde::Deserialize)]
182pub struct FragmentModule {
183 pub path: String,
184 pub entry_point: String,
185}
186
187#[derive(Default, serde::Deserialize)]
188#[serde(default)]
189pub struct Parameters {
190 pub god_mode: bool,
192
193 #[serde(rename = "wgsl-in")]
195 pub wgsl_in: WgslInParameters,
196
197 #[serde(rename = "spv-in")]
199 pub spv_in: SpirvInParameters,
200
201 pub spv: SpirvOutParameters,
203
204 pub targets: Option<Targets>,
207
208 pub msl: naga::back::msl::Options,
210 #[serde(default)]
211 pub msl_pipeline: naga::back::msl::PipelineOptions,
212
213 pub glsl: naga::back::glsl::Options,
215 pub glsl_exclude_list: naga::FastHashSet<String>,
216 pub glsl_multiview: Option<core::num::NonZeroU32>,
217
218 pub hlsl: naga::back::hlsl::Options,
220
221 pub wgsl: WgslOutParameters,
223
224 pub fragment_module: Option<FragmentModule>,
229
230 pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
231 pub pipeline_constants: naga::back::PipelineConstants,
232}
233
234#[derive(Debug)]
236pub struct Input {
237 pub subdirectory: PathBuf,
242
243 pub file_name: PathBuf,
245
246 pub keep_input_extension: bool,
253}
254
255impl Input {
256 pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
265 Input {
266 subdirectory: PathBuf::from(subdirectory),
267 file_name: PathBuf::from(format!("{name}.{extension}")),
270 keep_input_extension: false,
271 }
272 }
273
274 pub fn files_in_dir<'a>(
276 subdirectory: &'a str,
277 file_extensions: &'a [&'a str],
278 dir_in: &str,
279 ) -> impl Iterator<Item = Input> + 'a {
280 let input_directory = Path::new(dir_in).join(subdirectory);
281
282 let entries = match std::fs::read_dir(&input_directory) {
283 Ok(entries) => entries,
284 Err(err) => panic!(
285 "Error opening directory '{}': {}",
286 input_directory.display(),
287 err
288 ),
289 };
290
291 entries.filter_map(move |result| {
292 let entry = result.expect("error reading directory");
293 if !entry.file_type().unwrap().is_file() {
294 return None;
295 }
296
297 let file_name = PathBuf::from(entry.file_name());
298 let extension = file_name
299 .extension()
300 .expect("all files in snapshot input directory should have extensions");
301
302 if !file_extensions.contains(&extension.to_str().unwrap()) {
303 return None;
304 }
305
306 if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
307 if !file_name.to_string_lossy().contains(&pat) {
308 return None;
309 }
310 }
311
312 let input = Input::new(
313 subdirectory,
314 file_name.file_stem().unwrap().to_str().unwrap(),
315 extension.to_str().unwrap(),
316 );
317 Some(input)
318 })
319 }
320
321 pub fn input_directory(&self, dir_in: &str) -> PathBuf {
323 Path::new(dir_in).join(&self.subdirectory)
324 }
325
326 pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
328 Path::new(dir_out).join(subdirectory)
329 }
330
331 pub fn input_path(&self, dir_in: &str) -> PathBuf {
333 let mut input = self.input_directory(dir_in);
334 input.push(&self.file_name);
335 input
336 }
337
338 pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
339 let mut output = Self::output_directory(subdirectory, dir_out);
340 if self.keep_input_extension {
341 let file_name = format!(
342 "{}-{}.{}",
343 self.subdirectory.display(),
344 self.file_name.display(),
345 extension
346 );
347
348 output.push(&file_name);
349 } else {
350 let file_name = format!(
351 "{}-{}",
352 self.subdirectory.display(),
353 self.file_name.display()
354 );
355
356 output.push(&file_name);
357 output.set_extension(extension);
358 }
359 output
360 }
361
362 pub fn read_source(&self, dir_in: &str, print: bool) -> String {
364 if print {
365 println!("Processing '{}'", self.file_name.display());
366 }
367 let input_path = self.input_path(dir_in);
368 match fs::read_to_string(&input_path) {
369 Ok(source) => source,
370 Err(err) => {
371 panic!(
372 "Couldn't read shader input file `{}`: {}",
373 input_path.display(),
374 err
375 );
376 }
377 }
378 }
379
380 pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
382 if print {
383 println!("Processing '{}'", self.file_name.display());
384 }
385 let input_path = self.input_path(dir_in);
386 match fs::read(&input_path) {
387 Ok(bytes) => bytes,
388 Err(err) => {
389 panic!(
390 "Couldn't read shader input file `{}`: {}",
391 input_path.display(),
392 err
393 );
394 }
395 }
396 }
397
398 pub fn bytes(&self, dir_in: &str) -> u64 {
399 let input_path = self.input_path(dir_in);
400 std::fs::metadata(input_path).unwrap().len()
401 }
402
403 pub fn read_parameters(&self, dir_in: &str) -> Parameters {
405 let mut param_path = self.input_path(dir_in);
406 param_path.set_extension("toml");
407 let mut params = match fs::read_to_string(¶m_path) {
408 Ok(string) => match toml::de::from_str(&string) {
409 Ok(params) => params,
410 Err(e) => panic!(
411 "Couldn't parse param file: {} due to: {e}",
412 param_path.display()
413 ),
414 },
415 Err(_) => Parameters::default(),
416 };
417
418 if params.targets.is_none() {
419 match self
420 .input_path(dir_in)
421 .extension()
422 .unwrap()
423 .to_str()
424 .unwrap()
425 {
426 "wgsl" => params.targets = Some(Targets::wgsl_default()),
427 "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
428 "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
429 e => {
430 panic!("Unknown extension: {e}");
431 }
432 }
433 }
434
435 params
436 }
437
438 pub fn write_output_file(
441 &self,
442 subdirectory: &str,
443 extension: &str,
444 data: impl AsRef<[u8]>,
445 dir_out: &str,
446 ) {
447 let output_path = self.output_path(subdirectory, extension, dir_out);
448 fs::create_dir_all(output_path.parent().unwrap()).unwrap();
449 if let Err(err) = fs::write(&output_path, data) {
450 panic!("Error writing {}: {}", output_path.display(), err);
451 }
452 }
453}