naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::min_max_float_representable_by;
28pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
29
30impl From<super::StorageFormat> for super::Scalar {
31    fn from(format: super::StorageFormat) -> Self {
32        use super::{ScalarKind as Sk, StorageFormat as Sf};
33        let kind = match format {
34            Sf::R8Unorm => Sk::Float,
35            Sf::R8Snorm => Sk::Float,
36            Sf::R8Uint => Sk::Uint,
37            Sf::R8Sint => Sk::Sint,
38            Sf::R16Uint => Sk::Uint,
39            Sf::R16Sint => Sk::Sint,
40            Sf::R16Float => Sk::Float,
41            Sf::Rg8Unorm => Sk::Float,
42            Sf::Rg8Snorm => Sk::Float,
43            Sf::Rg8Uint => Sk::Uint,
44            Sf::Rg8Sint => Sk::Sint,
45            Sf::R32Uint => Sk::Uint,
46            Sf::R32Sint => Sk::Sint,
47            Sf::R32Float => Sk::Float,
48            Sf::Rg16Uint => Sk::Uint,
49            Sf::Rg16Sint => Sk::Sint,
50            Sf::Rg16Float => Sk::Float,
51            Sf::Rgba8Unorm => Sk::Float,
52            Sf::Rgba8Snorm => Sk::Float,
53            Sf::Rgba8Uint => Sk::Uint,
54            Sf::Rgba8Sint => Sk::Sint,
55            Sf::Bgra8Unorm => Sk::Float,
56            Sf::Rgb10a2Uint => Sk::Uint,
57            Sf::Rgb10a2Unorm => Sk::Float,
58            Sf::Rg11b10Ufloat => Sk::Float,
59            Sf::R64Uint => Sk::Uint,
60            Sf::Rg32Uint => Sk::Uint,
61            Sf::Rg32Sint => Sk::Sint,
62            Sf::Rg32Float => Sk::Float,
63            Sf::Rgba16Uint => Sk::Uint,
64            Sf::Rgba16Sint => Sk::Sint,
65            Sf::Rgba16Float => Sk::Float,
66            Sf::Rgba32Uint => Sk::Uint,
67            Sf::Rgba32Sint => Sk::Sint,
68            Sf::Rgba32Float => Sk::Float,
69            Sf::R16Unorm => Sk::Float,
70            Sf::R16Snorm => Sk::Float,
71            Sf::Rg16Unorm => Sk::Float,
72            Sf::Rg16Snorm => Sk::Float,
73            Sf::Rgba16Unorm => Sk::Float,
74            Sf::Rgba16Snorm => Sk::Float,
75        };
76        let width = match format {
77            Sf::R64Uint => 8,
78            _ => 4,
79        };
80        super::Scalar { kind, width }
81    }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum HashableLiteral {
86    F64(u64),
87    F32(u32),
88    F16(u16),
89    U32(u32),
90    I32(i32),
91    U64(u64),
92    I64(i64),
93    Bool(bool),
94    AbstractInt(i64),
95    AbstractFloat(u64),
96}
97
98impl From<crate::Literal> for HashableLiteral {
99    fn from(l: crate::Literal) -> Self {
100        match l {
101            crate::Literal::F64(v) => Self::F64(v.to_bits()),
102            crate::Literal::F32(v) => Self::F32(v.to_bits()),
103            crate::Literal::F16(v) => Self::F16(v.to_bits()),
104            crate::Literal::U32(v) => Self::U32(v),
105            crate::Literal::I32(v) => Self::I32(v),
106            crate::Literal::U64(v) => Self::U64(v),
107            crate::Literal::I64(v) => Self::I64(v),
108            crate::Literal::Bool(v) => Self::Bool(v),
109            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
110            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
111        }
112    }
113}
114
115impl crate::Literal {
116    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
117        match (value, scalar.kind, scalar.width) {
118            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
119            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
120            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
121            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
122            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
123            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
124            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
125            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
126            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
127            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
128            _ => None,
129        }
130    }
131
132    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
133        Self::new(0, scalar)
134    }
135
136    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
137        Self::new(1, scalar)
138    }
139
140    pub const fn width(&self) -> crate::Bytes {
141        match *self {
142            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
143            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
144            Self::F16(_) => 2,
145            Self::Bool(_) => crate::BOOL_WIDTH,
146            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
147        }
148    }
149    pub const fn scalar(&self) -> crate::Scalar {
150        match *self {
151            Self::F64(_) => crate::Scalar::F64,
152            Self::F32(_) => crate::Scalar::F32,
153            Self::F16(_) => crate::Scalar::F16,
154            Self::U32(_) => crate::Scalar::U32,
155            Self::I32(_) => crate::Scalar::I32,
156            Self::U64(_) => crate::Scalar::U64,
157            Self::I64(_) => crate::Scalar::I64,
158            Self::Bool(_) => crate::Scalar::BOOL,
159            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
160            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
161        }
162    }
163    pub const fn scalar_kind(&self) -> crate::ScalarKind {
164        self.scalar().kind
165    }
166    pub const fn ty_inner(&self) -> crate::TypeInner {
167        crate::TypeInner::Scalar(self.scalar())
168    }
169}
170
171impl super::AddressSpace {
172    pub fn access(self) -> crate::StorageAccess {
173        use crate::StorageAccess as Sa;
174        match self {
175            crate::AddressSpace::Function
176            | crate::AddressSpace::Private
177            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
178            crate::AddressSpace::Uniform => Sa::LOAD,
179            crate::AddressSpace::Storage { access } => access,
180            crate::AddressSpace::Handle => Sa::LOAD,
181            crate::AddressSpace::PushConstant => Sa::LOAD,
182            // TaskPayload isn't always writable, but this is checked for elsewhere,
183            // when not using multiple payloads and matching the entry payload is checked.
184            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
185        }
186    }
187}
188
189impl super::MathFunction {
190    pub const fn argument_count(&self) -> usize {
191        match *self {
192            // comparison
193            Self::Abs => 1,
194            Self::Min => 2,
195            Self::Max => 2,
196            Self::Clamp => 3,
197            Self::Saturate => 1,
198            // trigonometry
199            Self::Cos => 1,
200            Self::Cosh => 1,
201            Self::Sin => 1,
202            Self::Sinh => 1,
203            Self::Tan => 1,
204            Self::Tanh => 1,
205            Self::Acos => 1,
206            Self::Asin => 1,
207            Self::Atan => 1,
208            Self::Atan2 => 2,
209            Self::Asinh => 1,
210            Self::Acosh => 1,
211            Self::Atanh => 1,
212            Self::Radians => 1,
213            Self::Degrees => 1,
214            // decomposition
215            Self::Ceil => 1,
216            Self::Floor => 1,
217            Self::Round => 1,
218            Self::Fract => 1,
219            Self::Trunc => 1,
220            Self::Modf => 1,
221            Self::Frexp => 1,
222            Self::Ldexp => 2,
223            // exponent
224            Self::Exp => 1,
225            Self::Exp2 => 1,
226            Self::Log => 1,
227            Self::Log2 => 1,
228            Self::Pow => 2,
229            // geometry
230            Self::Dot => 2,
231            Self::Dot4I8Packed => 2,
232            Self::Dot4U8Packed => 2,
233            Self::Outer => 2,
234            Self::Cross => 2,
235            Self::Distance => 2,
236            Self::Length => 1,
237            Self::Normalize => 1,
238            Self::FaceForward => 3,
239            Self::Reflect => 2,
240            Self::Refract => 3,
241            // computational
242            Self::Sign => 1,
243            Self::Fma => 3,
244            Self::Mix => 3,
245            Self::Step => 2,
246            Self::SmoothStep => 3,
247            Self::Sqrt => 1,
248            Self::InverseSqrt => 1,
249            Self::Inverse => 1,
250            Self::Transpose => 1,
251            Self::Determinant => 1,
252            Self::QuantizeToF16 => 1,
253            // bits
254            Self::CountTrailingZeros => 1,
255            Self::CountLeadingZeros => 1,
256            Self::CountOneBits => 1,
257            Self::ReverseBits => 1,
258            Self::ExtractBits => 3,
259            Self::InsertBits => 4,
260            Self::FirstTrailingBit => 1,
261            Self::FirstLeadingBit => 1,
262            // data packing
263            Self::Pack4x8snorm => 1,
264            Self::Pack4x8unorm => 1,
265            Self::Pack2x16snorm => 1,
266            Self::Pack2x16unorm => 1,
267            Self::Pack2x16float => 1,
268            Self::Pack4xI8 => 1,
269            Self::Pack4xU8 => 1,
270            Self::Pack4xI8Clamp => 1,
271            Self::Pack4xU8Clamp => 1,
272            // data unpacking
273            Self::Unpack4x8snorm => 1,
274            Self::Unpack4x8unorm => 1,
275            Self::Unpack2x16snorm => 1,
276            Self::Unpack2x16unorm => 1,
277            Self::Unpack2x16float => 1,
278            Self::Unpack4xI8 => 1,
279            Self::Unpack4xU8 => 1,
280        }
281    }
282}
283
284impl crate::Expression {
285    /// Returns true if the expression is considered emitted at the start of a function.
286    pub const fn needs_pre_emit(&self) -> bool {
287        match *self {
288            Self::Literal(_)
289            | Self::Constant(_)
290            | Self::Override(_)
291            | Self::ZeroValue(_)
292            | Self::FunctionArgument(_)
293            | Self::GlobalVariable(_)
294            | Self::LocalVariable(_) => true,
295            _ => false,
296        }
297    }
298
299    /// Return true if this expression is a dynamic array/vector/matrix index,
300    /// for [`Access`].
301    ///
302    /// This method returns true if this expression is a dynamically computed
303    /// index, and as such can only be used to index matrices when they appear
304    /// behind a pointer. See the documentation for [`Access`] for details.
305    ///
306    /// Note, this does not check the _type_ of the given expression. It's up to
307    /// the caller to establish that the `Access` expression is well-typed
308    /// through other means, like [`ResolveContext`].
309    ///
310    /// [`Access`]: crate::Expression::Access
311    /// [`ResolveContext`]: crate::proc::ResolveContext
312    pub const fn is_dynamic_index(&self) -> bool {
313        match *self {
314            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
315            _ => true,
316        }
317    }
318}
319
320impl crate::Function {
321    /// Return the global variable being accessed by the expression `pointer`.
322    ///
323    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
324    /// expressions that ultimately access some part of a `GlobalVariable`,
325    /// return a handle for that global.
326    ///
327    /// If the expression does not ultimately access a global variable, return
328    /// `None`.
329    pub fn originating_global(
330        &self,
331        mut pointer: crate::Handle<crate::Expression>,
332    ) -> Option<crate::Handle<crate::GlobalVariable>> {
333        loop {
334            pointer = match self.expressions[pointer] {
335                crate::Expression::Access { base, .. } => base,
336                crate::Expression::AccessIndex { base, .. } => base,
337                crate::Expression::GlobalVariable(handle) => return Some(handle),
338                crate::Expression::LocalVariable(_) => return None,
339                crate::Expression::FunctionArgument(_) => return None,
340                // There are no other expressions that produce pointer values.
341                _ => unreachable!(),
342            }
343        }
344    }
345}
346
347impl crate::SampleLevel {
348    pub const fn implicit_derivatives(&self) -> bool {
349        match *self {
350            Self::Auto | Self::Bias(_) => true,
351            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
352        }
353    }
354}
355
356impl crate::Binding {
357    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
358        match *self {
359            crate::Binding::BuiltIn(built_in) => Some(built_in),
360            Self::Location { .. } => None,
361        }
362    }
363}
364
365impl super::SwizzleComponent {
366    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
367
368    pub const fn index(&self) -> u32 {
369        match *self {
370            Self::X => 0,
371            Self::Y => 1,
372            Self::Z => 2,
373            Self::W => 3,
374        }
375    }
376    pub const fn from_index(idx: u32) -> Self {
377        match idx {
378            0 => Self::X,
379            1 => Self::Y,
380            2 => Self::Z,
381            _ => Self::W,
382        }
383    }
384}
385
386impl super::ImageClass {
387    pub const fn is_multisampled(self) -> bool {
388        match self {
389            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
390            crate::ImageClass::Storage { .. } => false,
391            crate::ImageClass::External => false,
392        }
393    }
394
395    pub const fn is_mipmapped(self) -> bool {
396        match self {
397            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
398            crate::ImageClass::Storage { .. } => false,
399            crate::ImageClass::External => false,
400        }
401    }
402
403    pub const fn is_depth(self) -> bool {
404        matches!(self, crate::ImageClass::Depth { .. })
405    }
406}
407
408impl crate::Module {
409    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
410        GlobalCtx {
411            types: &self.types,
412            constants: &self.constants,
413            overrides: &self.overrides,
414            global_expressions: &self.global_expressions,
415        }
416    }
417
418    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
419        compare_types(lhs, rhs, &self.types)
420    }
421}
422
423#[derive(Debug)]
424pub(super) enum U32EvalError {
425    NonConst,
426    Negative,
427}
428
429#[derive(Clone, Copy)]
430pub struct GlobalCtx<'a> {
431    pub types: &'a crate::UniqueArena<crate::Type>,
432    pub constants: &'a crate::Arena<crate::Constant>,
433    pub overrides: &'a crate::Arena<crate::Override>,
434    pub global_expressions: &'a crate::Arena<crate::Expression>,
435}
436
437impl GlobalCtx<'_> {
438    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
439    #[allow(dead_code)]
440    pub(super) fn eval_expr_to_u32(
441        &self,
442        handle: crate::Handle<crate::Expression>,
443    ) -> Result<u32, U32EvalError> {
444        self.eval_expr_to_u32_from(handle, self.global_expressions)
445    }
446
447    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
448    pub(super) fn eval_expr_to_u32_from(
449        &self,
450        handle: crate::Handle<crate::Expression>,
451        arena: &crate::Arena<crate::Expression>,
452    ) -> Result<u32, U32EvalError> {
453        match self.eval_expr_to_literal_from(handle, arena) {
454            Some(crate::Literal::U32(value)) => Ok(value),
455            Some(crate::Literal::I32(value)) => {
456                value.try_into().map_err(|_| U32EvalError::Negative)
457            }
458            _ => Err(U32EvalError::NonConst),
459        }
460    }
461
462    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
463    #[allow(dead_code)]
464    pub(super) fn eval_expr_to_bool_from(
465        &self,
466        handle: crate::Handle<crate::Expression>,
467        arena: &crate::Arena<crate::Expression>,
468    ) -> Option<bool> {
469        match self.eval_expr_to_literal_from(handle, arena) {
470            Some(crate::Literal::Bool(value)) => Some(value),
471            _ => None,
472        }
473    }
474
475    #[allow(dead_code)]
476    pub(crate) fn eval_expr_to_literal(
477        &self,
478        handle: crate::Handle<crate::Expression>,
479    ) -> Option<crate::Literal> {
480        self.eval_expr_to_literal_from(handle, self.global_expressions)
481    }
482
483    pub(super) fn eval_expr_to_literal_from(
484        &self,
485        handle: crate::Handle<crate::Expression>,
486        arena: &crate::Arena<crate::Expression>,
487    ) -> Option<crate::Literal> {
488        fn get(
489            gctx: GlobalCtx,
490            handle: crate::Handle<crate::Expression>,
491            arena: &crate::Arena<crate::Expression>,
492        ) -> Option<crate::Literal> {
493            match arena[handle] {
494                crate::Expression::Literal(literal) => Some(literal),
495                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
496                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
497                    _ => None,
498                },
499                _ => None,
500            }
501        }
502        match arena[handle] {
503            crate::Expression::Constant(c) => {
504                get(*self, self.constants[c].init, self.global_expressions)
505            }
506            _ => get(*self, handle, arena),
507        }
508    }
509
510    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
511        compare_types(lhs, rhs, self.types)
512    }
513}
514
515#[derive(Error, Debug, Clone, Copy, PartialEq)]
516pub enum ResolveArraySizeError {
517    #[error("array element count must be positive (> 0)")]
518    ExpectedPositiveArrayLength,
519    #[error("internal: array size override has not been resolved")]
520    NonConstArrayLength,
521}
522
523impl crate::ArraySize {
524    /// Return the number of elements that `size` represents, if known at code generation time.
525    ///
526    /// If `size` is override-based, return an error unless the override's
527    /// initializer is a fully evaluated constant expression. You can call
528    /// [`pipeline_constants::process_overrides`] to supply values for a
529    /// module's overrides and ensure their initializers are fully evaluated, as
530    /// this function expects.
531    ///
532    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
533    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
534        match *self {
535            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
536            crate::ArraySize::Pending(handle) => {
537                let Some(expr) = gctx.overrides[handle].init else {
538                    return Err(ResolveArraySizeError::NonConstArrayLength);
539                };
540                let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
541                    U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
542                    U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
543                })?;
544
545                if length == 0 {
546                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
547                }
548
549                Ok(IndexableLength::Known(length))
550            }
551            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
552        }
553    }
554}
555
556/// Return an iterator over the individual components assembled by a
557/// `Compose` expression.
558///
559/// Given `ty` and `components` from an `Expression::Compose`, return an
560/// iterator over the components of the resulting value.
561///
562/// Normally, this would just be an iterator over `components`. However,
563/// `Compose` expressions can concatenate vectors, in which case the i'th
564/// value being composed is not generally the i'th element of `components`.
565/// This function consults `ty` to decide if this concatenation is occurring,
566/// and returns an iterator that produces the components of the result of
567/// the `Compose` expression in either case.
568pub fn flatten_compose<'arenas>(
569    ty: crate::Handle<crate::Type>,
570    components: &'arenas [crate::Handle<crate::Expression>],
571    expressions: &'arenas crate::Arena<crate::Expression>,
572    types: &'arenas crate::UniqueArena<crate::Type>,
573) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
574    // Returning `impl Iterator` is a bit tricky. We may or may not
575    // want to flatten the components, but we have to settle on a
576    // single concrete type to return. This function returns a single
577    // iterator chain that handles both the flattening and
578    // non-flattening cases.
579    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
580        (size as usize, true)
581    } else {
582        (components.len(), false)
583    };
584
585    /// Flatten `Compose` expressions if `is_vector` is true.
586    fn flatten_compose<'c>(
587        component: &'c crate::Handle<crate::Expression>,
588        is_vector: bool,
589        expressions: &'c crate::Arena<crate::Expression>,
590    ) -> &'c [crate::Handle<crate::Expression>] {
591        if is_vector {
592            if let crate::Expression::Compose {
593                ty: _,
594                components: ref subcomponents,
595            } = expressions[*component]
596            {
597                return subcomponents;
598            }
599        }
600        core::slice::from_ref(component)
601    }
602
603    /// Flatten `Splat` expressions if `is_vector` is true.
604    fn flatten_splat<'c>(
605        component: &'c crate::Handle<crate::Expression>,
606        is_vector: bool,
607        expressions: &'c crate::Arena<crate::Expression>,
608    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
609        let mut expr = *component;
610        let mut count = 1;
611        if is_vector {
612            if let crate::Expression::Splat { size, value } = expressions[expr] {
613                expr = value;
614                count = size as usize;
615            }
616        }
617        core::iter::repeat_n(expr, count)
618    }
619
620    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
621    // flatten up to two levels of `Compose` expressions.
622    //
623    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
624    // `Splat` expressions. Fortunately, the operand of a `Splat` must
625    // be a scalar, so we can stop there.
626    components
627        .iter()
628        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
629        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
630        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
631        .take(size)
632}
633
634impl super::ShaderStage {
635    pub const fn compute_like(self) -> bool {
636        match self {
637            Self::Vertex | Self::Fragment => false,
638            Self::Compute | Self::Task | Self::Mesh => true,
639        }
640    }
641}
642
643#[test]
644fn test_matrix_size() {
645    let module = crate::Module::default();
646    assert_eq!(
647        crate::TypeInner::Matrix {
648            columns: crate::VectorSize::Tri,
649            rows: crate::VectorSize::Tri,
650            scalar: crate::Scalar::F32,
651        }
652        .size(module.to_ctx()),
653        48,
654    );
655}