1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod overloads;
11mod terminator;
12mod type_methods;
13mod typifier;
14
15pub use constant_evaluator::{
16 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
17};
18pub use emitter::Emitter;
19pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
20pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
21pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
22pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
23pub use terminator::ensure_block_returns;
24use thiserror::Error;
25pub use type_methods::min_max_float_representable_by;
26pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
27
28impl From<super::StorageFormat> for super::Scalar {
29 fn from(format: super::StorageFormat) -> Self {
30 use super::{ScalarKind as Sk, StorageFormat as Sf};
31 let kind = match format {
32 Sf::R8Unorm => Sk::Float,
33 Sf::R8Snorm => Sk::Float,
34 Sf::R8Uint => Sk::Uint,
35 Sf::R8Sint => Sk::Sint,
36 Sf::R16Uint => Sk::Uint,
37 Sf::R16Sint => Sk::Sint,
38 Sf::R16Float => Sk::Float,
39 Sf::Rg8Unorm => Sk::Float,
40 Sf::Rg8Snorm => Sk::Float,
41 Sf::Rg8Uint => Sk::Uint,
42 Sf::Rg8Sint => Sk::Sint,
43 Sf::R32Uint => Sk::Uint,
44 Sf::R32Sint => Sk::Sint,
45 Sf::R32Float => Sk::Float,
46 Sf::Rg16Uint => Sk::Uint,
47 Sf::Rg16Sint => Sk::Sint,
48 Sf::Rg16Float => Sk::Float,
49 Sf::Rgba8Unorm => Sk::Float,
50 Sf::Rgba8Snorm => Sk::Float,
51 Sf::Rgba8Uint => Sk::Uint,
52 Sf::Rgba8Sint => Sk::Sint,
53 Sf::Bgra8Unorm => Sk::Float,
54 Sf::Rgb10a2Uint => Sk::Uint,
55 Sf::Rgb10a2Unorm => Sk::Float,
56 Sf::Rg11b10Ufloat => Sk::Float,
57 Sf::R64Uint => Sk::Uint,
58 Sf::Rg32Uint => Sk::Uint,
59 Sf::Rg32Sint => Sk::Sint,
60 Sf::Rg32Float => Sk::Float,
61 Sf::Rgba16Uint => Sk::Uint,
62 Sf::Rgba16Sint => Sk::Sint,
63 Sf::Rgba16Float => Sk::Float,
64 Sf::Rgba32Uint => Sk::Uint,
65 Sf::Rgba32Sint => Sk::Sint,
66 Sf::Rgba32Float => Sk::Float,
67 Sf::R16Unorm => Sk::Float,
68 Sf::R16Snorm => Sk::Float,
69 Sf::Rg16Unorm => Sk::Float,
70 Sf::Rg16Snorm => Sk::Float,
71 Sf::Rgba16Unorm => Sk::Float,
72 Sf::Rgba16Snorm => Sk::Float,
73 };
74 let width = match format {
75 Sf::R64Uint => 8,
76 _ => 4,
77 };
78 super::Scalar { kind, width }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum HashableLiteral {
84 F64(u64),
85 F32(u32),
86 F16(u16),
87 U32(u32),
88 I32(i32),
89 U64(u64),
90 I64(i64),
91 Bool(bool),
92 AbstractInt(i64),
93 AbstractFloat(u64),
94}
95
96impl From<crate::Literal> for HashableLiteral {
97 fn from(l: crate::Literal) -> Self {
98 match l {
99 crate::Literal::F64(v) => Self::F64(v.to_bits()),
100 crate::Literal::F32(v) => Self::F32(v.to_bits()),
101 crate::Literal::F16(v) => Self::F16(v.to_bits()),
102 crate::Literal::U32(v) => Self::U32(v),
103 crate::Literal::I32(v) => Self::I32(v),
104 crate::Literal::U64(v) => Self::U64(v),
105 crate::Literal::I64(v) => Self::I64(v),
106 crate::Literal::Bool(v) => Self::Bool(v),
107 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
108 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
109 }
110 }
111}
112
113impl crate::Literal {
114 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
115 match (value, scalar.kind, scalar.width) {
116 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
117 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
118 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
119 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
120 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
121 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
122 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
123 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
124 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
125 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
126 _ => None,
127 }
128 }
129
130 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
131 Self::new(0, scalar)
132 }
133
134 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
135 Self::new(1, scalar)
136 }
137
138 pub const fn width(&self) -> crate::Bytes {
139 match *self {
140 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
141 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
142 Self::F16(_) => 2,
143 Self::Bool(_) => crate::BOOL_WIDTH,
144 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
145 }
146 }
147 pub const fn scalar(&self) -> crate::Scalar {
148 match *self {
149 Self::F64(_) => crate::Scalar::F64,
150 Self::F32(_) => crate::Scalar::F32,
151 Self::F16(_) => crate::Scalar::F16,
152 Self::U32(_) => crate::Scalar::U32,
153 Self::I32(_) => crate::Scalar::I32,
154 Self::U64(_) => crate::Scalar::U64,
155 Self::I64(_) => crate::Scalar::I64,
156 Self::Bool(_) => crate::Scalar::BOOL,
157 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
158 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
159 }
160 }
161 pub const fn scalar_kind(&self) -> crate::ScalarKind {
162 self.scalar().kind
163 }
164 pub const fn ty_inner(&self) -> crate::TypeInner {
165 crate::TypeInner::Scalar(self.scalar())
166 }
167}
168
169impl super::AddressSpace {
170 pub fn access(self) -> crate::StorageAccess {
171 use crate::StorageAccess as Sa;
172 match self {
173 crate::AddressSpace::Function
174 | crate::AddressSpace::Private
175 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
176 crate::AddressSpace::Uniform => Sa::LOAD,
177 crate::AddressSpace::Storage { access } => access,
178 crate::AddressSpace::Handle => Sa::LOAD,
179 crate::AddressSpace::PushConstant => Sa::LOAD,
180 }
181 }
182}
183
184impl super::MathFunction {
185 pub const fn argument_count(&self) -> usize {
186 match *self {
187 Self::Abs => 1,
189 Self::Min => 2,
190 Self::Max => 2,
191 Self::Clamp => 3,
192 Self::Saturate => 1,
193 Self::Cos => 1,
195 Self::Cosh => 1,
196 Self::Sin => 1,
197 Self::Sinh => 1,
198 Self::Tan => 1,
199 Self::Tanh => 1,
200 Self::Acos => 1,
201 Self::Asin => 1,
202 Self::Atan => 1,
203 Self::Atan2 => 2,
204 Self::Asinh => 1,
205 Self::Acosh => 1,
206 Self::Atanh => 1,
207 Self::Radians => 1,
208 Self::Degrees => 1,
209 Self::Ceil => 1,
211 Self::Floor => 1,
212 Self::Round => 1,
213 Self::Fract => 1,
214 Self::Trunc => 1,
215 Self::Modf => 1,
216 Self::Frexp => 1,
217 Self::Ldexp => 2,
218 Self::Exp => 1,
220 Self::Exp2 => 1,
221 Self::Log => 1,
222 Self::Log2 => 1,
223 Self::Pow => 2,
224 Self::Dot => 2,
226 Self::Dot4I8Packed => 2,
227 Self::Dot4U8Packed => 2,
228 Self::Outer => 2,
229 Self::Cross => 2,
230 Self::Distance => 2,
231 Self::Length => 1,
232 Self::Normalize => 1,
233 Self::FaceForward => 3,
234 Self::Reflect => 2,
235 Self::Refract => 3,
236 Self::Sign => 1,
238 Self::Fma => 3,
239 Self::Mix => 3,
240 Self::Step => 2,
241 Self::SmoothStep => 3,
242 Self::Sqrt => 1,
243 Self::InverseSqrt => 1,
244 Self::Inverse => 1,
245 Self::Transpose => 1,
246 Self::Determinant => 1,
247 Self::QuantizeToF16 => 1,
248 Self::CountTrailingZeros => 1,
250 Self::CountLeadingZeros => 1,
251 Self::CountOneBits => 1,
252 Self::ReverseBits => 1,
253 Self::ExtractBits => 3,
254 Self::InsertBits => 4,
255 Self::FirstTrailingBit => 1,
256 Self::FirstLeadingBit => 1,
257 Self::Pack4x8snorm => 1,
259 Self::Pack4x8unorm => 1,
260 Self::Pack2x16snorm => 1,
261 Self::Pack2x16unorm => 1,
262 Self::Pack2x16float => 1,
263 Self::Pack4xI8 => 1,
264 Self::Pack4xU8 => 1,
265 Self::Pack4xI8Clamp => 1,
266 Self::Pack4xU8Clamp => 1,
267 Self::Unpack4x8snorm => 1,
269 Self::Unpack4x8unorm => 1,
270 Self::Unpack2x16snorm => 1,
271 Self::Unpack2x16unorm => 1,
272 Self::Unpack2x16float => 1,
273 Self::Unpack4xI8 => 1,
274 Self::Unpack4xU8 => 1,
275 }
276 }
277}
278
279impl crate::Expression {
280 pub const fn needs_pre_emit(&self) -> bool {
282 match *self {
283 Self::Literal(_)
284 | Self::Constant(_)
285 | Self::Override(_)
286 | Self::ZeroValue(_)
287 | Self::FunctionArgument(_)
288 | Self::GlobalVariable(_)
289 | Self::LocalVariable(_) => true,
290 _ => false,
291 }
292 }
293
294 pub const fn is_dynamic_index(&self) -> bool {
308 match *self {
309 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
310 _ => true,
311 }
312 }
313}
314
315impl crate::Function {
316 pub fn originating_global(
325 &self,
326 mut pointer: crate::Handle<crate::Expression>,
327 ) -> Option<crate::Handle<crate::GlobalVariable>> {
328 loop {
329 pointer = match self.expressions[pointer] {
330 crate::Expression::Access { base, .. } => base,
331 crate::Expression::AccessIndex { base, .. } => base,
332 crate::Expression::GlobalVariable(handle) => return Some(handle),
333 crate::Expression::LocalVariable(_) => return None,
334 crate::Expression::FunctionArgument(_) => return None,
335 _ => unreachable!(),
337 }
338 }
339 }
340}
341
342impl crate::SampleLevel {
343 pub const fn implicit_derivatives(&self) -> bool {
344 match *self {
345 Self::Auto | Self::Bias(_) => true,
346 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
347 }
348 }
349}
350
351impl crate::Binding {
352 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
353 match *self {
354 crate::Binding::BuiltIn(built_in) => Some(built_in),
355 Self::Location { .. } => None,
356 }
357 }
358}
359
360impl super::SwizzleComponent {
361 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
362
363 pub const fn index(&self) -> u32 {
364 match *self {
365 Self::X => 0,
366 Self::Y => 1,
367 Self::Z => 2,
368 Self::W => 3,
369 }
370 }
371 pub const fn from_index(idx: u32) -> Self {
372 match idx {
373 0 => Self::X,
374 1 => Self::Y,
375 2 => Self::Z,
376 _ => Self::W,
377 }
378 }
379}
380
381impl super::ImageClass {
382 pub const fn is_multisampled(self) -> bool {
383 match self {
384 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
385 crate::ImageClass::Storage { .. } => false,
386 crate::ImageClass::External => false,
387 }
388 }
389
390 pub const fn is_mipmapped(self) -> bool {
391 match self {
392 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
393 crate::ImageClass::Storage { .. } => false,
394 crate::ImageClass::External => false,
395 }
396 }
397
398 pub const fn is_depth(self) -> bool {
399 matches!(self, crate::ImageClass::Depth { .. })
400 }
401}
402
403impl crate::Module {
404 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
405 GlobalCtx {
406 types: &self.types,
407 constants: &self.constants,
408 overrides: &self.overrides,
409 global_expressions: &self.global_expressions,
410 }
411 }
412
413 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
414 compare_types(lhs, rhs, &self.types)
415 }
416}
417
418#[derive(Debug)]
419pub(super) enum U32EvalError {
420 NonConst,
421 Negative,
422}
423
424#[derive(Clone, Copy)]
425pub struct GlobalCtx<'a> {
426 pub types: &'a crate::UniqueArena<crate::Type>,
427 pub constants: &'a crate::Arena<crate::Constant>,
428 pub overrides: &'a crate::Arena<crate::Override>,
429 pub global_expressions: &'a crate::Arena<crate::Expression>,
430}
431
432impl GlobalCtx<'_> {
433 #[allow(dead_code)]
435 pub(super) fn eval_expr_to_u32(
436 &self,
437 handle: crate::Handle<crate::Expression>,
438 ) -> Result<u32, U32EvalError> {
439 self.eval_expr_to_u32_from(handle, self.global_expressions)
440 }
441
442 pub(super) fn eval_expr_to_u32_from(
444 &self,
445 handle: crate::Handle<crate::Expression>,
446 arena: &crate::Arena<crate::Expression>,
447 ) -> Result<u32, U32EvalError> {
448 match self.eval_expr_to_literal_from(handle, arena) {
449 Some(crate::Literal::U32(value)) => Ok(value),
450 Some(crate::Literal::I32(value)) => {
451 value.try_into().map_err(|_| U32EvalError::Negative)
452 }
453 _ => Err(U32EvalError::NonConst),
454 }
455 }
456
457 #[allow(dead_code)]
459 pub(super) fn eval_expr_to_bool_from(
460 &self,
461 handle: crate::Handle<crate::Expression>,
462 arena: &crate::Arena<crate::Expression>,
463 ) -> Option<bool> {
464 match self.eval_expr_to_literal_from(handle, arena) {
465 Some(crate::Literal::Bool(value)) => Some(value),
466 _ => None,
467 }
468 }
469
470 #[allow(dead_code)]
471 pub(crate) fn eval_expr_to_literal(
472 &self,
473 handle: crate::Handle<crate::Expression>,
474 ) -> Option<crate::Literal> {
475 self.eval_expr_to_literal_from(handle, self.global_expressions)
476 }
477
478 pub(super) fn eval_expr_to_literal_from(
479 &self,
480 handle: crate::Handle<crate::Expression>,
481 arena: &crate::Arena<crate::Expression>,
482 ) -> Option<crate::Literal> {
483 fn get(
484 gctx: GlobalCtx,
485 handle: crate::Handle<crate::Expression>,
486 arena: &crate::Arena<crate::Expression>,
487 ) -> Option<crate::Literal> {
488 match arena[handle] {
489 crate::Expression::Literal(literal) => Some(literal),
490 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
491 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
492 _ => None,
493 },
494 _ => None,
495 }
496 }
497 match arena[handle] {
498 crate::Expression::Constant(c) => {
499 get(*self, self.constants[c].init, self.global_expressions)
500 }
501 _ => get(*self, handle, arena),
502 }
503 }
504
505 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
506 compare_types(lhs, rhs, self.types)
507 }
508}
509
510#[derive(Error, Debug, Clone, Copy, PartialEq)]
511pub enum ResolveArraySizeError {
512 #[error("array element count must be positive (> 0)")]
513 ExpectedPositiveArrayLength,
514 #[error("internal: array size override has not been resolved")]
515 NonConstArrayLength,
516}
517
518impl crate::ArraySize {
519 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
529 match *self {
530 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
531 crate::ArraySize::Pending(handle) => {
532 let Some(expr) = gctx.overrides[handle].init else {
533 return Err(ResolveArraySizeError::NonConstArrayLength);
534 };
535 let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
536 U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
537 U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
538 })?;
539
540 if length == 0 {
541 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
542 }
543
544 Ok(IndexableLength::Known(length))
545 }
546 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
547 }
548 }
549}
550
551pub fn flatten_compose<'arenas>(
564 ty: crate::Handle<crate::Type>,
565 components: &'arenas [crate::Handle<crate::Expression>],
566 expressions: &'arenas crate::Arena<crate::Expression>,
567 types: &'arenas crate::UniqueArena<crate::Type>,
568) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
569 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
575 (size as usize, true)
576 } else {
577 (components.len(), false)
578 };
579
580 fn flatten_compose<'c>(
582 component: &'c crate::Handle<crate::Expression>,
583 is_vector: bool,
584 expressions: &'c crate::Arena<crate::Expression>,
585 ) -> &'c [crate::Handle<crate::Expression>] {
586 if is_vector {
587 if let crate::Expression::Compose {
588 ty: _,
589 components: ref subcomponents,
590 } = expressions[*component]
591 {
592 return subcomponents;
593 }
594 }
595 core::slice::from_ref(component)
596 }
597
598 fn flatten_splat<'c>(
600 component: &'c crate::Handle<crate::Expression>,
601 is_vector: bool,
602 expressions: &'c crate::Arena<crate::Expression>,
603 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
604 let mut expr = *component;
605 let mut count = 1;
606 if is_vector {
607 if let crate::Expression::Splat { size, value } = expressions[expr] {
608 expr = value;
609 count = size as usize;
610 }
611 }
612 core::iter::repeat_n(expr, count)
613 }
614
615 components
622 .iter()
623 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
624 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
625 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
626 .take(size)
627}
628
629#[test]
630fn test_matrix_size() {
631 let module = crate::Module::default();
632 assert_eq!(
633 crate::TypeInner::Matrix {
634 columns: crate::VectorSize::Tri,
635 rows: crate::VectorSize::Tri,
636 scalar: crate::Scalar::F32,
637 }
638 .size(module.to_ctx()),
639 48,
640 );
641}