1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::min_max_float_representable_by;
28pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
29
30impl From<super::StorageFormat> for super::Scalar {
31 fn from(format: super::StorageFormat) -> Self {
32 use super::{ScalarKind as Sk, StorageFormat as Sf};
33 let kind = match format {
34 Sf::R8Unorm => Sk::Float,
35 Sf::R8Snorm => Sk::Float,
36 Sf::R8Uint => Sk::Uint,
37 Sf::R8Sint => Sk::Sint,
38 Sf::R16Uint => Sk::Uint,
39 Sf::R16Sint => Sk::Sint,
40 Sf::R16Float => Sk::Float,
41 Sf::Rg8Unorm => Sk::Float,
42 Sf::Rg8Snorm => Sk::Float,
43 Sf::Rg8Uint => Sk::Uint,
44 Sf::Rg8Sint => Sk::Sint,
45 Sf::R32Uint => Sk::Uint,
46 Sf::R32Sint => Sk::Sint,
47 Sf::R32Float => Sk::Float,
48 Sf::Rg16Uint => Sk::Uint,
49 Sf::Rg16Sint => Sk::Sint,
50 Sf::Rg16Float => Sk::Float,
51 Sf::Rgba8Unorm => Sk::Float,
52 Sf::Rgba8Snorm => Sk::Float,
53 Sf::Rgba8Uint => Sk::Uint,
54 Sf::Rgba8Sint => Sk::Sint,
55 Sf::Bgra8Unorm => Sk::Float,
56 Sf::Rgb10a2Uint => Sk::Uint,
57 Sf::Rgb10a2Unorm => Sk::Float,
58 Sf::Rg11b10Ufloat => Sk::Float,
59 Sf::R64Uint => Sk::Uint,
60 Sf::Rg32Uint => Sk::Uint,
61 Sf::Rg32Sint => Sk::Sint,
62 Sf::Rg32Float => Sk::Float,
63 Sf::Rgba16Uint => Sk::Uint,
64 Sf::Rgba16Sint => Sk::Sint,
65 Sf::Rgba16Float => Sk::Float,
66 Sf::Rgba32Uint => Sk::Uint,
67 Sf::Rgba32Sint => Sk::Sint,
68 Sf::Rgba32Float => Sk::Float,
69 Sf::R16Unorm => Sk::Float,
70 Sf::R16Snorm => Sk::Float,
71 Sf::Rg16Unorm => Sk::Float,
72 Sf::Rg16Snorm => Sk::Float,
73 Sf::Rgba16Unorm => Sk::Float,
74 Sf::Rgba16Snorm => Sk::Float,
75 };
76 let width = match format {
77 Sf::R64Uint => 8,
78 _ => 4,
79 };
80 super::Scalar { kind, width }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum HashableLiteral {
86 F64(u64),
87 F32(u32),
88 F16(u16),
89 U32(u32),
90 I32(i32),
91 U64(u64),
92 I64(i64),
93 Bool(bool),
94 AbstractInt(i64),
95 AbstractFloat(u64),
96}
97
98impl From<crate::Literal> for HashableLiteral {
99 fn from(l: crate::Literal) -> Self {
100 match l {
101 crate::Literal::F64(v) => Self::F64(v.to_bits()),
102 crate::Literal::F32(v) => Self::F32(v.to_bits()),
103 crate::Literal::F16(v) => Self::F16(v.to_bits()),
104 crate::Literal::U32(v) => Self::U32(v),
105 crate::Literal::I32(v) => Self::I32(v),
106 crate::Literal::U64(v) => Self::U64(v),
107 crate::Literal::I64(v) => Self::I64(v),
108 crate::Literal::Bool(v) => Self::Bool(v),
109 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
110 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
111 }
112 }
113}
114
115impl crate::Literal {
116 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
117 match (value, scalar.kind, scalar.width) {
118 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
119 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
120 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
121 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
122 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
123 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
124 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
125 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
126 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
127 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
128 _ => None,
129 }
130 }
131
132 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
133 Self::new(0, scalar)
134 }
135
136 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
137 Self::new(1, scalar)
138 }
139
140 pub const fn width(&self) -> crate::Bytes {
141 match *self {
142 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
143 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
144 Self::F16(_) => 2,
145 Self::Bool(_) => crate::BOOL_WIDTH,
146 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
147 }
148 }
149 pub const fn scalar(&self) -> crate::Scalar {
150 match *self {
151 Self::F64(_) => crate::Scalar::F64,
152 Self::F32(_) => crate::Scalar::F32,
153 Self::F16(_) => crate::Scalar::F16,
154 Self::U32(_) => crate::Scalar::U32,
155 Self::I32(_) => crate::Scalar::I32,
156 Self::U64(_) => crate::Scalar::U64,
157 Self::I64(_) => crate::Scalar::I64,
158 Self::Bool(_) => crate::Scalar::BOOL,
159 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
160 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
161 }
162 }
163 pub const fn scalar_kind(&self) -> crate::ScalarKind {
164 self.scalar().kind
165 }
166 pub const fn ty_inner(&self) -> crate::TypeInner {
167 crate::TypeInner::Scalar(self.scalar())
168 }
169}
170
171impl super::AddressSpace {
172 pub fn access(self) -> crate::StorageAccess {
173 use crate::StorageAccess as Sa;
174 match self {
175 crate::AddressSpace::Function
176 | crate::AddressSpace::Private
177 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
178 crate::AddressSpace::Uniform => Sa::LOAD,
179 crate::AddressSpace::Storage { access } => access,
180 crate::AddressSpace::Handle => Sa::LOAD,
181 crate::AddressSpace::PushConstant => Sa::LOAD,
182 crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
185 }
186 }
187}
188
189impl super::MathFunction {
190 pub const fn argument_count(&self) -> usize {
191 match *self {
192 Self::Abs => 1,
194 Self::Min => 2,
195 Self::Max => 2,
196 Self::Clamp => 3,
197 Self::Saturate => 1,
198 Self::Cos => 1,
200 Self::Cosh => 1,
201 Self::Sin => 1,
202 Self::Sinh => 1,
203 Self::Tan => 1,
204 Self::Tanh => 1,
205 Self::Acos => 1,
206 Self::Asin => 1,
207 Self::Atan => 1,
208 Self::Atan2 => 2,
209 Self::Asinh => 1,
210 Self::Acosh => 1,
211 Self::Atanh => 1,
212 Self::Radians => 1,
213 Self::Degrees => 1,
214 Self::Ceil => 1,
216 Self::Floor => 1,
217 Self::Round => 1,
218 Self::Fract => 1,
219 Self::Trunc => 1,
220 Self::Modf => 1,
221 Self::Frexp => 1,
222 Self::Ldexp => 2,
223 Self::Exp => 1,
225 Self::Exp2 => 1,
226 Self::Log => 1,
227 Self::Log2 => 1,
228 Self::Pow => 2,
229 Self::Dot => 2,
231 Self::Dot4I8Packed => 2,
232 Self::Dot4U8Packed => 2,
233 Self::Outer => 2,
234 Self::Cross => 2,
235 Self::Distance => 2,
236 Self::Length => 1,
237 Self::Normalize => 1,
238 Self::FaceForward => 3,
239 Self::Reflect => 2,
240 Self::Refract => 3,
241 Self::Sign => 1,
243 Self::Fma => 3,
244 Self::Mix => 3,
245 Self::Step => 2,
246 Self::SmoothStep => 3,
247 Self::Sqrt => 1,
248 Self::InverseSqrt => 1,
249 Self::Inverse => 1,
250 Self::Transpose => 1,
251 Self::Determinant => 1,
252 Self::QuantizeToF16 => 1,
253 Self::CountTrailingZeros => 1,
255 Self::CountLeadingZeros => 1,
256 Self::CountOneBits => 1,
257 Self::ReverseBits => 1,
258 Self::ExtractBits => 3,
259 Self::InsertBits => 4,
260 Self::FirstTrailingBit => 1,
261 Self::FirstLeadingBit => 1,
262 Self::Pack4x8snorm => 1,
264 Self::Pack4x8unorm => 1,
265 Self::Pack2x16snorm => 1,
266 Self::Pack2x16unorm => 1,
267 Self::Pack2x16float => 1,
268 Self::Pack4xI8 => 1,
269 Self::Pack4xU8 => 1,
270 Self::Pack4xI8Clamp => 1,
271 Self::Pack4xU8Clamp => 1,
272 Self::Unpack4x8snorm => 1,
274 Self::Unpack4x8unorm => 1,
275 Self::Unpack2x16snorm => 1,
276 Self::Unpack2x16unorm => 1,
277 Self::Unpack2x16float => 1,
278 Self::Unpack4xI8 => 1,
279 Self::Unpack4xU8 => 1,
280 }
281 }
282}
283
284impl crate::Expression {
285 pub const fn needs_pre_emit(&self) -> bool {
287 match *self {
288 Self::Literal(_)
289 | Self::Constant(_)
290 | Self::Override(_)
291 | Self::ZeroValue(_)
292 | Self::FunctionArgument(_)
293 | Self::GlobalVariable(_)
294 | Self::LocalVariable(_) => true,
295 _ => false,
296 }
297 }
298
299 pub const fn is_dynamic_index(&self) -> bool {
313 match *self {
314 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
315 _ => true,
316 }
317 }
318}
319
320impl crate::Function {
321 pub fn originating_global(
330 &self,
331 mut pointer: crate::Handle<crate::Expression>,
332 ) -> Option<crate::Handle<crate::GlobalVariable>> {
333 loop {
334 pointer = match self.expressions[pointer] {
335 crate::Expression::Access { base, .. } => base,
336 crate::Expression::AccessIndex { base, .. } => base,
337 crate::Expression::GlobalVariable(handle) => return Some(handle),
338 crate::Expression::LocalVariable(_) => return None,
339 crate::Expression::FunctionArgument(_) => return None,
340 _ => unreachable!(),
342 }
343 }
344 }
345}
346
347impl crate::SampleLevel {
348 pub const fn implicit_derivatives(&self) -> bool {
349 match *self {
350 Self::Auto | Self::Bias(_) => true,
351 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
352 }
353 }
354}
355
356impl crate::Binding {
357 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
358 match *self {
359 crate::Binding::BuiltIn(built_in) => Some(built_in),
360 Self::Location { .. } => None,
361 }
362 }
363}
364
365impl super::SwizzleComponent {
366 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
367
368 pub const fn index(&self) -> u32 {
369 match *self {
370 Self::X => 0,
371 Self::Y => 1,
372 Self::Z => 2,
373 Self::W => 3,
374 }
375 }
376 pub const fn from_index(idx: u32) -> Self {
377 match idx {
378 0 => Self::X,
379 1 => Self::Y,
380 2 => Self::Z,
381 _ => Self::W,
382 }
383 }
384}
385
386impl super::ImageClass {
387 pub const fn is_multisampled(self) -> bool {
388 match self {
389 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
390 crate::ImageClass::Storage { .. } => false,
391 crate::ImageClass::External => false,
392 }
393 }
394
395 pub const fn is_mipmapped(self) -> bool {
396 match self {
397 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
398 crate::ImageClass::Storage { .. } => false,
399 crate::ImageClass::External => false,
400 }
401 }
402
403 pub const fn is_depth(self) -> bool {
404 matches!(self, crate::ImageClass::Depth { .. })
405 }
406}
407
408impl crate::Module {
409 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
410 GlobalCtx {
411 types: &self.types,
412 constants: &self.constants,
413 overrides: &self.overrides,
414 global_expressions: &self.global_expressions,
415 }
416 }
417
418 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
419 compare_types(lhs, rhs, &self.types)
420 }
421}
422
423#[derive(Debug)]
424pub(super) enum U32EvalError {
425 NonConst,
426 Negative,
427}
428
429#[derive(Clone, Copy)]
430pub struct GlobalCtx<'a> {
431 pub types: &'a crate::UniqueArena<crate::Type>,
432 pub constants: &'a crate::Arena<crate::Constant>,
433 pub overrides: &'a crate::Arena<crate::Override>,
434 pub global_expressions: &'a crate::Arena<crate::Expression>,
435}
436
437impl GlobalCtx<'_> {
438 #[allow(dead_code)]
440 pub(super) fn eval_expr_to_u32(
441 &self,
442 handle: crate::Handle<crate::Expression>,
443 ) -> Result<u32, U32EvalError> {
444 self.eval_expr_to_u32_from(handle, self.global_expressions)
445 }
446
447 pub(super) fn eval_expr_to_u32_from(
449 &self,
450 handle: crate::Handle<crate::Expression>,
451 arena: &crate::Arena<crate::Expression>,
452 ) -> Result<u32, U32EvalError> {
453 match self.eval_expr_to_literal_from(handle, arena) {
454 Some(crate::Literal::U32(value)) => Ok(value),
455 Some(crate::Literal::I32(value)) => {
456 value.try_into().map_err(|_| U32EvalError::Negative)
457 }
458 _ => Err(U32EvalError::NonConst),
459 }
460 }
461
462 #[allow(dead_code)]
464 pub(super) fn eval_expr_to_bool_from(
465 &self,
466 handle: crate::Handle<crate::Expression>,
467 arena: &crate::Arena<crate::Expression>,
468 ) -> Option<bool> {
469 match self.eval_expr_to_literal_from(handle, arena) {
470 Some(crate::Literal::Bool(value)) => Some(value),
471 _ => None,
472 }
473 }
474
475 #[allow(dead_code)]
476 pub(crate) fn eval_expr_to_literal(
477 &self,
478 handle: crate::Handle<crate::Expression>,
479 ) -> Option<crate::Literal> {
480 self.eval_expr_to_literal_from(handle, self.global_expressions)
481 }
482
483 pub(super) fn eval_expr_to_literal_from(
484 &self,
485 handle: crate::Handle<crate::Expression>,
486 arena: &crate::Arena<crate::Expression>,
487 ) -> Option<crate::Literal> {
488 fn get(
489 gctx: GlobalCtx,
490 handle: crate::Handle<crate::Expression>,
491 arena: &crate::Arena<crate::Expression>,
492 ) -> Option<crate::Literal> {
493 match arena[handle] {
494 crate::Expression::Literal(literal) => Some(literal),
495 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
496 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
497 _ => None,
498 },
499 _ => None,
500 }
501 }
502 match arena[handle] {
503 crate::Expression::Constant(c) => {
504 get(*self, self.constants[c].init, self.global_expressions)
505 }
506 _ => get(*self, handle, arena),
507 }
508 }
509
510 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
511 compare_types(lhs, rhs, self.types)
512 }
513}
514
515#[derive(Error, Debug, Clone, Copy, PartialEq)]
516pub enum ResolveArraySizeError {
517 #[error("array element count must be positive (> 0)")]
518 ExpectedPositiveArrayLength,
519 #[error("internal: array size override has not been resolved")]
520 NonConstArrayLength,
521}
522
523impl crate::ArraySize {
524 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
534 match *self {
535 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
536 crate::ArraySize::Pending(handle) => {
537 let Some(expr) = gctx.overrides[handle].init else {
538 return Err(ResolveArraySizeError::NonConstArrayLength);
539 };
540 let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
541 U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
542 U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
543 })?;
544
545 if length == 0 {
546 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
547 }
548
549 Ok(IndexableLength::Known(length))
550 }
551 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
552 }
553 }
554}
555
556pub fn flatten_compose<'arenas>(
569 ty: crate::Handle<crate::Type>,
570 components: &'arenas [crate::Handle<crate::Expression>],
571 expressions: &'arenas crate::Arena<crate::Expression>,
572 types: &'arenas crate::UniqueArena<crate::Type>,
573) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
574 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
580 (size as usize, true)
581 } else {
582 (components.len(), false)
583 };
584
585 fn flatten_compose<'c>(
587 component: &'c crate::Handle<crate::Expression>,
588 is_vector: bool,
589 expressions: &'c crate::Arena<crate::Expression>,
590 ) -> &'c [crate::Handle<crate::Expression>] {
591 if is_vector {
592 if let crate::Expression::Compose {
593 ty: _,
594 components: ref subcomponents,
595 } = expressions[*component]
596 {
597 return subcomponents;
598 }
599 }
600 core::slice::from_ref(component)
601 }
602
603 fn flatten_splat<'c>(
605 component: &'c crate::Handle<crate::Expression>,
606 is_vector: bool,
607 expressions: &'c crate::Arena<crate::Expression>,
608 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
609 let mut expr = *component;
610 let mut count = 1;
611 if is_vector {
612 if let crate::Expression::Splat { size, value } = expressions[expr] {
613 expr = value;
614 count = size as usize;
615 }
616 }
617 core::iter::repeat_n(expr, count)
618 }
619
620 components
627 .iter()
628 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
629 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
630 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
631 .take(size)
632}
633
634impl super::ShaderStage {
635 pub const fn compute_like(self) -> bool {
636 match self {
637 Self::Vertex | Self::Fragment => false,
638 Self::Compute | Self::Task | Self::Mesh => true,
639 }
640 }
641}
642
643#[test]
644fn test_matrix_size() {
645 let module = crate::Module::default();
646 assert_eq!(
647 crate::TypeInner::Matrix {
648 columns: crate::VectorSize::Tri,
649 rows: crate::VectorSize::Tri,
650 scalar: crate::Scalar::F32,
651 }
652 .size(module.to_ctx()),
653 48,
654 );
655}