naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U32(u32),
94    I32(i32),
95    U64(u64),
96    I64(i64),
97    Bool(bool),
98    AbstractInt(i64),
99    AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103    fn from(l: crate::Literal) -> Self {
104        match l {
105            crate::Literal::F64(v) => Self::F64(v.to_bits()),
106            crate::Literal::F32(v) => Self::F32(v.to_bits()),
107            crate::Literal::F16(v) => Self::F16(v.to_bits()),
108            crate::Literal::U32(v) => Self::U32(v),
109            crate::Literal::I32(v) => Self::I32(v),
110            crate::Literal::U64(v) => Self::U64(v),
111            crate::Literal::I64(v) => Self::I64(v),
112            crate::Literal::Bool(v) => Self::Bool(v),
113            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115        }
116    }
117}
118
119impl crate::Literal {
120    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121        match (value, scalar.kind, scalar.width) {
122            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
125            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
126            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
127            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
128            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
129            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
130            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
131            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
132            _ => None,
133        }
134    }
135
136    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
137        Self::new(0, scalar)
138    }
139
140    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
141        Self::new(1, scalar)
142    }
143
144    pub const fn width(&self) -> crate::Bytes {
145        match *self {
146            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
147            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
148            Self::F16(_) => 2,
149            Self::Bool(_) => crate::BOOL_WIDTH,
150            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
151        }
152    }
153    pub const fn scalar(&self) -> crate::Scalar {
154        match *self {
155            Self::F64(_) => crate::Scalar::F64,
156            Self::F32(_) => crate::Scalar::F32,
157            Self::F16(_) => crate::Scalar::F16,
158            Self::U32(_) => crate::Scalar::U32,
159            Self::I32(_) => crate::Scalar::I32,
160            Self::U64(_) => crate::Scalar::U64,
161            Self::I64(_) => crate::Scalar::I64,
162            Self::Bool(_) => crate::Scalar::BOOL,
163            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
164            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
165        }
166    }
167    pub const fn scalar_kind(&self) -> crate::ScalarKind {
168        self.scalar().kind
169    }
170    pub const fn ty_inner(&self) -> crate::TypeInner {
171        crate::TypeInner::Scalar(self.scalar())
172    }
173}
174
175impl super::AddressSpace {
176    pub fn access(self) -> crate::StorageAccess {
177        use crate::StorageAccess as Sa;
178        match self {
179            crate::AddressSpace::Function
180            | crate::AddressSpace::Private
181            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
182            crate::AddressSpace::Uniform => Sa::LOAD,
183            crate::AddressSpace::Storage { access } => access,
184            crate::AddressSpace::Handle => Sa::LOAD,
185            crate::AddressSpace::Immediate => Sa::LOAD,
186            // TaskPayload isn't always writable, but this is checked for elsewhere,
187            // when not using multiple payloads and matching the entry payload is checked.
188            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
189        }
190    }
191}
192
193impl super::MathFunction {
194    pub const fn argument_count(&self) -> usize {
195        match *self {
196            // comparison
197            Self::Abs => 1,
198            Self::Min => 2,
199            Self::Max => 2,
200            Self::Clamp => 3,
201            Self::Saturate => 1,
202            // trigonometry
203            Self::Cos => 1,
204            Self::Cosh => 1,
205            Self::Sin => 1,
206            Self::Sinh => 1,
207            Self::Tan => 1,
208            Self::Tanh => 1,
209            Self::Acos => 1,
210            Self::Asin => 1,
211            Self::Atan => 1,
212            Self::Atan2 => 2,
213            Self::Asinh => 1,
214            Self::Acosh => 1,
215            Self::Atanh => 1,
216            Self::Radians => 1,
217            Self::Degrees => 1,
218            // decomposition
219            Self::Ceil => 1,
220            Self::Floor => 1,
221            Self::Round => 1,
222            Self::Fract => 1,
223            Self::Trunc => 1,
224            Self::Modf => 1,
225            Self::Frexp => 1,
226            Self::Ldexp => 2,
227            // exponent
228            Self::Exp => 1,
229            Self::Exp2 => 1,
230            Self::Log => 1,
231            Self::Log2 => 1,
232            Self::Pow => 2,
233            // geometry
234            Self::Dot => 2,
235            Self::Dot4I8Packed => 2,
236            Self::Dot4U8Packed => 2,
237            Self::Outer => 2,
238            Self::Cross => 2,
239            Self::Distance => 2,
240            Self::Length => 1,
241            Self::Normalize => 1,
242            Self::FaceForward => 3,
243            Self::Reflect => 2,
244            Self::Refract => 3,
245            // computational
246            Self::Sign => 1,
247            Self::Fma => 3,
248            Self::Mix => 3,
249            Self::Step => 2,
250            Self::SmoothStep => 3,
251            Self::Sqrt => 1,
252            Self::InverseSqrt => 1,
253            Self::Inverse => 1,
254            Self::Transpose => 1,
255            Self::Determinant => 1,
256            Self::QuantizeToF16 => 1,
257            // bits
258            Self::CountTrailingZeros => 1,
259            Self::CountLeadingZeros => 1,
260            Self::CountOneBits => 1,
261            Self::ReverseBits => 1,
262            Self::ExtractBits => 3,
263            Self::InsertBits => 4,
264            Self::FirstTrailingBit => 1,
265            Self::FirstLeadingBit => 1,
266            // data packing
267            Self::Pack4x8snorm => 1,
268            Self::Pack4x8unorm => 1,
269            Self::Pack2x16snorm => 1,
270            Self::Pack2x16unorm => 1,
271            Self::Pack2x16float => 1,
272            Self::Pack4xI8 => 1,
273            Self::Pack4xU8 => 1,
274            Self::Pack4xI8Clamp => 1,
275            Self::Pack4xU8Clamp => 1,
276            // data unpacking
277            Self::Unpack4x8snorm => 1,
278            Self::Unpack4x8unorm => 1,
279            Self::Unpack2x16snorm => 1,
280            Self::Unpack2x16unorm => 1,
281            Self::Unpack2x16float => 1,
282            Self::Unpack4xI8 => 1,
283            Self::Unpack4xU8 => 1,
284        }
285    }
286}
287
288impl crate::Expression {
289    /// Returns true if the expression is considered emitted at the start of a function.
290    pub const fn needs_pre_emit(&self) -> bool {
291        match *self {
292            Self::Literal(_)
293            | Self::Constant(_)
294            | Self::Override(_)
295            | Self::ZeroValue(_)
296            | Self::FunctionArgument(_)
297            | Self::GlobalVariable(_)
298            | Self::LocalVariable(_) => true,
299            _ => false,
300        }
301    }
302
303    /// Return true if this expression is a dynamic array/vector/matrix index,
304    /// for [`Access`].
305    ///
306    /// This method returns true if this expression is a dynamically computed
307    /// index, and as such can only be used to index matrices when they appear
308    /// behind a pointer. See the documentation for [`Access`] for details.
309    ///
310    /// Note, this does not check the _type_ of the given expression. It's up to
311    /// the caller to establish that the `Access` expression is well-typed
312    /// through other means, like [`ResolveContext`].
313    ///
314    /// [`Access`]: crate::Expression::Access
315    /// [`ResolveContext`]: crate::proc::ResolveContext
316    pub const fn is_dynamic_index(&self) -> bool {
317        match *self {
318            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
319            _ => true,
320        }
321    }
322}
323
324impl crate::Function {
325    /// Return the global variable being accessed by the expression `pointer`.
326    ///
327    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
328    /// expressions that ultimately access some part of a `GlobalVariable`,
329    /// return a handle for that global.
330    ///
331    /// If the expression does not ultimately access a global variable, return
332    /// `None`.
333    pub fn originating_global(
334        &self,
335        mut pointer: crate::Handle<crate::Expression>,
336    ) -> Option<crate::Handle<crate::GlobalVariable>> {
337        loop {
338            pointer = match self.expressions[pointer] {
339                crate::Expression::Access { base, .. } => base,
340                crate::Expression::AccessIndex { base, .. } => base,
341                crate::Expression::GlobalVariable(handle) => return Some(handle),
342                crate::Expression::LocalVariable(_) => return None,
343                crate::Expression::FunctionArgument(_) => return None,
344                // There are no other expressions that produce pointer values.
345                _ => unreachable!(),
346            }
347        }
348    }
349}
350
351impl crate::SampleLevel {
352    pub const fn implicit_derivatives(&self) -> bool {
353        match *self {
354            Self::Auto | Self::Bias(_) => true,
355            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
356        }
357    }
358}
359
360impl crate::Binding {
361    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
362        match *self {
363            crate::Binding::BuiltIn(built_in) => Some(built_in),
364            Self::Location { .. } => None,
365        }
366    }
367}
368
369impl super::SwizzleComponent {
370    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
371
372    pub const fn index(&self) -> u32 {
373        match *self {
374            Self::X => 0,
375            Self::Y => 1,
376            Self::Z => 2,
377            Self::W => 3,
378        }
379    }
380    pub const fn from_index(idx: u32) -> Self {
381        match idx {
382            0 => Self::X,
383            1 => Self::Y,
384            2 => Self::Z,
385            _ => Self::W,
386        }
387    }
388}
389
390impl super::ImageClass {
391    pub const fn is_multisampled(self) -> bool {
392        match self {
393            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
394            crate::ImageClass::Storage { .. } => false,
395            crate::ImageClass::External => false,
396        }
397    }
398
399    pub const fn is_mipmapped(self) -> bool {
400        match self {
401            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
402            crate::ImageClass::Storage { .. } => false,
403            crate::ImageClass::External => false,
404        }
405    }
406
407    pub const fn is_depth(self) -> bool {
408        matches!(self, crate::ImageClass::Depth { .. })
409    }
410}
411
412impl crate::Module {
413    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
414        GlobalCtx {
415            types: &self.types,
416            constants: &self.constants,
417            overrides: &self.overrides,
418            global_expressions: &self.global_expressions,
419        }
420    }
421
422    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
423        compare_types(lhs, rhs, &self.types)
424    }
425}
426
427#[derive(Debug)]
428pub(super) enum U32EvalError {
429    NonConst,
430    Negative,
431}
432
433#[derive(Clone, Copy)]
434pub struct GlobalCtx<'a> {
435    pub types: &'a crate::UniqueArena<crate::Type>,
436    pub constants: &'a crate::Arena<crate::Constant>,
437    pub overrides: &'a crate::Arena<crate::Override>,
438    pub global_expressions: &'a crate::Arena<crate::Expression>,
439}
440
441impl GlobalCtx<'_> {
442    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
443    #[cfg_attr(
444        not(any(
445            feature = "glsl-in",
446            feature = "spv-in",
447            feature = "wgsl-in",
448            glsl_out,
449            hlsl_out,
450            msl_out,
451            wgsl_out
452        )),
453        allow(dead_code)
454    )]
455    pub(super) fn eval_expr_to_u32(
456        &self,
457        handle: crate::Handle<crate::Expression>,
458    ) -> Result<u32, U32EvalError> {
459        self.eval_expr_to_u32_from(handle, self.global_expressions)
460    }
461
462    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
463    pub(super) fn eval_expr_to_u32_from(
464        &self,
465        handle: crate::Handle<crate::Expression>,
466        arena: &crate::Arena<crate::Expression>,
467    ) -> Result<u32, U32EvalError> {
468        match self.eval_expr_to_literal_from(handle, arena) {
469            Some(crate::Literal::U32(value)) => Ok(value),
470            Some(crate::Literal::I32(value)) => {
471                value.try_into().map_err(|_| U32EvalError::Negative)
472            }
473            _ => Err(U32EvalError::NonConst),
474        }
475    }
476
477    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `bool`.
478    #[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))]
479    pub(super) fn eval_expr_to_bool(
480        &self,
481        handle: crate::Handle<crate::Expression>,
482    ) -> Option<bool> {
483        self.eval_expr_to_bool_from(handle, self.global_expressions)
484    }
485
486    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
487    #[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))]
488    pub(super) fn eval_expr_to_bool_from(
489        &self,
490        handle: crate::Handle<crate::Expression>,
491        arena: &crate::Arena<crate::Expression>,
492    ) -> Option<bool> {
493        match self.eval_expr_to_literal_from(handle, arena) {
494            Some(crate::Literal::Bool(value)) => Some(value),
495            _ => None,
496        }
497    }
498
499    #[expect(dead_code)]
500    pub(crate) fn eval_expr_to_literal(
501        &self,
502        handle: crate::Handle<crate::Expression>,
503    ) -> Option<crate::Literal> {
504        self.eval_expr_to_literal_from(handle, self.global_expressions)
505    }
506
507    pub(super) fn eval_expr_to_literal_from(
508        &self,
509        handle: crate::Handle<crate::Expression>,
510        arena: &crate::Arena<crate::Expression>,
511    ) -> Option<crate::Literal> {
512        fn get(
513            gctx: GlobalCtx,
514            handle: crate::Handle<crate::Expression>,
515            arena: &crate::Arena<crate::Expression>,
516        ) -> Option<crate::Literal> {
517            match arena[handle] {
518                crate::Expression::Literal(literal) => Some(literal),
519                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
520                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
521                    _ => None,
522                },
523                _ => None,
524            }
525        }
526        match arena[handle] {
527            crate::Expression::Constant(c) => {
528                get(*self, self.constants[c].init, self.global_expressions)
529            }
530            _ => get(*self, handle, arena),
531        }
532    }
533
534    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
535        compare_types(lhs, rhs, self.types)
536    }
537}
538
539#[derive(Error, Debug, Clone, Copy, PartialEq)]
540pub enum ResolveArraySizeError {
541    #[error("array element count must be positive (> 0)")]
542    ExpectedPositiveArrayLength,
543    #[error("internal: array size override has not been resolved")]
544    NonConstArrayLength,
545}
546
547impl crate::ArraySize {
548    /// Return the number of elements that `size` represents, if known at code generation time.
549    ///
550    /// If `size` is override-based, return an error unless the override's
551    /// initializer is a fully evaluated constant expression. You can call
552    /// [`pipeline_constants::process_overrides`] to supply values for a
553    /// module's overrides and ensure their initializers are fully evaluated, as
554    /// this function expects.
555    ///
556    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
557    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
558        match *self {
559            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
560            crate::ArraySize::Pending(handle) => {
561                let Some(expr) = gctx.overrides[handle].init else {
562                    return Err(ResolveArraySizeError::NonConstArrayLength);
563                };
564                let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
565                    U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
566                    U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
567                })?;
568
569                if length == 0 {
570                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
571                }
572
573                Ok(IndexableLength::Known(length))
574            }
575            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
576        }
577    }
578}
579
580/// Return an iterator over the individual components assembled by a
581/// `Compose` expression.
582///
583/// Given `ty` and `components` from an `Expression::Compose`, return an
584/// iterator over the components of the resulting value.
585///
586/// Normally, this would just be an iterator over `components`. However,
587/// `Compose` expressions can concatenate vectors, in which case the i'th
588/// value being composed is not generally the i'th element of `components`.
589/// This function consults `ty` to decide if this concatenation is occurring,
590/// and returns an iterator that produces the components of the result of
591/// the `Compose` expression in either case.
592pub fn flatten_compose<'arenas>(
593    ty: crate::Handle<crate::Type>,
594    components: &'arenas [crate::Handle<crate::Expression>],
595    expressions: &'arenas crate::Arena<crate::Expression>,
596    types: &'arenas crate::UniqueArena<crate::Type>,
597) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
598    // Returning `impl Iterator` is a bit tricky. We may or may not
599    // want to flatten the components, but we have to settle on a
600    // single concrete type to return. This function returns a single
601    // iterator chain that handles both the flattening and
602    // non-flattening cases.
603    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
604        (size as usize, true)
605    } else {
606        (components.len(), false)
607    };
608
609    /// Flatten `Compose` expressions if `is_vector` is true.
610    fn flatten_compose<'c>(
611        component: &'c crate::Handle<crate::Expression>,
612        is_vector: bool,
613        expressions: &'c crate::Arena<crate::Expression>,
614    ) -> &'c [crate::Handle<crate::Expression>] {
615        if is_vector {
616            if let crate::Expression::Compose {
617                ty: _,
618                components: ref subcomponents,
619            } = expressions[*component]
620            {
621                return subcomponents;
622            }
623        }
624        core::slice::from_ref(component)
625    }
626
627    /// Flatten `Splat` expressions if `is_vector` is true.
628    fn flatten_splat<'c>(
629        component: &'c crate::Handle<crate::Expression>,
630        is_vector: bool,
631        expressions: &'c crate::Arena<crate::Expression>,
632    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
633        let mut expr = *component;
634        let mut count = 1;
635        if is_vector {
636            if let crate::Expression::Splat { size, value } = expressions[expr] {
637                expr = value;
638                count = size as usize;
639            }
640        }
641        core::iter::repeat_n(expr, count)
642    }
643
644    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
645    // flatten up to two levels of `Compose` expressions.
646    //
647    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
648    // `Splat` expressions. Fortunately, the operand of a `Splat` must
649    // be a scalar, so we can stop there.
650    components
651        .iter()
652        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
653        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
654        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
655        .take(size)
656}
657
658impl super::ShaderStage {
659    pub const fn compute_like(self) -> bool {
660        match self {
661            Self::Vertex | Self::Fragment => false,
662            Self::Compute | Self::Task | Self::Mesh => true,
663        }
664    }
665}
666
667#[test]
668fn test_matrix_size() {
669    let module = crate::Module::default();
670    assert_eq!(
671        crate::TypeInner::Matrix {
672            columns: crate::VectorSize::Tri,
673            rows: crate::VectorSize::Tri,
674            scalar: crate::Scalar::F32,
675        }
676        .size(module.to_ctx()),
677        48,
678    );
679}
680
681impl crate::Module {
682    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
683    /// and by validators. This only validates the output variable itself, and not the
684    /// vertex and primitive output types.
685    ///
686    /// The output contains the extracted mesh stage info, with overrides unset,
687    /// and then the overrides separately. This is because the overrides should be
688    /// treated as expressions elsewhere, but that requires mutably modifying the
689    /// module and the expressions should only be created at parse time, not validation
690    /// time.
691    #[allow(clippy::type_complexity)]
692    pub fn analyze_mesh_shader_info(
693        &self,
694        gv: crate::Handle<crate::GlobalVariable>,
695    ) -> (
696        crate::MeshStageInfo,
697        [Option<crate::Handle<crate::Override>>; 2],
698        Option<crate::WithSpan<crate::valid::EntryPointError>>,
699    ) {
700        use crate::span::AddSpan;
701        use crate::valid::EntryPointError;
702        #[derive(Default)]
703        struct OutError {
704            pub inner: Option<EntryPointError>,
705        }
706        impl OutError {
707            pub fn set(&mut self, err: EntryPointError) {
708                if self.inner.is_none() {
709                    self.inner = Some(err);
710                }
711            }
712        }
713
714        // Used to temporarily initialize stuff
715        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
716        let mut output = crate::MeshStageInfo {
717            topology: crate::MeshOutputTopology::Triangles,
718            max_vertices: 0,
719            max_vertices_override: None,
720            max_primitives: 0,
721            max_primitives_override: None,
722            vertex_output_type: null_type,
723            primitive_output_type: null_type,
724            output_variable: gv,
725        };
726        // Stores the error to output, if any.
727        let mut error = OutError::default();
728        let r#type = &self.types[self.global_variables[gv].ty].inner;
729
730        let mut topology = output.topology;
731        // Max, max override, type
732        let mut vertex_info = (0, None, null_type);
733        let mut primitive_info = (0, None, null_type);
734
735        match r#type {
736            &crate::TypeInner::Struct { ref members, .. } => {
737                let mut builtins = crate::FastHashSet::default();
738                for member in members {
739                    match member.binding {
740                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
741                            // Must have type u32
742                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
743                                error.set(EntryPointError::BadMeshOutputVariableField);
744                            }
745                            // Each builtin should only occur once
746                            if builtins.contains(&crate::BuiltIn::VertexCount) {
747                                error.set(EntryPointError::BadMeshOutputVariableType);
748                            }
749                            builtins.insert(crate::BuiltIn::VertexCount);
750                        }
751                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
752                            // Must have type u32
753                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
754                                error.set(EntryPointError::BadMeshOutputVariableField);
755                            }
756                            // Each builtin should only occur once
757                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
758                                error.set(EntryPointError::BadMeshOutputVariableType);
759                            }
760                            builtins.insert(crate::BuiltIn::PrimitiveCount);
761                        }
762                        Some(crate::Binding::BuiltIn(
763                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
764                        )) => {
765                            let ty = &self.types[member.ty].inner;
766                            // Analyze the array type to determine size and vertex/primitive type
767                            let (a, b, c) = match ty {
768                                &crate::TypeInner::Array { base, size, .. } => {
769                                    let ty = base;
770                                    let (max, max_override) = match size {
771                                        crate::ArraySize::Constant(a) => (a.get(), None),
772                                        crate::ArraySize::Pending(o) => (0, Some(o)),
773                                        crate::ArraySize::Dynamic => {
774                                            error.set(EntryPointError::BadMeshOutputVariableField);
775                                            (0, None)
776                                        }
777                                    };
778                                    (max, max_override, ty)
779                                }
780                                _ => {
781                                    error.set(EntryPointError::BadMeshOutputVariableField);
782                                    (0, None, null_type)
783                                }
784                            };
785                            if matches!(
786                                member.binding,
787                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
788                            ) {
789                                // Primitives require special analysis to determine topology
790                                primitive_info = (a, b, c);
791                                match self.types[c].inner {
792                                    crate::TypeInner::Struct { ref members, .. } => {
793                                        for member in members {
794                                            match member.binding {
795                                                Some(crate::Binding::BuiltIn(
796                                                    crate::BuiltIn::PointIndex,
797                                                )) => {
798                                                    topology = crate::MeshOutputTopology::Points;
799                                                }
800                                                Some(crate::Binding::BuiltIn(
801                                                    crate::BuiltIn::LineIndices,
802                                                )) => {
803                                                    topology = crate::MeshOutputTopology::Lines;
804                                                }
805                                                Some(crate::Binding::BuiltIn(
806                                                    crate::BuiltIn::TriangleIndices,
807                                                )) => {
808                                                    topology = crate::MeshOutputTopology::Triangles;
809                                                }
810                                                _ => (),
811                                            }
812                                        }
813                                    }
814                                    _ => (),
815                                }
816                                // Each builtin should only occur once
817                                if builtins.contains(&crate::BuiltIn::Primitives) {
818                                    error.set(EntryPointError::BadMeshOutputVariableType);
819                                }
820                                builtins.insert(crate::BuiltIn::Primitives);
821                            } else {
822                                vertex_info = (a, b, c);
823                                // Each builtin should only occur once
824                                if builtins.contains(&crate::BuiltIn::Vertices) {
825                                    error.set(EntryPointError::BadMeshOutputVariableType);
826                                }
827                                builtins.insert(crate::BuiltIn::Vertices);
828                            }
829                        }
830                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
831                    }
832                }
833                output = crate::MeshStageInfo {
834                    topology,
835                    max_vertices: vertex_info.0,
836                    max_vertices_override: None,
837                    vertex_output_type: vertex_info.2,
838                    max_primitives: primitive_info.0,
839                    max_primitives_override: None,
840                    primitive_output_type: primitive_info.2,
841                    ..output
842                }
843            }
844            _ => error.set(EntryPointError::BadMeshOutputVariableType),
845        }
846        (
847            output,
848            [vertex_info.1, primitive_info.1],
849            error
850                .inner
851                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
852        )
853    }
854}