1#![allow(
2 dead_code,
3 unused_imports,
4 reason = "A lot of the code can be unused based on configuration flags; \
5 the corresponding warnings aren't helpful."
6)]
7
8use core::fmt::Write;
9
10use std::{
11 fs,
12 path::{Path, PathBuf},
13};
14
15use naga::compact::KeepUnused;
16use ron::de;
17
18bitflags::bitflags! {
19 #[derive(Clone, Copy, serde::Deserialize)]
20 #[serde(transparent)]
21 #[derive(Debug, Eq, PartialEq)]
22 pub struct Targets: u32 {
23 const IR = 1;
25
26 const ANALYSIS = 1 << 1;
28
29 const SPIRV = 1 << 2;
30 const METAL = 1 << 3;
31 const GLSL = 1 << 4;
32 const DOT = 1 << 5;
33 const HLSL = 1 << 6;
34 const WGSL = 1 << 7;
35 const NO_VALIDATION = 1 << 8;
36 }
37}
38
39impl Targets {
40 pub fn non_wgsl_default() -> Self {
42 Targets::WGSL
43 }
44
45 pub fn wgsl_default() -> Self {
47 Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL
48 }
49}
50
51#[derive(serde::Deserialize)]
52pub struct SpvOutVersion(pub u8, pub u8);
53impl Default for SpvOutVersion {
54 fn default() -> Self {
55 SpvOutVersion(1, 1)
56 }
57}
58
59#[derive(serde::Deserialize)]
60pub struct BindingMapSerialization {
61 pub resource_binding: naga::ResourceBinding,
62 pub bind_target: naga::back::spv::BindingInfo,
63}
64
65pub fn deserialize_binding_map<'de, D>(
66 deserializer: D,
67) -> Result<naga::back::spv::BindingMap, D::Error>
68where
69 D: serde::Deserializer<'de>,
70{
71 use serde::Deserialize;
72
73 let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
74 let mut map = naga::back::spv::BindingMap::default();
75 for item in vec {
76 map.insert(item.resource_binding, item.bind_target);
77 }
78 Ok(map)
79}
80
81#[derive(Default, serde::Deserialize)]
82#[serde(default)]
83pub struct WriterSharedOptions {
84 pub mesh_output_validation: bool,
85 pub task_limits: Option<naga::back::TaskDispatchLimits>,
86 pub bounds_checks_policies: naga::proc::BoundsCheckPolicies,
87}
88
89#[derive(Default, serde::Deserialize)]
90#[serde(default)]
91pub struct WgslInParameters {
92 pub parse_doc_comments: bool,
93}
94impl From<&WgslInParameters> for naga::front::wgsl::Options {
95 fn from(value: &WgslInParameters) -> Self {
96 Self {
97 parse_doc_comments: value.parse_doc_comments,
98 capabilities: naga::valid::Capabilities::all(),
99 }
100 }
101}
102
103#[derive(Default, serde::Deserialize)]
104#[serde(default)]
105pub struct SpirvInParameters {
106 pub adjust_coordinate_space: bool,
107}
108impl From<&SpirvInParameters> for naga::front::spv::Options {
109 fn from(value: &SpirvInParameters) -> Self {
110 Self {
111 adjust_coordinate_space: value.adjust_coordinate_space,
112 ..Default::default()
113 }
114 }
115}
116
117#[derive(serde::Deserialize)]
118#[serde(default)]
119pub struct SpirvOutParameters {
120 pub version: SpvOutVersion,
121 pub capabilities: naga::FastHashSet<spirv::Capability>,
122 pub debug: bool,
123 pub adjust_coordinate_space: bool,
124 pub force_point_size: bool,
125 pub clamp_frag_depth: bool,
126 pub separate_entry_points: bool,
127 #[serde(deserialize_with = "deserialize_binding_map")]
128 pub binding_map: naga::back::spv::BindingMap,
129 pub ray_query_initialization_tracking: bool,
130 pub use_storage_input_output_16: bool,
131 pub emit_int_div_checks: bool,
132}
133impl Default for SpirvOutParameters {
134 fn default() -> Self {
135 Self {
136 version: SpvOutVersion::default(),
137 capabilities: naga::FastHashSet::default(),
138 debug: false,
139 adjust_coordinate_space: false,
140 force_point_size: false,
141 clamp_frag_depth: false,
142 separate_entry_points: false,
143 ray_query_initialization_tracking: true,
144 use_storage_input_output_16: true,
145 emit_int_div_checks: true,
146 binding_map: naga::back::spv::BindingMap::default(),
147 }
148 }
149}
150impl SpirvOutParameters {
151 pub fn to_options<'a>(
152 &'a self,
153 shared_info: &WriterSharedOptions,
154 debug_info: Option<naga::back::spv::DebugInfo<'a>>,
155 ) -> naga::back::spv::Options<'a> {
156 use naga::back::spv;
157 let mut flags = spv::WriterFlags::LABEL_VARYINGS;
158 flags.set(spv::WriterFlags::DEBUG, self.debug);
159 flags.set(
160 spv::WriterFlags::ADJUST_COORDINATE_SPACE,
161 self.adjust_coordinate_space,
162 );
163 flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size);
164 flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth);
165 naga::back::spv::Options {
166 lang_version: (self.version.0, self.version.1),
167 flags,
168 capabilities: if self.capabilities.is_empty() {
169 None
170 } else {
171 Some(self.capabilities.clone())
172 },
173 bounds_check_policies: shared_info.bounds_checks_policies,
174 fake_missing_bindings: true,
175 binding_map: self.binding_map.clone(),
176 zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
177 force_loop_bounding: true,
178 ray_query_initialization_tracking: true,
179 debug_info,
180 use_storage_input_output_16: self.use_storage_input_output_16,
181 task_dispatch_limits: shared_info.task_limits,
182 mesh_shader_primitive_indices_clamp: shared_info.mesh_output_validation,
183 trace_ray_argument_validation: true,
184 emit_int_div_checks: self.emit_int_div_checks,
185 }
186 }
187}
188
189#[derive(Default, serde::Deserialize)]
190#[serde(default)]
191pub struct WgslOutParameters {
192 pub explicit_types: bool,
193}
194impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags {
195 fn from(value: &WgslOutParameters) -> Self {
196 let mut flags = Self::empty();
197 flags.set(Self::EXPLICIT_TYPES, value.explicit_types);
198 flags
199 }
200}
201
202#[derive(Default, serde::Deserialize)]
203pub struct FragmentModule {
204 pub path: String,
205 pub entry_point: String,
206}
207
208#[derive(Default, serde::Deserialize)]
209#[serde(default)]
210pub struct Parameters {
211 pub capabilities: Option<naga::valid::Capabilities>,
215
216 #[serde(rename = "wgsl-in")]
218 pub wgsl_in: WgslInParameters,
219
220 #[serde(rename = "spv-in")]
222 pub spv_in: SpirvInParameters,
223
224 pub spv: SpirvOutParameters,
226
227 pub targets: Option<Targets>,
230
231 pub msl: naga::back::msl::Options,
233 #[serde(default)]
234 pub msl_pipeline: naga::back::msl::PipelineOptions,
235
236 pub glsl: naga::back::glsl::Options,
238 pub glsl_exclude_list: naga::FastHashSet<String>,
239 pub glsl_multiview: Option<core::num::NonZeroU32>,
240
241 pub hlsl: naga::back::hlsl::Options,
243
244 pub wgsl: WgslOutParameters,
246
247 pub fragment_module: Option<FragmentModule>,
252
253 pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
254 pub pipeline_constants: naga::back::PipelineConstants,
255
256 pub mesh_output_validation: bool,
257 #[serde(default = "default_task_limits")]
258 pub task_limits: Option<naga::back::TaskDispatchLimits>,
259}
260
261fn default_task_limits() -> Option<naga::back::TaskDispatchLimits> {
262 Some(naga::back::TaskDispatchLimits {
263 max_mesh_workgroups_per_dim: 256,
264 max_mesh_workgroups_total: 1024,
265 })
266}
267
268#[derive(Debug)]
270pub struct Input {
271 pub subdirectory: PathBuf,
276
277 pub file_name: PathBuf,
279
280 pub keep_input_extension: bool,
287}
288
289impl Input {
290 pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input {
299 Input {
300 subdirectory: PathBuf::from(subdirectory),
301 file_name: PathBuf::from(format!("{name}.{extension}")),
304 keep_input_extension: false,
305 }
306 }
307
308 pub fn files_in_dir<'a>(
310 subdirectory: &'a str,
311 file_extensions: &'a [&'a str],
312 dir_in: &str,
313 ) -> impl Iterator<Item = Input> + 'a {
314 let input_directory = Path::new(dir_in).join(subdirectory);
315
316 let entries = match std::fs::read_dir(&input_directory) {
317 Ok(entries) => entries,
318 Err(err) => panic!(
319 "Error opening directory '{}': {}",
320 input_directory.display(),
321 err
322 ),
323 };
324
325 entries.filter_map(move |result| {
326 let entry = result.expect("error reading directory");
327 if !entry.file_type().unwrap().is_file() {
328 return None;
329 }
330
331 let file_name = PathBuf::from(entry.file_name());
332 let extension = file_name
333 .extension()
334 .expect("all files in snapshot input directory should have extensions");
335
336 if !file_extensions.contains(&extension.to_str().unwrap()) {
337 return None;
338 }
339
340 if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") {
341 if !file_name.to_string_lossy().contains(&pat) {
342 return None;
343 }
344 }
345
346 let input = Input::new(
347 subdirectory,
348 file_name.file_stem().unwrap().to_str().unwrap(),
349 extension.to_str().unwrap(),
350 );
351 Some(input)
352 })
353 }
354
355 pub fn input_directory(&self, dir_in: &str) -> PathBuf {
357 Path::new(dir_in).join(&self.subdirectory)
358 }
359
360 pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf {
362 Path::new(dir_out).join(subdirectory)
363 }
364
365 pub fn input_path(&self, dir_in: &str) -> PathBuf {
367 let mut input = self.input_directory(dir_in);
368 input.push(&self.file_name);
369 input
370 }
371
372 pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf {
373 let mut output = Self::output_directory(subdirectory, dir_out);
374 if self.keep_input_extension {
375 let file_name = format!(
376 "{}-{}.{}",
377 self.subdirectory.display(),
378 self.file_name.display(),
379 extension
380 );
381
382 output.push(&file_name);
383 } else {
384 let file_name = format!(
385 "{}-{}",
386 self.subdirectory.display(),
387 self.file_name.display()
388 );
389
390 output.push(&file_name);
391 output.set_extension(extension);
392 }
393 output
394 }
395
396 pub fn read_source(&self, dir_in: &str, print: bool) -> String {
398 if print {
399 println!("Processing '{}'", self.file_name.display());
400 }
401 let input_path = self.input_path(dir_in);
402 match fs::read_to_string(&input_path) {
403 Ok(source) => source,
404 Err(err) => {
405 panic!(
406 "Couldn't read shader input file `{}`: {}",
407 input_path.display(),
408 err
409 );
410 }
411 }
412 }
413
414 pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec<u8> {
416 if print {
417 println!("Processing '{}'", self.file_name.display());
418 }
419 let input_path = self.input_path(dir_in);
420 match fs::read(&input_path) {
421 Ok(bytes) => bytes,
422 Err(err) => {
423 panic!(
424 "Couldn't read shader input file `{}`: {}",
425 input_path.display(),
426 err
427 );
428 }
429 }
430 }
431
432 pub fn bytes(&self, dir_in: &str) -> u64 {
433 let input_path = self.input_path(dir_in);
434 std::fs::metadata(input_path).unwrap().len()
435 }
436
437 pub fn read_parameters(&self, dir_in: &str) -> Parameters {
439 let mut param_path = self.input_path(dir_in);
440 param_path.set_extension("toml");
441 let mut params = match fs::read_to_string(¶m_path) {
442 Ok(string) => match toml::de::from_str(&string) {
443 Ok(params) => params,
444 Err(e) => panic!(
445 "Couldn't parse param file: {} due to: {e}",
446 param_path.display()
447 ),
448 },
449 Err(_) => Parameters::default(),
450 };
451
452 if params.targets.is_none() {
453 match self
454 .input_path(dir_in)
455 .extension()
456 .unwrap()
457 .to_str()
458 .unwrap()
459 {
460 "wgsl" => params.targets = Some(Targets::wgsl_default()),
461 "spvasm" => params.targets = Some(Targets::non_wgsl_default()),
462 "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()),
463 e => {
464 panic!("Unknown extension: {e}");
465 }
466 }
467 }
468
469 params
470 }
471
472 pub fn write_output_file(
475 &self,
476 subdirectory: &str,
477 extension: &str,
478 data: impl AsRef<[u8]>,
479 dir_out: &str,
480 ) {
481 let output_path = self.output_path(subdirectory, extension, dir_out);
482 fs::create_dir_all(output_path.parent().unwrap()).unwrap();
483 if let Err(err) = fs::write(&output_path, data) {
484 panic!("Error writing {}: {}", output_path.display(), err);
485 }
486 }
487}