1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::min_max_float_representable_by;
28pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
29
30impl From<super::StorageFormat> for super::Scalar {
31 fn from(format: super::StorageFormat) -> Self {
32 use super::{ScalarKind as Sk, StorageFormat as Sf};
33 let kind = match format {
34 Sf::R8Unorm => Sk::Float,
35 Sf::R8Snorm => Sk::Float,
36 Sf::R8Uint => Sk::Uint,
37 Sf::R8Sint => Sk::Sint,
38 Sf::R16Uint => Sk::Uint,
39 Sf::R16Sint => Sk::Sint,
40 Sf::R16Float => Sk::Float,
41 Sf::Rg8Unorm => Sk::Float,
42 Sf::Rg8Snorm => Sk::Float,
43 Sf::Rg8Uint => Sk::Uint,
44 Sf::Rg8Sint => Sk::Sint,
45 Sf::R32Uint => Sk::Uint,
46 Sf::R32Sint => Sk::Sint,
47 Sf::R32Float => Sk::Float,
48 Sf::Rg16Uint => Sk::Uint,
49 Sf::Rg16Sint => Sk::Sint,
50 Sf::Rg16Float => Sk::Float,
51 Sf::Rgba8Unorm => Sk::Float,
52 Sf::Rgba8Snorm => Sk::Float,
53 Sf::Rgba8Uint => Sk::Uint,
54 Sf::Rgba8Sint => Sk::Sint,
55 Sf::Bgra8Unorm => Sk::Float,
56 Sf::Rgb10a2Uint => Sk::Uint,
57 Sf::Rgb10a2Unorm => Sk::Float,
58 Sf::Rg11b10Ufloat => Sk::Float,
59 Sf::R64Uint => Sk::Uint,
60 Sf::Rg32Uint => Sk::Uint,
61 Sf::Rg32Sint => Sk::Sint,
62 Sf::Rg32Float => Sk::Float,
63 Sf::Rgba16Uint => Sk::Uint,
64 Sf::Rgba16Sint => Sk::Sint,
65 Sf::Rgba16Float => Sk::Float,
66 Sf::Rgba32Uint => Sk::Uint,
67 Sf::Rgba32Sint => Sk::Sint,
68 Sf::Rgba32Float => Sk::Float,
69 Sf::R16Unorm => Sk::Float,
70 Sf::R16Snorm => Sk::Float,
71 Sf::Rg16Unorm => Sk::Float,
72 Sf::Rg16Snorm => Sk::Float,
73 Sf::Rgba16Unorm => Sk::Float,
74 Sf::Rgba16Snorm => Sk::Float,
75 };
76 let width = match format {
77 Sf::R64Uint => 8,
78 _ => 4,
79 };
80 super::Scalar { kind, width }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum HashableLiteral {
86 F64(u64),
87 F32(u32),
88 F16(u16),
89 U32(u32),
90 I32(i32),
91 U64(u64),
92 I64(i64),
93 Bool(bool),
94 AbstractInt(i64),
95 AbstractFloat(u64),
96}
97
98impl From<crate::Literal> for HashableLiteral {
99 fn from(l: crate::Literal) -> Self {
100 match l {
101 crate::Literal::F64(v) => Self::F64(v.to_bits()),
102 crate::Literal::F32(v) => Self::F32(v.to_bits()),
103 crate::Literal::F16(v) => Self::F16(v.to_bits()),
104 crate::Literal::U32(v) => Self::U32(v),
105 crate::Literal::I32(v) => Self::I32(v),
106 crate::Literal::U64(v) => Self::U64(v),
107 crate::Literal::I64(v) => Self::I64(v),
108 crate::Literal::Bool(v) => Self::Bool(v),
109 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
110 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
111 }
112 }
113}
114
115impl crate::Literal {
116 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
117 match (value, scalar.kind, scalar.width) {
118 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
119 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
120 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
121 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
122 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
123 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
124 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
125 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
126 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
127 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
128 _ => None,
129 }
130 }
131
132 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
133 Self::new(0, scalar)
134 }
135
136 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
137 Self::new(1, scalar)
138 }
139
140 pub const fn width(&self) -> crate::Bytes {
141 match *self {
142 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
143 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
144 Self::F16(_) => 2,
145 Self::Bool(_) => crate::BOOL_WIDTH,
146 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
147 }
148 }
149 pub const fn scalar(&self) -> crate::Scalar {
150 match *self {
151 Self::F64(_) => crate::Scalar::F64,
152 Self::F32(_) => crate::Scalar::F32,
153 Self::F16(_) => crate::Scalar::F16,
154 Self::U32(_) => crate::Scalar::U32,
155 Self::I32(_) => crate::Scalar::I32,
156 Self::U64(_) => crate::Scalar::U64,
157 Self::I64(_) => crate::Scalar::I64,
158 Self::Bool(_) => crate::Scalar::BOOL,
159 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
160 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
161 }
162 }
163 pub const fn scalar_kind(&self) -> crate::ScalarKind {
164 self.scalar().kind
165 }
166 pub const fn ty_inner(&self) -> crate::TypeInner {
167 crate::TypeInner::Scalar(self.scalar())
168 }
169}
170
171impl super::AddressSpace {
172 pub fn access(self) -> crate::StorageAccess {
173 use crate::StorageAccess as Sa;
174 match self {
175 crate::AddressSpace::Function
176 | crate::AddressSpace::Private
177 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
178 crate::AddressSpace::Uniform => Sa::LOAD,
179 crate::AddressSpace::Storage { access } => access,
180 crate::AddressSpace::Handle => Sa::LOAD,
181 crate::AddressSpace::PushConstant => Sa::LOAD,
182 }
183 }
184}
185
186impl super::MathFunction {
187 pub const fn argument_count(&self) -> usize {
188 match *self {
189 Self::Abs => 1,
191 Self::Min => 2,
192 Self::Max => 2,
193 Self::Clamp => 3,
194 Self::Saturate => 1,
195 Self::Cos => 1,
197 Self::Cosh => 1,
198 Self::Sin => 1,
199 Self::Sinh => 1,
200 Self::Tan => 1,
201 Self::Tanh => 1,
202 Self::Acos => 1,
203 Self::Asin => 1,
204 Self::Atan => 1,
205 Self::Atan2 => 2,
206 Self::Asinh => 1,
207 Self::Acosh => 1,
208 Self::Atanh => 1,
209 Self::Radians => 1,
210 Self::Degrees => 1,
211 Self::Ceil => 1,
213 Self::Floor => 1,
214 Self::Round => 1,
215 Self::Fract => 1,
216 Self::Trunc => 1,
217 Self::Modf => 1,
218 Self::Frexp => 1,
219 Self::Ldexp => 2,
220 Self::Exp => 1,
222 Self::Exp2 => 1,
223 Self::Log => 1,
224 Self::Log2 => 1,
225 Self::Pow => 2,
226 Self::Dot => 2,
228 Self::Dot4I8Packed => 2,
229 Self::Dot4U8Packed => 2,
230 Self::Outer => 2,
231 Self::Cross => 2,
232 Self::Distance => 2,
233 Self::Length => 1,
234 Self::Normalize => 1,
235 Self::FaceForward => 3,
236 Self::Reflect => 2,
237 Self::Refract => 3,
238 Self::Sign => 1,
240 Self::Fma => 3,
241 Self::Mix => 3,
242 Self::Step => 2,
243 Self::SmoothStep => 3,
244 Self::Sqrt => 1,
245 Self::InverseSqrt => 1,
246 Self::Inverse => 1,
247 Self::Transpose => 1,
248 Self::Determinant => 1,
249 Self::QuantizeToF16 => 1,
250 Self::CountTrailingZeros => 1,
252 Self::CountLeadingZeros => 1,
253 Self::CountOneBits => 1,
254 Self::ReverseBits => 1,
255 Self::ExtractBits => 3,
256 Self::InsertBits => 4,
257 Self::FirstTrailingBit => 1,
258 Self::FirstLeadingBit => 1,
259 Self::Pack4x8snorm => 1,
261 Self::Pack4x8unorm => 1,
262 Self::Pack2x16snorm => 1,
263 Self::Pack2x16unorm => 1,
264 Self::Pack2x16float => 1,
265 Self::Pack4xI8 => 1,
266 Self::Pack4xU8 => 1,
267 Self::Pack4xI8Clamp => 1,
268 Self::Pack4xU8Clamp => 1,
269 Self::Unpack4x8snorm => 1,
271 Self::Unpack4x8unorm => 1,
272 Self::Unpack2x16snorm => 1,
273 Self::Unpack2x16unorm => 1,
274 Self::Unpack2x16float => 1,
275 Self::Unpack4xI8 => 1,
276 Self::Unpack4xU8 => 1,
277 }
278 }
279}
280
281impl crate::Expression {
282 pub const fn needs_pre_emit(&self) -> bool {
284 match *self {
285 Self::Literal(_)
286 | Self::Constant(_)
287 | Self::Override(_)
288 | Self::ZeroValue(_)
289 | Self::FunctionArgument(_)
290 | Self::GlobalVariable(_)
291 | Self::LocalVariable(_) => true,
292 _ => false,
293 }
294 }
295
296 pub const fn is_dynamic_index(&self) -> bool {
310 match *self {
311 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
312 _ => true,
313 }
314 }
315}
316
317impl crate::Function {
318 pub fn originating_global(
327 &self,
328 mut pointer: crate::Handle<crate::Expression>,
329 ) -> Option<crate::Handle<crate::GlobalVariable>> {
330 loop {
331 pointer = match self.expressions[pointer] {
332 crate::Expression::Access { base, .. } => base,
333 crate::Expression::AccessIndex { base, .. } => base,
334 crate::Expression::GlobalVariable(handle) => return Some(handle),
335 crate::Expression::LocalVariable(_) => return None,
336 crate::Expression::FunctionArgument(_) => return None,
337 _ => unreachable!(),
339 }
340 }
341 }
342}
343
344impl crate::SampleLevel {
345 pub const fn implicit_derivatives(&self) -> bool {
346 match *self {
347 Self::Auto | Self::Bias(_) => true,
348 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
349 }
350 }
351}
352
353impl crate::Binding {
354 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
355 match *self {
356 crate::Binding::BuiltIn(built_in) => Some(built_in),
357 Self::Location { .. } => None,
358 }
359 }
360}
361
362impl super::SwizzleComponent {
363 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
364
365 pub const fn index(&self) -> u32 {
366 match *self {
367 Self::X => 0,
368 Self::Y => 1,
369 Self::Z => 2,
370 Self::W => 3,
371 }
372 }
373 pub const fn from_index(idx: u32) -> Self {
374 match idx {
375 0 => Self::X,
376 1 => Self::Y,
377 2 => Self::Z,
378 _ => Self::W,
379 }
380 }
381}
382
383impl super::ImageClass {
384 pub const fn is_multisampled(self) -> bool {
385 match self {
386 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
387 crate::ImageClass::Storage { .. } => false,
388 crate::ImageClass::External => false,
389 }
390 }
391
392 pub const fn is_mipmapped(self) -> bool {
393 match self {
394 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
395 crate::ImageClass::Storage { .. } => false,
396 crate::ImageClass::External => false,
397 }
398 }
399
400 pub const fn is_depth(self) -> bool {
401 matches!(self, crate::ImageClass::Depth { .. })
402 }
403}
404
405impl crate::Module {
406 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
407 GlobalCtx {
408 types: &self.types,
409 constants: &self.constants,
410 overrides: &self.overrides,
411 global_expressions: &self.global_expressions,
412 }
413 }
414
415 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
416 compare_types(lhs, rhs, &self.types)
417 }
418}
419
420#[derive(Debug)]
421pub(super) enum U32EvalError {
422 NonConst,
423 Negative,
424}
425
426#[derive(Clone, Copy)]
427pub struct GlobalCtx<'a> {
428 pub types: &'a crate::UniqueArena<crate::Type>,
429 pub constants: &'a crate::Arena<crate::Constant>,
430 pub overrides: &'a crate::Arena<crate::Override>,
431 pub global_expressions: &'a crate::Arena<crate::Expression>,
432}
433
434impl GlobalCtx<'_> {
435 #[allow(dead_code)]
437 pub(super) fn eval_expr_to_u32(
438 &self,
439 handle: crate::Handle<crate::Expression>,
440 ) -> Result<u32, U32EvalError> {
441 self.eval_expr_to_u32_from(handle, self.global_expressions)
442 }
443
444 pub(super) fn eval_expr_to_u32_from(
446 &self,
447 handle: crate::Handle<crate::Expression>,
448 arena: &crate::Arena<crate::Expression>,
449 ) -> Result<u32, U32EvalError> {
450 match self.eval_expr_to_literal_from(handle, arena) {
451 Some(crate::Literal::U32(value)) => Ok(value),
452 Some(crate::Literal::I32(value)) => {
453 value.try_into().map_err(|_| U32EvalError::Negative)
454 }
455 _ => Err(U32EvalError::NonConst),
456 }
457 }
458
459 #[allow(dead_code)]
461 pub(super) fn eval_expr_to_bool_from(
462 &self,
463 handle: crate::Handle<crate::Expression>,
464 arena: &crate::Arena<crate::Expression>,
465 ) -> Option<bool> {
466 match self.eval_expr_to_literal_from(handle, arena) {
467 Some(crate::Literal::Bool(value)) => Some(value),
468 _ => None,
469 }
470 }
471
472 #[allow(dead_code)]
473 pub(crate) fn eval_expr_to_literal(
474 &self,
475 handle: crate::Handle<crate::Expression>,
476 ) -> Option<crate::Literal> {
477 self.eval_expr_to_literal_from(handle, self.global_expressions)
478 }
479
480 pub(super) fn eval_expr_to_literal_from(
481 &self,
482 handle: crate::Handle<crate::Expression>,
483 arena: &crate::Arena<crate::Expression>,
484 ) -> Option<crate::Literal> {
485 fn get(
486 gctx: GlobalCtx,
487 handle: crate::Handle<crate::Expression>,
488 arena: &crate::Arena<crate::Expression>,
489 ) -> Option<crate::Literal> {
490 match arena[handle] {
491 crate::Expression::Literal(literal) => Some(literal),
492 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
493 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
494 _ => None,
495 },
496 _ => None,
497 }
498 }
499 match arena[handle] {
500 crate::Expression::Constant(c) => {
501 get(*self, self.constants[c].init, self.global_expressions)
502 }
503 _ => get(*self, handle, arena),
504 }
505 }
506
507 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
508 compare_types(lhs, rhs, self.types)
509 }
510}
511
512#[derive(Error, Debug, Clone, Copy, PartialEq)]
513pub enum ResolveArraySizeError {
514 #[error("array element count must be positive (> 0)")]
515 ExpectedPositiveArrayLength,
516 #[error("internal: array size override has not been resolved")]
517 NonConstArrayLength,
518}
519
520impl crate::ArraySize {
521 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
531 match *self {
532 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
533 crate::ArraySize::Pending(handle) => {
534 let Some(expr) = gctx.overrides[handle].init else {
535 return Err(ResolveArraySizeError::NonConstArrayLength);
536 };
537 let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
538 U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
539 U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
540 })?;
541
542 if length == 0 {
543 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
544 }
545
546 Ok(IndexableLength::Known(length))
547 }
548 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
549 }
550 }
551}
552
553pub fn flatten_compose<'arenas>(
566 ty: crate::Handle<crate::Type>,
567 components: &'arenas [crate::Handle<crate::Expression>],
568 expressions: &'arenas crate::Arena<crate::Expression>,
569 types: &'arenas crate::UniqueArena<crate::Type>,
570) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
571 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
577 (size as usize, true)
578 } else {
579 (components.len(), false)
580 };
581
582 fn flatten_compose<'c>(
584 component: &'c crate::Handle<crate::Expression>,
585 is_vector: bool,
586 expressions: &'c crate::Arena<crate::Expression>,
587 ) -> &'c [crate::Handle<crate::Expression>] {
588 if is_vector {
589 if let crate::Expression::Compose {
590 ty: _,
591 components: ref subcomponents,
592 } = expressions[*component]
593 {
594 return subcomponents;
595 }
596 }
597 core::slice::from_ref(component)
598 }
599
600 fn flatten_splat<'c>(
602 component: &'c crate::Handle<crate::Expression>,
603 is_vector: bool,
604 expressions: &'c crate::Arena<crate::Expression>,
605 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
606 let mut expr = *component;
607 let mut count = 1;
608 if is_vector {
609 if let crate::Expression::Splat { size, value } = expressions[expr] {
610 expr = value;
611 count = size as usize;
612 }
613 }
614 core::iter::repeat_n(expr, count)
615 }
616
617 components
624 .iter()
625 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
626 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
627 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
628 .take(size)
629}
630
631#[test]
632fn test_matrix_size() {
633 let module = crate::Module::default();
634 assert_eq!(
635 crate::TypeInner::Matrix {
636 columns: crate::VectorSize::Tri,
637 rows: crate::VectorSize::Tri,
638 scalar: crate::Scalar::F32,
639 }
640 .size(module.to_ctx()),
641 48,
642 );
643}