1mod conv;
116mod help;
117mod keywords;
118mod ray;
119mod storage;
120mod writer;
121
122use alloc::{string::String, vec::Vec};
123use core::fmt::Error as FmtError;
124
125use thiserror::Error;
126
127use crate::{back, ir, proc};
128
129#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
130#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
131#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
132pub struct BindTarget {
133 pub space: u8,
134 pub register: u32,
138 pub binding_array_size: Option<u32>,
140 pub dynamic_storage_buffer_offsets_index: Option<u32>,
142 #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
146 pub restrict_indexing: bool,
147}
148
149#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152pub struct OffsetsBindTarget {
154 pub space: u8,
155 pub register: u32,
156 pub size: u32,
157}
158
159#[cfg(any(feature = "serialize", feature = "deserialize"))]
160#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
161#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
162struct BindingMapSerialization {
163 resource_binding: crate::ResourceBinding,
164 bind_target: BindTarget,
165}
166
167#[cfg(feature = "deserialize")]
168fn deserialize_binding_map<'de, D>(deserializer: D) -> Result<BindingMap, D::Error>
169where
170 D: serde::Deserializer<'de>,
171{
172 use serde::Deserialize;
173
174 let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
175 let mut map = BindingMap::default();
176 for item in vec {
177 map.insert(item.resource_binding, item.bind_target);
178 }
179 Ok(map)
180}
181
182pub type BindingMap = alloc::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
184
185#[allow(non_snake_case, non_camel_case_types)]
187#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
188#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
189#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
190pub enum ShaderModel {
191 V5_0,
192 V5_1,
193 V6_0,
194 V6_1,
195 V6_2,
196 V6_3,
197 V6_4,
198 V6_5,
199 V6_6,
200 V6_7,
201}
202
203impl ShaderModel {
204 pub const fn to_str(self) -> &'static str {
205 match self {
206 Self::V5_0 => "5_0",
207 Self::V5_1 => "5_1",
208 Self::V6_0 => "6_0",
209 Self::V6_1 => "6_1",
210 Self::V6_2 => "6_2",
211 Self::V6_3 => "6_3",
212 Self::V6_4 => "6_4",
213 Self::V6_5 => "6_5",
214 Self::V6_6 => "6_6",
215 Self::V6_7 => "6_7",
216 }
217 }
218}
219
220impl crate::ShaderStage {
221 pub const fn to_hlsl_str(self) -> &'static str {
222 match self {
223 Self::Vertex => "vs",
224 Self::Fragment => "ps",
225 Self::Compute => "cs",
226 Self::Task | Self::Mesh => unreachable!(),
227 }
228 }
229}
230
231impl crate::ImageDimension {
232 const fn to_hlsl_str(self) -> &'static str {
233 match self {
234 Self::D1 => "1D",
235 Self::D2 => "2D",
236 Self::D3 => "3D",
237 Self::Cube => "Cube",
238 }
239 }
240}
241
242#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
243#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
244#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
245pub struct SamplerIndexBufferKey {
246 pub group: u32,
247}
248
249#[derive(Clone, Debug, Hash, PartialEq, Eq)]
250#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
251#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
252#[cfg_attr(feature = "deserialize", serde(default))]
253pub struct SamplerHeapBindTargets {
254 pub standard_samplers: BindTarget,
255 pub comparison_samplers: BindTarget,
256}
257
258impl Default for SamplerHeapBindTargets {
259 fn default() -> Self {
260 Self {
261 standard_samplers: BindTarget {
262 space: 0,
263 register: 0,
264 binding_array_size: None,
265 dynamic_storage_buffer_offsets_index: None,
266 restrict_indexing: false,
267 },
268 comparison_samplers: BindTarget {
269 space: 1,
270 register: 0,
271 binding_array_size: None,
272 dynamic_storage_buffer_offsets_index: None,
273 restrict_indexing: false,
274 },
275 }
276 }
277}
278
279#[cfg(any(feature = "serialize", feature = "deserialize"))]
280#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
281#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
282struct SamplerIndexBufferBindingSerialization {
283 group: u32,
284 bind_target: BindTarget,
285}
286
287#[cfg(feature = "deserialize")]
288fn deserialize_sampler_index_buffer_bindings<'de, D>(
289 deserializer: D,
290) -> Result<SamplerIndexBufferBindingMap, D::Error>
291where
292 D: serde::Deserializer<'de>,
293{
294 use serde::Deserialize;
295
296 let vec = Vec::<SamplerIndexBufferBindingSerialization>::deserialize(deserializer)?;
297 let mut map = SamplerIndexBufferBindingMap::default();
298 for item in vec {
299 map.insert(
300 SamplerIndexBufferKey { group: item.group },
301 item.bind_target,
302 );
303 }
304 Ok(map)
305}
306
307pub type SamplerIndexBufferBindingMap =
309 alloc::collections::BTreeMap<SamplerIndexBufferKey, BindTarget>;
310
311#[cfg(any(feature = "serialize", feature = "deserialize"))]
312#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
313#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
314struct DynamicStorageBufferOffsetTargetSerialization {
315 index: u32,
316 bind_target: OffsetsBindTarget,
317}
318
319#[cfg(feature = "deserialize")]
320fn deserialize_storage_buffer_offsets<'de, D>(
321 deserializer: D,
322) -> Result<DynamicStorageBufferOffsetsTargets, D::Error>
323where
324 D: serde::Deserializer<'de>,
325{
326 use serde::Deserialize;
327
328 let vec = Vec::<DynamicStorageBufferOffsetTargetSerialization>::deserialize(deserializer)?;
329 let mut map = DynamicStorageBufferOffsetsTargets::default();
330 for item in vec {
331 map.insert(item.index, item.bind_target);
332 }
333 Ok(map)
334}
335
336pub type DynamicStorageBufferOffsetsTargets = alloc::collections::BTreeMap<u32, OffsetsBindTarget>;
337
338type BackendResult = Result<(), Error>;
340
341#[derive(Clone, Debug, PartialEq, thiserror::Error)]
342#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
343#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
344pub enum EntryPointError {
345 #[error("mapping of {0:?} is missing")]
346 MissingBinding(crate::ResourceBinding),
347}
348
349#[derive(Clone, Debug, Hash, PartialEq, Eq)]
351#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
352#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
353#[cfg_attr(feature = "deserialize", serde(default))]
354pub struct Options {
355 pub shader_model: ShaderModel,
357 #[cfg_attr(
359 feature = "deserialize",
360 serde(deserialize_with = "deserialize_binding_map")
361 )]
362 pub binding_map: BindingMap,
363 pub fake_missing_bindings: bool,
365 pub special_constants_binding: Option<BindTarget>,
368 pub push_constants_target: Option<BindTarget>,
370 pub sampler_heap_target: SamplerHeapBindTargets,
372 #[cfg_attr(
374 feature = "deserialize",
375 serde(deserialize_with = "deserialize_sampler_index_buffer_bindings")
376 )]
377 pub sampler_buffer_binding_map: SamplerIndexBufferBindingMap,
378 #[cfg_attr(
380 feature = "deserialize",
381 serde(deserialize_with = "deserialize_storage_buffer_offsets")
382 )]
383 pub dynamic_storage_buffer_offsets_targets: DynamicStorageBufferOffsetsTargets,
384 pub zero_initialize_workgroup_memory: bool,
386 pub restrict_indexing: bool,
388 pub force_loop_bounding: bool,
391}
392
393impl Default for Options {
394 fn default() -> Self {
395 Options {
396 shader_model: ShaderModel::V5_1,
397 binding_map: BindingMap::default(),
398 fake_missing_bindings: true,
399 special_constants_binding: None,
400 sampler_heap_target: SamplerHeapBindTargets::default(),
401 sampler_buffer_binding_map: alloc::collections::BTreeMap::default(),
402 push_constants_target: None,
403 dynamic_storage_buffer_offsets_targets: alloc::collections::BTreeMap::new(),
404 zero_initialize_workgroup_memory: true,
405 restrict_indexing: true,
406 force_loop_bounding: true,
407 }
408 }
409}
410
411impl Options {
412 fn resolve_resource_binding(
413 &self,
414 res_binding: &crate::ResourceBinding,
415 ) -> Result<BindTarget, EntryPointError> {
416 match self.binding_map.get(res_binding) {
417 Some(target) => Ok(*target),
418 None if self.fake_missing_bindings => Ok(BindTarget {
419 space: res_binding.group as u8,
420 register: res_binding.binding,
421 binding_array_size: None,
422 dynamic_storage_buffer_offsets_index: None,
423 restrict_indexing: false,
424 }),
425 None => Err(EntryPointError::MissingBinding(*res_binding)),
426 }
427 }
428}
429
430#[derive(Default)]
432pub struct ReflectionInfo {
433 pub entry_point_names: Vec<Result<String, EntryPointError>>,
440}
441
442#[derive(Debug, Default, Clone)]
444#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
445#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
446#[cfg_attr(feature = "deserialize", serde(default))]
447pub struct PipelineOptions {
448 pub entry_point: Option<(ir::ShaderStage, String)>,
456}
457
458#[derive(Error, Debug)]
459pub enum Error {
460 #[error(transparent)]
461 IoError(#[from] FmtError),
462 #[error("A scalar with an unsupported width was requested: {0:?}")]
463 UnsupportedScalar(crate::Scalar),
464 #[error("{0}")]
465 Unimplemented(String), #[error("{0}")]
467 Custom(String),
468 #[error("overrides should not be present at this stage")]
469 Override,
470 #[error(transparent)]
471 ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
472 #[error("entry point with stage {0:?} and name '{1}' not found")]
473 EntryPointNotFound(ir::ShaderStage, String),
474}
475
476#[derive(PartialEq, Eq, Hash)]
477enum WrappedType {
478 ZeroValue(help::WrappedZeroValue),
479 ArrayLength(help::WrappedArrayLength),
480 ImageSample(help::WrappedImageSample),
481 ImageQuery(help::WrappedImageQuery),
482 ImageLoadScalar(crate::Scalar),
483 Constructor(help::WrappedConstructor),
484 StructMatrixAccess(help::WrappedStructMatrixAccess),
485 MatCx2(help::WrappedMatCx2),
486 Math(help::WrappedMath),
487 UnaryOp(help::WrappedUnaryOp),
488 BinaryOp(help::WrappedBinaryOp),
489 Cast(help::WrappedCast),
490}
491
492#[derive(Default)]
493struct Wrapped {
494 types: crate::FastHashSet<WrappedType>,
495 sampler_heaps: bool,
497 sampler_index_buffers: crate::FastHashMap<SamplerIndexBufferKey, String>,
499}
500
501impl Wrapped {
502 fn insert(&mut self, r#type: WrappedType) -> bool {
503 self.types.insert(r#type)
504 }
505
506 fn clear(&mut self) {
507 self.types.clear();
508 }
509}
510
511pub struct FragmentEntryPoint<'a> {
520 module: &'a crate::Module,
521 func: &'a crate::Function,
522}
523
524impl<'a> FragmentEntryPoint<'a> {
525 pub fn new(module: &'a crate::Module, ep_name: &'a str) -> Option<Self> {
528 module
529 .entry_points
530 .iter()
531 .find(|ep| ep.name == ep_name)
532 .filter(|ep| ep.stage == crate::ShaderStage::Fragment)
533 .map(|ep| Self {
534 module,
535 func: &ep.function,
536 })
537 }
538}
539
540pub struct Writer<'a, W> {
541 out: W,
542 names: crate::FastHashMap<proc::NameKey, String>,
543 namer: proc::Namer,
544 options: &'a Options,
546 pipeline_options: &'a PipelineOptions,
548 entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
550 named_expressions: crate::NamedExpressions,
552 wrapped: Wrapped,
553 written_committed_intersection: bool,
554 written_candidate_intersection: bool,
555 continue_ctx: back::continue_forward::ContinueCtx,
556
557 temp_access_chain: Vec<storage::SubAccess>,
575 need_bake_expressions: back::NeedBakeExpressions,
576}