naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U32(u32),
94    I32(i32),
95    U64(u64),
96    I64(i64),
97    Bool(bool),
98    AbstractInt(i64),
99    AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103    fn from(l: crate::Literal) -> Self {
104        match l {
105            crate::Literal::F64(v) => Self::F64(v.to_bits()),
106            crate::Literal::F32(v) => Self::F32(v.to_bits()),
107            crate::Literal::F16(v) => Self::F16(v.to_bits()),
108            crate::Literal::U32(v) => Self::U32(v),
109            crate::Literal::I32(v) => Self::I32(v),
110            crate::Literal::U64(v) => Self::U64(v),
111            crate::Literal::I64(v) => Self::I64(v),
112            crate::Literal::Bool(v) => Self::Bool(v),
113            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115        }
116    }
117}
118
119impl crate::Literal {
120    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121        match (value, scalar.kind, scalar.width) {
122            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124            (value, crate::ScalarKind::Float, 2) => {
125                Some(Self::F16(half::f16::from_f32_const(value as _)))
126            }
127            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
128            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
129            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
130            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
131            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
132            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
133            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
134            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
135            _ => None,
136        }
137    }
138
139    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
140        Self::new(0, scalar)
141    }
142
143    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
144        Self::new(1, scalar)
145    }
146
147    pub const fn width(&self) -> crate::Bytes {
148        match *self {
149            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
150            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
151            Self::F16(_) => 2,
152            Self::Bool(_) => crate::BOOL_WIDTH,
153            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
154        }
155    }
156    pub const fn scalar(&self) -> crate::Scalar {
157        match *self {
158            Self::F64(_) => crate::Scalar::F64,
159            Self::F32(_) => crate::Scalar::F32,
160            Self::F16(_) => crate::Scalar::F16,
161            Self::U32(_) => crate::Scalar::U32,
162            Self::I32(_) => crate::Scalar::I32,
163            Self::U64(_) => crate::Scalar::U64,
164            Self::I64(_) => crate::Scalar::I64,
165            Self::Bool(_) => crate::Scalar::BOOL,
166            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
167            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
168        }
169    }
170    pub const fn scalar_kind(&self) -> crate::ScalarKind {
171        self.scalar().kind
172    }
173    pub const fn ty_inner(&self) -> crate::TypeInner {
174        crate::TypeInner::Scalar(self.scalar())
175    }
176}
177
178impl TryFrom<crate::Literal> for u32 {
179    type Error = ConstValueError;
180
181    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
182        match value {
183            crate::Literal::U32(value) => Ok(value),
184            crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
185            _ => Err(ConstValueError::InvalidType),
186        }
187    }
188}
189
190impl TryFrom<crate::Literal> for bool {
191    type Error = ConstValueError;
192
193    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
194        match value {
195            crate::Literal::Bool(value) => Ok(value),
196            _ => Err(ConstValueError::InvalidType),
197        }
198    }
199}
200
201impl super::AddressSpace {
202    pub fn access(self) -> crate::StorageAccess {
203        use crate::StorageAccess as Sa;
204        match self {
205            crate::AddressSpace::Function
206            | crate::AddressSpace::Private
207            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
208            crate::AddressSpace::Uniform => Sa::LOAD,
209            crate::AddressSpace::Storage { access } => access,
210            crate::AddressSpace::Handle => Sa::LOAD,
211            crate::AddressSpace::Immediate => Sa::LOAD,
212            // TaskPayload isn't always writable, but this is checked for elsewhere,
213            // when not using multiple payloads and matching the entry payload is checked.
214            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
215        }
216    }
217}
218
219impl super::MathFunction {
220    pub const fn argument_count(&self) -> usize {
221        match *self {
222            // comparison
223            Self::Abs => 1,
224            Self::Min => 2,
225            Self::Max => 2,
226            Self::Clamp => 3,
227            Self::Saturate => 1,
228            // trigonometry
229            Self::Cos => 1,
230            Self::Cosh => 1,
231            Self::Sin => 1,
232            Self::Sinh => 1,
233            Self::Tan => 1,
234            Self::Tanh => 1,
235            Self::Acos => 1,
236            Self::Asin => 1,
237            Self::Atan => 1,
238            Self::Atan2 => 2,
239            Self::Asinh => 1,
240            Self::Acosh => 1,
241            Self::Atanh => 1,
242            Self::Radians => 1,
243            Self::Degrees => 1,
244            // decomposition
245            Self::Ceil => 1,
246            Self::Floor => 1,
247            Self::Round => 1,
248            Self::Fract => 1,
249            Self::Trunc => 1,
250            Self::Modf => 1,
251            Self::Frexp => 1,
252            Self::Ldexp => 2,
253            // exponent
254            Self::Exp => 1,
255            Self::Exp2 => 1,
256            Self::Log => 1,
257            Self::Log2 => 1,
258            Self::Pow => 2,
259            // geometry
260            Self::Dot => 2,
261            Self::Dot4I8Packed => 2,
262            Self::Dot4U8Packed => 2,
263            Self::Outer => 2,
264            Self::Cross => 2,
265            Self::Distance => 2,
266            Self::Length => 1,
267            Self::Normalize => 1,
268            Self::FaceForward => 3,
269            Self::Reflect => 2,
270            Self::Refract => 3,
271            // computational
272            Self::Sign => 1,
273            Self::Fma => 3,
274            Self::Mix => 3,
275            Self::Step => 2,
276            Self::SmoothStep => 3,
277            Self::Sqrt => 1,
278            Self::InverseSqrt => 1,
279            Self::Inverse => 1,
280            Self::Transpose => 1,
281            Self::Determinant => 1,
282            Self::QuantizeToF16 => 1,
283            // bits
284            Self::CountTrailingZeros => 1,
285            Self::CountLeadingZeros => 1,
286            Self::CountOneBits => 1,
287            Self::ReverseBits => 1,
288            Self::ExtractBits => 3,
289            Self::InsertBits => 4,
290            Self::FirstTrailingBit => 1,
291            Self::FirstLeadingBit => 1,
292            // data packing
293            Self::Pack4x8snorm => 1,
294            Self::Pack4x8unorm => 1,
295            Self::Pack2x16snorm => 1,
296            Self::Pack2x16unorm => 1,
297            Self::Pack2x16float => 1,
298            Self::Pack4xI8 => 1,
299            Self::Pack4xU8 => 1,
300            Self::Pack4xI8Clamp => 1,
301            Self::Pack4xU8Clamp => 1,
302            // data unpacking
303            Self::Unpack4x8snorm => 1,
304            Self::Unpack4x8unorm => 1,
305            Self::Unpack2x16snorm => 1,
306            Self::Unpack2x16unorm => 1,
307            Self::Unpack2x16float => 1,
308            Self::Unpack4xI8 => 1,
309            Self::Unpack4xU8 => 1,
310        }
311    }
312}
313
314impl crate::Expression {
315    /// Returns true if the expression is considered emitted at the start of a function.
316    pub const fn needs_pre_emit(&self) -> bool {
317        match *self {
318            Self::Literal(_)
319            | Self::Constant(_)
320            | Self::Override(_)
321            | Self::ZeroValue(_)
322            | Self::FunctionArgument(_)
323            | Self::GlobalVariable(_)
324            | Self::LocalVariable(_) => true,
325            _ => false,
326        }
327    }
328
329    /// Return true if this expression is a dynamic array/vector/matrix index,
330    /// for [`Access`].
331    ///
332    /// This method returns true if this expression is a dynamically computed
333    /// index, and as such can only be used to index matrices when they appear
334    /// behind a pointer. See the documentation for [`Access`] for details.
335    ///
336    /// Note, this does not check the _type_ of the given expression. It's up to
337    /// the caller to establish that the `Access` expression is well-typed
338    /// through other means, like [`ResolveContext`].
339    ///
340    /// [`Access`]: crate::Expression::Access
341    /// [`ResolveContext`]: crate::proc::ResolveContext
342    pub const fn is_dynamic_index(&self) -> bool {
343        match *self {
344            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
345            _ => true,
346        }
347    }
348}
349
350impl crate::Function {
351    /// Return the global variable being accessed by the expression `pointer`.
352    ///
353    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
354    /// expressions that ultimately access some part of a `GlobalVariable`,
355    /// return a handle for that global.
356    ///
357    /// If the expression does not ultimately access a global variable, return
358    /// `None`.
359    pub fn originating_global(
360        &self,
361        mut pointer: crate::Handle<crate::Expression>,
362    ) -> Option<crate::Handle<crate::GlobalVariable>> {
363        loop {
364            pointer = match self.expressions[pointer] {
365                crate::Expression::Access { base, .. } => base,
366                crate::Expression::AccessIndex { base, .. } => base,
367                crate::Expression::GlobalVariable(handle) => return Some(handle),
368                crate::Expression::LocalVariable(_) => return None,
369                crate::Expression::FunctionArgument(_) => return None,
370                // There are no other expressions that produce pointer values.
371                _ => unreachable!(),
372            }
373        }
374    }
375}
376
377impl crate::SampleLevel {
378    pub const fn implicit_derivatives(&self) -> bool {
379        match *self {
380            Self::Auto | Self::Bias(_) => true,
381            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
382        }
383    }
384}
385
386impl crate::Binding {
387    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
388        match *self {
389            crate::Binding::BuiltIn(built_in) => Some(built_in),
390            Self::Location { .. } => None,
391        }
392    }
393}
394
395impl super::SwizzleComponent {
396    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
397
398    pub const fn index(&self) -> u32 {
399        match *self {
400            Self::X => 0,
401            Self::Y => 1,
402            Self::Z => 2,
403            Self::W => 3,
404        }
405    }
406    pub const fn from_index(idx: u32) -> Self {
407        match idx {
408            0 => Self::X,
409            1 => Self::Y,
410            2 => Self::Z,
411            _ => Self::W,
412        }
413    }
414}
415
416impl super::ImageClass {
417    pub const fn is_multisampled(self) -> bool {
418        match self {
419            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
420            crate::ImageClass::Storage { .. } => false,
421            crate::ImageClass::External => false,
422        }
423    }
424
425    pub const fn is_mipmapped(self) -> bool {
426        match self {
427            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
428            crate::ImageClass::Storage { .. } => false,
429            crate::ImageClass::External => false,
430        }
431    }
432
433    pub const fn is_depth(self) -> bool {
434        matches!(self, crate::ImageClass::Depth { .. })
435    }
436}
437
438impl crate::Module {
439    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
440        GlobalCtx {
441            types: &self.types,
442            constants: &self.constants,
443            overrides: &self.overrides,
444            global_expressions: &self.global_expressions,
445        }
446    }
447
448    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
449        compare_types(lhs, rhs, &self.types)
450    }
451}
452
453#[derive(Debug)]
454pub enum ConstValueError {
455    NonConst,
456    Negative,
457    InvalidType,
458}
459
460impl From<core::convert::Infallible> for ConstValueError {
461    fn from(_: core::convert::Infallible) -> Self {
462        unreachable!()
463    }
464}
465
466#[derive(Clone, Copy)]
467pub struct GlobalCtx<'a> {
468    pub types: &'a crate::UniqueArena<crate::Type>,
469    pub constants: &'a crate::Arena<crate::Constant>,
470    pub overrides: &'a crate::Arena<crate::Override>,
471    pub global_expressions: &'a crate::Arena<crate::Expression>,
472}
473
474impl GlobalCtx<'_> {
475    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
476    #[cfg_attr(
477        not(any(
478            feature = "glsl-in",
479            feature = "spv-in",
480            feature = "wgsl-in",
481            glsl_out,
482            hlsl_out,
483            msl_out,
484            wgsl_out
485        )),
486        allow(dead_code)
487    )]
488    pub(super) fn get_const_val<T, E>(
489        &self,
490        handle: crate::Handle<crate::Expression>,
491    ) -> Result<T, ConstValueError>
492    where
493        T: TryFrom<crate::Literal, Error = E>,
494        E: Into<ConstValueError>,
495    {
496        self.get_const_val_from(handle, self.global_expressions)
497    }
498
499    pub(super) fn get_const_val_from<T, E>(
500        &self,
501        handle: crate::Handle<crate::Expression>,
502        arena: &crate::Arena<crate::Expression>,
503    ) -> Result<T, ConstValueError>
504    where
505        T: TryFrom<crate::Literal, Error = E>,
506        E: Into<ConstValueError>,
507    {
508        fn get(
509            gctx: GlobalCtx,
510            handle: crate::Handle<crate::Expression>,
511            arena: &crate::Arena<crate::Expression>,
512        ) -> Option<crate::Literal> {
513            match arena[handle] {
514                crate::Expression::Literal(literal) => Some(literal),
515                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
516                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
517                    _ => None,
518                },
519                _ => None,
520            }
521        }
522        let value = match arena[handle] {
523            crate::Expression::Constant(c) => {
524                get(*self, self.constants[c].init, self.global_expressions)
525            }
526            _ => get(*self, handle, arena),
527        };
528        match value {
529            Some(v) => v.try_into().map_err(Into::into),
530            None => Err(ConstValueError::NonConst),
531        }
532    }
533
534    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
535        compare_types(lhs, rhs, self.types)
536    }
537}
538
539#[derive(Error, Debug, Clone, Copy, PartialEq)]
540pub enum ResolveArraySizeError {
541    #[error("array element count must be positive (> 0)")]
542    ExpectedPositiveArrayLength,
543    #[error("internal: array size override has not been resolved")]
544    NonConstArrayLength,
545}
546
547impl crate::ArraySize {
548    /// Return the number of elements that `size` represents, if known at code generation time.
549    ///
550    /// If `size` is override-based, return an error unless the override's
551    /// initializer is a fully evaluated constant expression. You can call
552    /// [`pipeline_constants::process_overrides`] to supply values for a
553    /// module's overrides and ensure their initializers are fully evaluated, as
554    /// this function expects.
555    ///
556    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
557    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
558        match *self {
559            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
560            crate::ArraySize::Pending(handle) => {
561                let Some(expr) = gctx.overrides[handle].init else {
562                    return Err(ResolveArraySizeError::NonConstArrayLength);
563                };
564                let length = gctx.get_const_val(expr).map_err(|err| match err {
565                    ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
566                    ConstValueError::Negative | ConstValueError::InvalidType => {
567                        ResolveArraySizeError::ExpectedPositiveArrayLength
568                    }
569                })?;
570
571                if length == 0 {
572                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
573                }
574
575                Ok(IndexableLength::Known(length))
576            }
577            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
578        }
579    }
580}
581
582/// Return an iterator over the individual components assembled by a
583/// `Compose` expression.
584///
585/// Given `ty` and `components` from an `Expression::Compose`, return an
586/// iterator over the components of the resulting value.
587///
588/// Normally, this would just be an iterator over `components`. However,
589/// `Compose` expressions can concatenate vectors, in which case the i'th
590/// value being composed is not generally the i'th element of `components`.
591/// This function consults `ty` to decide if this concatenation is occurring,
592/// and returns an iterator that produces the components of the result of
593/// the `Compose` expression in either case.
594pub fn flatten_compose<'arenas>(
595    ty: crate::Handle<crate::Type>,
596    components: &'arenas [crate::Handle<crate::Expression>],
597    expressions: &'arenas crate::Arena<crate::Expression>,
598    types: &'arenas crate::UniqueArena<crate::Type>,
599) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
600    // Returning `impl Iterator` is a bit tricky. We may or may not
601    // want to flatten the components, but we have to settle on a
602    // single concrete type to return. This function returns a single
603    // iterator chain that handles both the flattening and
604    // non-flattening cases.
605    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
606        (size as usize, true)
607    } else {
608        (components.len(), false)
609    };
610
611    /// Flatten `Compose` expressions if `is_vector` is true.
612    fn flatten_compose<'c>(
613        component: &'c crate::Handle<crate::Expression>,
614        is_vector: bool,
615        expressions: &'c crate::Arena<crate::Expression>,
616    ) -> &'c [crate::Handle<crate::Expression>] {
617        if is_vector {
618            if let crate::Expression::Compose {
619                ty: _,
620                components: ref subcomponents,
621            } = expressions[*component]
622            {
623                return subcomponents;
624            }
625        }
626        core::slice::from_ref(component)
627    }
628
629    /// Flatten `Splat` expressions if `is_vector` is true.
630    fn flatten_splat<'c>(
631        component: &'c crate::Handle<crate::Expression>,
632        is_vector: bool,
633        expressions: &'c crate::Arena<crate::Expression>,
634    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
635        let mut expr = *component;
636        let mut count = 1;
637        if is_vector {
638            if let crate::Expression::Splat { size, value } = expressions[expr] {
639                expr = value;
640                count = size as usize;
641            }
642        }
643        core::iter::repeat_n(expr, count)
644    }
645
646    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
647    // flatten up to two levels of `Compose` expressions.
648    //
649    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
650    // `Splat` expressions. Fortunately, the operand of a `Splat` must
651    // be a scalar, so we can stop there.
652    components
653        .iter()
654        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
655        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
656        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
657        .take(size)
658}
659
660impl super::ShaderStage {
661    pub const fn compute_like(self) -> bool {
662        match self {
663            Self::Vertex | Self::Fragment => false,
664            Self::Compute | Self::Task | Self::Mesh => true,
665        }
666    }
667}
668
669#[test]
670fn test_matrix_size() {
671    let module = crate::Module::default();
672    assert_eq!(
673        crate::TypeInner::Matrix {
674            columns: crate::VectorSize::Tri,
675            rows: crate::VectorSize::Tri,
676            scalar: crate::Scalar::F32,
677        }
678        .size(module.to_ctx()),
679        48,
680    );
681}
682
683impl crate::Module {
684    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
685    /// and by validators. This only validates the output variable itself, and not the
686    /// vertex and primitive output types.
687    ///
688    /// The output contains the extracted mesh stage info, with overrides unset,
689    /// and then the overrides separately. This is because the overrides should be
690    /// treated as expressions elsewhere, but that requires mutably modifying the
691    /// module and the expressions should only be created at parse time, not validation
692    /// time.
693    #[allow(clippy::type_complexity)]
694    pub fn analyze_mesh_shader_info(
695        &self,
696        gv: crate::Handle<crate::GlobalVariable>,
697    ) -> (
698        crate::MeshStageInfo,
699        [Option<crate::Handle<crate::Override>>; 2],
700        Option<crate::WithSpan<crate::valid::EntryPointError>>,
701    ) {
702        use crate::span::AddSpan;
703        use crate::valid::EntryPointError;
704        #[derive(Default)]
705        struct OutError {
706            pub inner: Option<EntryPointError>,
707        }
708        impl OutError {
709            pub fn set(&mut self, err: EntryPointError) {
710                if self.inner.is_none() {
711                    self.inner = Some(err);
712                }
713            }
714        }
715
716        // Used to temporarily initialize stuff
717        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
718        let mut output = crate::MeshStageInfo {
719            topology: crate::MeshOutputTopology::Triangles,
720            max_vertices: 0,
721            max_vertices_override: None,
722            max_primitives: 0,
723            max_primitives_override: None,
724            vertex_output_type: null_type,
725            primitive_output_type: null_type,
726            output_variable: gv,
727        };
728        // Stores the error to output, if any.
729        let mut error = OutError::default();
730        let r#type = &self.types[self.global_variables[gv].ty].inner;
731
732        let mut topology = output.topology;
733        // Max, max override, type
734        let mut vertex_info = (0, None, null_type);
735        let mut primitive_info = (0, None, null_type);
736
737        match r#type {
738            &crate::TypeInner::Struct { ref members, .. } => {
739                let mut builtins = crate::FastHashSet::default();
740                for member in members {
741                    match member.binding {
742                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
743                            // Must have type u32
744                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
745                                error.set(EntryPointError::BadMeshOutputVariableField);
746                            }
747                            // Each builtin should only occur once
748                            if builtins.contains(&crate::BuiltIn::VertexCount) {
749                                error.set(EntryPointError::BadMeshOutputVariableType);
750                            }
751                            builtins.insert(crate::BuiltIn::VertexCount);
752                        }
753                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
754                            // Must have type u32
755                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
756                                error.set(EntryPointError::BadMeshOutputVariableField);
757                            }
758                            // Each builtin should only occur once
759                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
760                                error.set(EntryPointError::BadMeshOutputVariableType);
761                            }
762                            builtins.insert(crate::BuiltIn::PrimitiveCount);
763                        }
764                        Some(crate::Binding::BuiltIn(
765                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
766                        )) => {
767                            let ty = &self.types[member.ty].inner;
768                            // Analyze the array type to determine size and vertex/primitive type
769                            let (a, b, c) = match ty {
770                                &crate::TypeInner::Array { base, size, .. } => {
771                                    let ty = base;
772                                    let (max, max_override) = match size {
773                                        crate::ArraySize::Constant(a) => (a.get(), None),
774                                        crate::ArraySize::Pending(o) => (0, Some(o)),
775                                        crate::ArraySize::Dynamic => {
776                                            error.set(EntryPointError::BadMeshOutputVariableField);
777                                            (0, None)
778                                        }
779                                    };
780                                    (max, max_override, ty)
781                                }
782                                _ => {
783                                    error.set(EntryPointError::BadMeshOutputVariableField);
784                                    (0, None, null_type)
785                                }
786                            };
787                            if matches!(
788                                member.binding,
789                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
790                            ) {
791                                // Primitives require special analysis to determine topology
792                                primitive_info = (a, b, c);
793                                match self.types[c].inner {
794                                    crate::TypeInner::Struct { ref members, .. } => {
795                                        for member in members {
796                                            match member.binding {
797                                                Some(crate::Binding::BuiltIn(
798                                                    crate::BuiltIn::PointIndex,
799                                                )) => {
800                                                    topology = crate::MeshOutputTopology::Points;
801                                                }
802                                                Some(crate::Binding::BuiltIn(
803                                                    crate::BuiltIn::LineIndices,
804                                                )) => {
805                                                    topology = crate::MeshOutputTopology::Lines;
806                                                }
807                                                Some(crate::Binding::BuiltIn(
808                                                    crate::BuiltIn::TriangleIndices,
809                                                )) => {
810                                                    topology = crate::MeshOutputTopology::Triangles;
811                                                }
812                                                _ => (),
813                                            }
814                                        }
815                                    }
816                                    _ => (),
817                                }
818                                // Each builtin should only occur once
819                                if builtins.contains(&crate::BuiltIn::Primitives) {
820                                    error.set(EntryPointError::BadMeshOutputVariableType);
821                                }
822                                builtins.insert(crate::BuiltIn::Primitives);
823                            } else {
824                                vertex_info = (a, b, c);
825                                // Each builtin should only occur once
826                                if builtins.contains(&crate::BuiltIn::Vertices) {
827                                    error.set(EntryPointError::BadMeshOutputVariableType);
828                                }
829                                builtins.insert(crate::BuiltIn::Vertices);
830                            }
831                        }
832                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
833                    }
834                }
835                output = crate::MeshStageInfo {
836                    topology,
837                    max_vertices: vertex_info.0,
838                    max_vertices_override: None,
839                    vertex_output_type: vertex_info.2,
840                    max_primitives: primitive_info.0,
841                    max_primitives_override: None,
842                    primitive_output_type: primitive_info.2,
843                    ..output
844                }
845            }
846            _ => error.set(EntryPointError::BadMeshOutputVariableType),
847        }
848        (
849            output,
850            [vertex_info.1, primitive_info.1],
851            error
852                .inner
853                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
854        )
855    }
856}