1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod overloads;
11mod terminator;
12mod type_methods;
13mod typifier;
14
15pub use constant_evaluator::{
16 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
17};
18pub use emitter::Emitter;
19pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
20pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
21pub use namer::{EntryPointIndex, NameKey, Namer};
22pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
23pub use terminator::ensure_block_returns;
24use thiserror::Error;
25pub use type_methods::min_max_float_representable_by;
26pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
27
28impl From<super::StorageFormat> for super::Scalar {
29 fn from(format: super::StorageFormat) -> Self {
30 use super::{ScalarKind as Sk, StorageFormat as Sf};
31 let kind = match format {
32 Sf::R8Unorm => Sk::Float,
33 Sf::R8Snorm => Sk::Float,
34 Sf::R8Uint => Sk::Uint,
35 Sf::R8Sint => Sk::Sint,
36 Sf::R16Uint => Sk::Uint,
37 Sf::R16Sint => Sk::Sint,
38 Sf::R16Float => Sk::Float,
39 Sf::Rg8Unorm => Sk::Float,
40 Sf::Rg8Snorm => Sk::Float,
41 Sf::Rg8Uint => Sk::Uint,
42 Sf::Rg8Sint => Sk::Sint,
43 Sf::R32Uint => Sk::Uint,
44 Sf::R32Sint => Sk::Sint,
45 Sf::R32Float => Sk::Float,
46 Sf::Rg16Uint => Sk::Uint,
47 Sf::Rg16Sint => Sk::Sint,
48 Sf::Rg16Float => Sk::Float,
49 Sf::Rgba8Unorm => Sk::Float,
50 Sf::Rgba8Snorm => Sk::Float,
51 Sf::Rgba8Uint => Sk::Uint,
52 Sf::Rgba8Sint => Sk::Sint,
53 Sf::Bgra8Unorm => Sk::Float,
54 Sf::Rgb10a2Uint => Sk::Uint,
55 Sf::Rgb10a2Unorm => Sk::Float,
56 Sf::Rg11b10Ufloat => Sk::Float,
57 Sf::R64Uint => Sk::Uint,
58 Sf::Rg32Uint => Sk::Uint,
59 Sf::Rg32Sint => Sk::Sint,
60 Sf::Rg32Float => Sk::Float,
61 Sf::Rgba16Uint => Sk::Uint,
62 Sf::Rgba16Sint => Sk::Sint,
63 Sf::Rgba16Float => Sk::Float,
64 Sf::Rgba32Uint => Sk::Uint,
65 Sf::Rgba32Sint => Sk::Sint,
66 Sf::Rgba32Float => Sk::Float,
67 Sf::R16Unorm => Sk::Float,
68 Sf::R16Snorm => Sk::Float,
69 Sf::Rg16Unorm => Sk::Float,
70 Sf::Rg16Snorm => Sk::Float,
71 Sf::Rgba16Unorm => Sk::Float,
72 Sf::Rgba16Snorm => Sk::Float,
73 };
74 let width = match format {
75 Sf::R64Uint => 8,
76 _ => 4,
77 };
78 super::Scalar { kind, width }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum HashableLiteral {
84 F64(u64),
85 F32(u32),
86 F16(u16),
87 U32(u32),
88 I32(i32),
89 U64(u64),
90 I64(i64),
91 Bool(bool),
92 AbstractInt(i64),
93 AbstractFloat(u64),
94}
95
96impl From<crate::Literal> for HashableLiteral {
97 fn from(l: crate::Literal) -> Self {
98 match l {
99 crate::Literal::F64(v) => Self::F64(v.to_bits()),
100 crate::Literal::F32(v) => Self::F32(v.to_bits()),
101 crate::Literal::F16(v) => Self::F16(v.to_bits()),
102 crate::Literal::U32(v) => Self::U32(v),
103 crate::Literal::I32(v) => Self::I32(v),
104 crate::Literal::U64(v) => Self::U64(v),
105 crate::Literal::I64(v) => Self::I64(v),
106 crate::Literal::Bool(v) => Self::Bool(v),
107 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
108 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
109 }
110 }
111}
112
113impl crate::Literal {
114 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
115 match (value, scalar.kind, scalar.width) {
116 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
117 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
118 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
119 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
120 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
121 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
122 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
123 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
124 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
125 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
126 _ => None,
127 }
128 }
129
130 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
131 Self::new(0, scalar)
132 }
133
134 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
135 Self::new(1, scalar)
136 }
137
138 pub const fn width(&self) -> crate::Bytes {
139 match *self {
140 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
141 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
142 Self::F16(_) => 2,
143 Self::Bool(_) => crate::BOOL_WIDTH,
144 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
145 }
146 }
147 pub const fn scalar(&self) -> crate::Scalar {
148 match *self {
149 Self::F64(_) => crate::Scalar::F64,
150 Self::F32(_) => crate::Scalar::F32,
151 Self::F16(_) => crate::Scalar::F16,
152 Self::U32(_) => crate::Scalar::U32,
153 Self::I32(_) => crate::Scalar::I32,
154 Self::U64(_) => crate::Scalar::U64,
155 Self::I64(_) => crate::Scalar::I64,
156 Self::Bool(_) => crate::Scalar::BOOL,
157 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
158 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
159 }
160 }
161 pub const fn scalar_kind(&self) -> crate::ScalarKind {
162 self.scalar().kind
163 }
164 pub const fn ty_inner(&self) -> crate::TypeInner {
165 crate::TypeInner::Scalar(self.scalar())
166 }
167}
168
169impl super::AddressSpace {
170 pub fn access(self) -> crate::StorageAccess {
171 use crate::StorageAccess as Sa;
172 match self {
173 crate::AddressSpace::Function
174 | crate::AddressSpace::Private
175 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
176 crate::AddressSpace::Uniform => Sa::LOAD,
177 crate::AddressSpace::Storage { access } => access,
178 crate::AddressSpace::Handle => Sa::LOAD,
179 crate::AddressSpace::PushConstant => Sa::LOAD,
180 }
181 }
182}
183
184impl super::MathFunction {
185 pub const fn argument_count(&self) -> usize {
186 match *self {
187 Self::Abs => 1,
189 Self::Min => 2,
190 Self::Max => 2,
191 Self::Clamp => 3,
192 Self::Saturate => 1,
193 Self::Cos => 1,
195 Self::Cosh => 1,
196 Self::Sin => 1,
197 Self::Sinh => 1,
198 Self::Tan => 1,
199 Self::Tanh => 1,
200 Self::Acos => 1,
201 Self::Asin => 1,
202 Self::Atan => 1,
203 Self::Atan2 => 2,
204 Self::Asinh => 1,
205 Self::Acosh => 1,
206 Self::Atanh => 1,
207 Self::Radians => 1,
208 Self::Degrees => 1,
209 Self::Ceil => 1,
211 Self::Floor => 1,
212 Self::Round => 1,
213 Self::Fract => 1,
214 Self::Trunc => 1,
215 Self::Modf => 1,
216 Self::Frexp => 1,
217 Self::Ldexp => 2,
218 Self::Exp => 1,
220 Self::Exp2 => 1,
221 Self::Log => 1,
222 Self::Log2 => 1,
223 Self::Pow => 2,
224 Self::Dot => 2,
226 Self::Dot4I8Packed => 2,
227 Self::Dot4U8Packed => 2,
228 Self::Outer => 2,
229 Self::Cross => 2,
230 Self::Distance => 2,
231 Self::Length => 1,
232 Self::Normalize => 1,
233 Self::FaceForward => 3,
234 Self::Reflect => 2,
235 Self::Refract => 3,
236 Self::Sign => 1,
238 Self::Fma => 3,
239 Self::Mix => 3,
240 Self::Step => 2,
241 Self::SmoothStep => 3,
242 Self::Sqrt => 1,
243 Self::InverseSqrt => 1,
244 Self::Inverse => 1,
245 Self::Transpose => 1,
246 Self::Determinant => 1,
247 Self::QuantizeToF16 => 1,
248 Self::CountTrailingZeros => 1,
250 Self::CountLeadingZeros => 1,
251 Self::CountOneBits => 1,
252 Self::ReverseBits => 1,
253 Self::ExtractBits => 3,
254 Self::InsertBits => 4,
255 Self::FirstTrailingBit => 1,
256 Self::FirstLeadingBit => 1,
257 Self::Pack4x8snorm => 1,
259 Self::Pack4x8unorm => 1,
260 Self::Pack2x16snorm => 1,
261 Self::Pack2x16unorm => 1,
262 Self::Pack2x16float => 1,
263 Self::Pack4xI8 => 1,
264 Self::Pack4xU8 => 1,
265 Self::Pack4xI8Clamp => 1,
266 Self::Pack4xU8Clamp => 1,
267 Self::Unpack4x8snorm => 1,
269 Self::Unpack4x8unorm => 1,
270 Self::Unpack2x16snorm => 1,
271 Self::Unpack2x16unorm => 1,
272 Self::Unpack2x16float => 1,
273 Self::Unpack4xI8 => 1,
274 Self::Unpack4xU8 => 1,
275 }
276 }
277}
278
279impl crate::Expression {
280 pub const fn needs_pre_emit(&self) -> bool {
282 match *self {
283 Self::Literal(_)
284 | Self::Constant(_)
285 | Self::Override(_)
286 | Self::ZeroValue(_)
287 | Self::FunctionArgument(_)
288 | Self::GlobalVariable(_)
289 | Self::LocalVariable(_) => true,
290 _ => false,
291 }
292 }
293
294 pub const fn is_dynamic_index(&self) -> bool {
308 match *self {
309 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
310 _ => true,
311 }
312 }
313}
314
315impl crate::Function {
316 pub fn originating_global(
325 &self,
326 mut pointer: crate::Handle<crate::Expression>,
327 ) -> Option<crate::Handle<crate::GlobalVariable>> {
328 loop {
329 pointer = match self.expressions[pointer] {
330 crate::Expression::Access { base, .. } => base,
331 crate::Expression::AccessIndex { base, .. } => base,
332 crate::Expression::GlobalVariable(handle) => return Some(handle),
333 crate::Expression::LocalVariable(_) => return None,
334 crate::Expression::FunctionArgument(_) => return None,
335 _ => unreachable!(),
337 }
338 }
339 }
340}
341
342impl crate::SampleLevel {
343 pub const fn implicit_derivatives(&self) -> bool {
344 match *self {
345 Self::Auto | Self::Bias(_) => true,
346 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
347 }
348 }
349}
350
351impl crate::Binding {
352 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
353 match *self {
354 crate::Binding::BuiltIn(built_in) => Some(built_in),
355 Self::Location { .. } => None,
356 }
357 }
358}
359
360impl super::SwizzleComponent {
361 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
362
363 pub const fn index(&self) -> u32 {
364 match *self {
365 Self::X => 0,
366 Self::Y => 1,
367 Self::Z => 2,
368 Self::W => 3,
369 }
370 }
371 pub const fn from_index(idx: u32) -> Self {
372 match idx {
373 0 => Self::X,
374 1 => Self::Y,
375 2 => Self::Z,
376 _ => Self::W,
377 }
378 }
379}
380
381impl super::ImageClass {
382 pub const fn is_multisampled(self) -> bool {
383 match self {
384 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
385 crate::ImageClass::Storage { .. } => false,
386 }
387 }
388
389 pub const fn is_mipmapped(self) -> bool {
390 match self {
391 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
392 crate::ImageClass::Storage { .. } => false,
393 }
394 }
395
396 pub const fn is_depth(self) -> bool {
397 matches!(self, crate::ImageClass::Depth { .. })
398 }
399}
400
401impl crate::Module {
402 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
403 GlobalCtx {
404 types: &self.types,
405 constants: &self.constants,
406 overrides: &self.overrides,
407 global_expressions: &self.global_expressions,
408 }
409 }
410
411 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
412 compare_types(lhs, rhs, &self.types)
413 }
414}
415
416#[derive(Debug)]
417pub(super) enum U32EvalError {
418 NonConst,
419 Negative,
420}
421
422#[derive(Clone, Copy)]
423pub struct GlobalCtx<'a> {
424 pub types: &'a crate::UniqueArena<crate::Type>,
425 pub constants: &'a crate::Arena<crate::Constant>,
426 pub overrides: &'a crate::Arena<crate::Override>,
427 pub global_expressions: &'a crate::Arena<crate::Expression>,
428}
429
430impl GlobalCtx<'_> {
431 #[allow(dead_code)]
433 pub(super) fn eval_expr_to_u32(
434 &self,
435 handle: crate::Handle<crate::Expression>,
436 ) -> Result<u32, U32EvalError> {
437 self.eval_expr_to_u32_from(handle, self.global_expressions)
438 }
439
440 pub(super) fn eval_expr_to_u32_from(
442 &self,
443 handle: crate::Handle<crate::Expression>,
444 arena: &crate::Arena<crate::Expression>,
445 ) -> Result<u32, U32EvalError> {
446 match self.eval_expr_to_literal_from(handle, arena) {
447 Some(crate::Literal::U32(value)) => Ok(value),
448 Some(crate::Literal::I32(value)) => {
449 value.try_into().map_err(|_| U32EvalError::Negative)
450 }
451 _ => Err(U32EvalError::NonConst),
452 }
453 }
454
455 #[allow(dead_code)]
457 pub(super) fn eval_expr_to_bool_from(
458 &self,
459 handle: crate::Handle<crate::Expression>,
460 arena: &crate::Arena<crate::Expression>,
461 ) -> Option<bool> {
462 match self.eval_expr_to_literal_from(handle, arena) {
463 Some(crate::Literal::Bool(value)) => Some(value),
464 _ => None,
465 }
466 }
467
468 #[allow(dead_code)]
469 pub(crate) fn eval_expr_to_literal(
470 &self,
471 handle: crate::Handle<crate::Expression>,
472 ) -> Option<crate::Literal> {
473 self.eval_expr_to_literal_from(handle, self.global_expressions)
474 }
475
476 pub(super) fn eval_expr_to_literal_from(
477 &self,
478 handle: crate::Handle<crate::Expression>,
479 arena: &crate::Arena<crate::Expression>,
480 ) -> Option<crate::Literal> {
481 fn get(
482 gctx: GlobalCtx,
483 handle: crate::Handle<crate::Expression>,
484 arena: &crate::Arena<crate::Expression>,
485 ) -> Option<crate::Literal> {
486 match arena[handle] {
487 crate::Expression::Literal(literal) => Some(literal),
488 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
489 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
490 _ => None,
491 },
492 _ => None,
493 }
494 }
495 match arena[handle] {
496 crate::Expression::Constant(c) => {
497 get(*self, self.constants[c].init, self.global_expressions)
498 }
499 _ => get(*self, handle, arena),
500 }
501 }
502
503 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
504 compare_types(lhs, rhs, self.types)
505 }
506}
507
508#[derive(Error, Debug, Clone, Copy, PartialEq)]
509pub enum ResolveArraySizeError {
510 #[error("array element count must be positive (> 0)")]
511 ExpectedPositiveArrayLength,
512 #[error("internal: array size override has not been resolved")]
513 NonConstArrayLength,
514}
515
516impl crate::ArraySize {
517 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
527 match *self {
528 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
529 crate::ArraySize::Pending(handle) => {
530 let Some(expr) = gctx.overrides[handle].init else {
531 return Err(ResolveArraySizeError::NonConstArrayLength);
532 };
533 let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
534 U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
535 U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
536 })?;
537
538 if length == 0 {
539 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
540 }
541
542 Ok(IndexableLength::Known(length))
543 }
544 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
545 }
546 }
547}
548
549pub fn flatten_compose<'arenas>(
562 ty: crate::Handle<crate::Type>,
563 components: &'arenas [crate::Handle<crate::Expression>],
564 expressions: &'arenas crate::Arena<crate::Expression>,
565 types: &'arenas crate::UniqueArena<crate::Type>,
566) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
567 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
573 (size as usize, true)
574 } else {
575 (components.len(), false)
576 };
577
578 fn flatten_compose<'c>(
580 component: &'c crate::Handle<crate::Expression>,
581 is_vector: bool,
582 expressions: &'c crate::Arena<crate::Expression>,
583 ) -> &'c [crate::Handle<crate::Expression>] {
584 if is_vector {
585 if let crate::Expression::Compose {
586 ty: _,
587 components: ref subcomponents,
588 } = expressions[*component]
589 {
590 return subcomponents;
591 }
592 }
593 core::slice::from_ref(component)
594 }
595
596 fn flatten_splat<'c>(
598 component: &'c crate::Handle<crate::Expression>,
599 is_vector: bool,
600 expressions: &'c crate::Arena<crate::Expression>,
601 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
602 let mut expr = *component;
603 let mut count = 1;
604 if is_vector {
605 if let crate::Expression::Splat { size, value } = expressions[expr] {
606 expr = value;
607 count = size as usize;
608 }
609 }
610 core::iter::repeat_n(expr, count)
611 }
612
613 components
620 .iter()
621 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
622 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
623 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
624 .take(size)
625}
626
627#[test]
628fn test_matrix_size() {
629 let module = crate::Module::default();
630 assert_eq!(
631 crate::TypeInner::Matrix {
632 columns: crate::VectorSize::Tri,
633 rows: crate::VectorSize::Tri,
634 scalar: crate::Scalar::F32,
635 }
636 .size(module.to_ctx()),
637 48,
638 );
639}