naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U32(u32),
94    I32(i32),
95    U64(u64),
96    I64(i64),
97    Bool(bool),
98    AbstractInt(i64),
99    AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103    fn from(l: crate::Literal) -> Self {
104        match l {
105            crate::Literal::F64(v) => Self::F64(v.to_bits()),
106            crate::Literal::F32(v) => Self::F32(v.to_bits()),
107            crate::Literal::F16(v) => Self::F16(v.to_bits()),
108            crate::Literal::U32(v) => Self::U32(v),
109            crate::Literal::I32(v) => Self::I32(v),
110            crate::Literal::U64(v) => Self::U64(v),
111            crate::Literal::I64(v) => Self::I64(v),
112            crate::Literal::Bool(v) => Self::Bool(v),
113            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115        }
116    }
117}
118
119impl crate::Literal {
120    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121        match (value, scalar.kind, scalar.width) {
122            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124            (value, crate::ScalarKind::Float, 2) => {
125                Some(Self::F16(half::f16::from_f32_const(value as _)))
126            }
127            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
128            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
129            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
130            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
131            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
132            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
133            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
134            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
135            _ => None,
136        }
137    }
138
139    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
140        Self::new(0, scalar)
141    }
142
143    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
144        Self::new(1, scalar)
145    }
146
147    pub const fn width(&self) -> crate::Bytes {
148        match *self {
149            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
150            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
151            Self::F16(_) => 2,
152            Self::Bool(_) => crate::BOOL_WIDTH,
153            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
154        }
155    }
156    pub const fn scalar(&self) -> crate::Scalar {
157        match *self {
158            Self::F64(_) => crate::Scalar::F64,
159            Self::F32(_) => crate::Scalar::F32,
160            Self::F16(_) => crate::Scalar::F16,
161            Self::U32(_) => crate::Scalar::U32,
162            Self::I32(_) => crate::Scalar::I32,
163            Self::U64(_) => crate::Scalar::U64,
164            Self::I64(_) => crate::Scalar::I64,
165            Self::Bool(_) => crate::Scalar::BOOL,
166            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
167            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
168        }
169    }
170    pub const fn scalar_kind(&self) -> crate::ScalarKind {
171        self.scalar().kind
172    }
173    pub const fn ty_inner(&self) -> crate::TypeInner {
174        crate::TypeInner::Scalar(self.scalar())
175    }
176}
177
178impl TryFrom<crate::Literal> for u32 {
179    type Error = ConstValueError;
180
181    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
182        match value {
183            crate::Literal::U32(value) => Ok(value),
184            crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
185            _ => Err(ConstValueError::InvalidType),
186        }
187    }
188}
189
190impl TryFrom<crate::Literal> for bool {
191    type Error = ConstValueError;
192
193    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
194        match value {
195            crate::Literal::Bool(value) => Ok(value),
196            _ => Err(ConstValueError::InvalidType),
197        }
198    }
199}
200
201impl super::AddressSpace {
202    pub fn access(self) -> crate::StorageAccess {
203        use crate::StorageAccess as Sa;
204        match self {
205            crate::AddressSpace::Function
206            | crate::AddressSpace::Private
207            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
208            crate::AddressSpace::Uniform => Sa::LOAD,
209            crate::AddressSpace::Storage { access } => access,
210            crate::AddressSpace::Handle => Sa::LOAD,
211            crate::AddressSpace::Immediate => Sa::LOAD,
212            // TaskPayload isn't always writable, but this is checked for elsewhere,
213            // when not using multiple payloads and matching the entry payload is checked.
214            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
215            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
216                Sa::LOAD | Sa::STORE
217            }
218        }
219    }
220}
221
222impl super::MathFunction {
223    pub const fn argument_count(&self) -> usize {
224        match *self {
225            // comparison
226            Self::Abs => 1,
227            Self::Min => 2,
228            Self::Max => 2,
229            Self::Clamp => 3,
230            Self::Saturate => 1,
231            // trigonometry
232            Self::Cos => 1,
233            Self::Cosh => 1,
234            Self::Sin => 1,
235            Self::Sinh => 1,
236            Self::Tan => 1,
237            Self::Tanh => 1,
238            Self::Acos => 1,
239            Self::Asin => 1,
240            Self::Atan => 1,
241            Self::Atan2 => 2,
242            Self::Asinh => 1,
243            Self::Acosh => 1,
244            Self::Atanh => 1,
245            Self::Radians => 1,
246            Self::Degrees => 1,
247            // decomposition
248            Self::Ceil => 1,
249            Self::Floor => 1,
250            Self::Round => 1,
251            Self::Fract => 1,
252            Self::Trunc => 1,
253            Self::Modf => 1,
254            Self::Frexp => 1,
255            Self::Ldexp => 2,
256            // exponent
257            Self::Exp => 1,
258            Self::Exp2 => 1,
259            Self::Log => 1,
260            Self::Log2 => 1,
261            Self::Pow => 2,
262            // geometry
263            Self::Dot => 2,
264            Self::Dot4I8Packed => 2,
265            Self::Dot4U8Packed => 2,
266            Self::Outer => 2,
267            Self::Cross => 2,
268            Self::Distance => 2,
269            Self::Length => 1,
270            Self::Normalize => 1,
271            Self::FaceForward => 3,
272            Self::Reflect => 2,
273            Self::Refract => 3,
274            // computational
275            Self::Sign => 1,
276            Self::Fma => 3,
277            Self::Mix => 3,
278            Self::Step => 2,
279            Self::SmoothStep => 3,
280            Self::Sqrt => 1,
281            Self::InverseSqrt => 1,
282            Self::Inverse => 1,
283            Self::Transpose => 1,
284            Self::Determinant => 1,
285            Self::QuantizeToF16 => 1,
286            // bits
287            Self::CountTrailingZeros => 1,
288            Self::CountLeadingZeros => 1,
289            Self::CountOneBits => 1,
290            Self::ReverseBits => 1,
291            Self::ExtractBits => 3,
292            Self::InsertBits => 4,
293            Self::FirstTrailingBit => 1,
294            Self::FirstLeadingBit => 1,
295            // data packing
296            Self::Pack4x8snorm => 1,
297            Self::Pack4x8unorm => 1,
298            Self::Pack2x16snorm => 1,
299            Self::Pack2x16unorm => 1,
300            Self::Pack2x16float => 1,
301            Self::Pack4xI8 => 1,
302            Self::Pack4xU8 => 1,
303            Self::Pack4xI8Clamp => 1,
304            Self::Pack4xU8Clamp => 1,
305            // data unpacking
306            Self::Unpack4x8snorm => 1,
307            Self::Unpack4x8unorm => 1,
308            Self::Unpack2x16snorm => 1,
309            Self::Unpack2x16unorm => 1,
310            Self::Unpack2x16float => 1,
311            Self::Unpack4xI8 => 1,
312            Self::Unpack4xU8 => 1,
313        }
314    }
315}
316
317impl crate::Expression {
318    /// Returns true if the expression is considered emitted at the start of a function.
319    pub const fn needs_pre_emit(&self) -> bool {
320        match *self {
321            Self::Literal(_)
322            | Self::Constant(_)
323            | Self::Override(_)
324            | Self::ZeroValue(_)
325            | Self::FunctionArgument(_)
326            | Self::GlobalVariable(_)
327            | Self::LocalVariable(_) => true,
328            _ => false,
329        }
330    }
331
332    /// Return true if this expression is a dynamic array/vector/matrix index,
333    /// for [`Access`].
334    ///
335    /// This method returns true if this expression is a dynamically computed
336    /// index, and as such can only be used to index matrices when they appear
337    /// behind a pointer. See the documentation for [`Access`] for details.
338    ///
339    /// Note, this does not check the _type_ of the given expression. It's up to
340    /// the caller to establish that the `Access` expression is well-typed
341    /// through other means, like [`ResolveContext`].
342    ///
343    /// [`Access`]: crate::Expression::Access
344    /// [`ResolveContext`]: crate::proc::ResolveContext
345    pub const fn is_dynamic_index(&self) -> bool {
346        match *self {
347            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
348            _ => true,
349        }
350    }
351}
352
353impl crate::Function {
354    /// Return the global variable being accessed by the expression `pointer`.
355    ///
356    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
357    /// expressions that ultimately access some part of a `GlobalVariable`,
358    /// return a handle for that global.
359    ///
360    /// If the expression does not ultimately access a global variable, return
361    /// `None`.
362    pub fn originating_global(
363        &self,
364        mut pointer: crate::Handle<crate::Expression>,
365    ) -> Option<crate::Handle<crate::GlobalVariable>> {
366        loop {
367            pointer = match self.expressions[pointer] {
368                crate::Expression::Access { base, .. } => base,
369                crate::Expression::AccessIndex { base, .. } => base,
370                crate::Expression::GlobalVariable(handle) => return Some(handle),
371                crate::Expression::LocalVariable(_) => return None,
372                crate::Expression::FunctionArgument(_) => return None,
373                // There are no other expressions that produce pointer values.
374                _ => unreachable!(),
375            }
376        }
377    }
378}
379
380impl crate::SampleLevel {
381    pub const fn implicit_derivatives(&self) -> bool {
382        match *self {
383            Self::Auto | Self::Bias(_) => true,
384            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
385        }
386    }
387}
388
389impl crate::Binding {
390    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
391        match *self {
392            crate::Binding::BuiltIn(built_in) => Some(built_in),
393            Self::Location { .. } => None,
394        }
395    }
396}
397
398impl super::SwizzleComponent {
399    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
400
401    pub const fn index(&self) -> u32 {
402        match *self {
403            Self::X => 0,
404            Self::Y => 1,
405            Self::Z => 2,
406            Self::W => 3,
407        }
408    }
409    pub const fn from_index(idx: u32) -> Self {
410        match idx {
411            0 => Self::X,
412            1 => Self::Y,
413            2 => Self::Z,
414            _ => Self::W,
415        }
416    }
417}
418
419impl super::ImageClass {
420    pub const fn is_multisampled(self) -> bool {
421        match self {
422            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
423            crate::ImageClass::Storage { .. } => false,
424            crate::ImageClass::External => false,
425        }
426    }
427
428    pub const fn is_mipmapped(self) -> bool {
429        match self {
430            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
431            crate::ImageClass::Storage { .. } => false,
432            crate::ImageClass::External => false,
433        }
434    }
435
436    pub const fn is_depth(self) -> bool {
437        matches!(self, crate::ImageClass::Depth { .. })
438    }
439}
440
441impl crate::Module {
442    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
443        GlobalCtx {
444            types: &self.types,
445            constants: &self.constants,
446            overrides: &self.overrides,
447            global_expressions: &self.global_expressions,
448        }
449    }
450
451    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
452        compare_types(lhs, rhs, &self.types)
453    }
454}
455
456#[derive(Debug)]
457pub enum ConstValueError {
458    NonConst,
459    Negative,
460    InvalidType,
461}
462
463impl From<core::convert::Infallible> for ConstValueError {
464    fn from(_: core::convert::Infallible) -> Self {
465        unreachable!()
466    }
467}
468
469#[derive(Clone, Copy)]
470pub struct GlobalCtx<'a> {
471    pub types: &'a crate::UniqueArena<crate::Type>,
472    pub constants: &'a crate::Arena<crate::Constant>,
473    pub overrides: &'a crate::Arena<crate::Override>,
474    pub global_expressions: &'a crate::Arena<crate::Expression>,
475}
476
477impl GlobalCtx<'_> {
478    /// Try to evaluate the expression in `self.global_expressions` using its `handle`
479    /// and return it as a `T: TryFrom<ir::Literal>`.
480    ///
481    /// This currently only evaluates scalar expressions. If adding support for vectors,
482    /// consider changing `valid::expression::validate_constant_shift_amounts` to use that
483    /// support.
484    #[cfg_attr(
485        not(any(
486            feature = "glsl-in",
487            feature = "spv-in",
488            feature = "wgsl-in",
489            glsl_out,
490            hlsl_out,
491            msl_out,
492            wgsl_out
493        )),
494        allow(dead_code)
495    )]
496    pub(super) fn get_const_val<T, E>(
497        &self,
498        handle: crate::Handle<crate::Expression>,
499    ) -> Result<T, ConstValueError>
500    where
501        T: TryFrom<crate::Literal, Error = E>,
502        E: Into<ConstValueError>,
503    {
504        self.get_const_val_from(handle, self.global_expressions)
505    }
506
507    pub(super) fn get_const_val_from<T, E>(
508        &self,
509        handle: crate::Handle<crate::Expression>,
510        arena: &crate::Arena<crate::Expression>,
511    ) -> Result<T, ConstValueError>
512    where
513        T: TryFrom<crate::Literal, Error = E>,
514        E: Into<ConstValueError>,
515    {
516        fn get(
517            gctx: GlobalCtx,
518            handle: crate::Handle<crate::Expression>,
519            arena: &crate::Arena<crate::Expression>,
520        ) -> Option<crate::Literal> {
521            match arena[handle] {
522                crate::Expression::Literal(literal) => Some(literal),
523                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
524                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
525                    _ => None,
526                },
527                _ => None,
528            }
529        }
530        let value = match arena[handle] {
531            crate::Expression::Constant(c) => {
532                get(*self, self.constants[c].init, self.global_expressions)
533            }
534            _ => get(*self, handle, arena),
535        };
536        match value {
537            Some(v) => v.try_into().map_err(Into::into),
538            None => Err(ConstValueError::NonConst),
539        }
540    }
541
542    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
543        compare_types(lhs, rhs, self.types)
544    }
545}
546
547#[derive(Error, Debug, Clone, Copy, PartialEq)]
548pub enum ResolveArraySizeError {
549    #[error("array element count must be positive (> 0)")]
550    ExpectedPositiveArrayLength,
551    #[error("internal: array size override has not been resolved")]
552    NonConstArrayLength,
553}
554
555impl crate::ArraySize {
556    /// Return the number of elements that `size` represents, if known at code generation time.
557    ///
558    /// If `size` is override-based, return an error unless the override's
559    /// initializer is a fully evaluated constant expression. You can call
560    /// [`pipeline_constants::process_overrides`] to supply values for a
561    /// module's overrides and ensure their initializers are fully evaluated, as
562    /// this function expects.
563    ///
564    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
565    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
566        match *self {
567            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
568            crate::ArraySize::Pending(handle) => {
569                let Some(expr) = gctx.overrides[handle].init else {
570                    return Err(ResolveArraySizeError::NonConstArrayLength);
571                };
572                let length = gctx.get_const_val(expr).map_err(|err| match err {
573                    ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
574                    ConstValueError::Negative | ConstValueError::InvalidType => {
575                        ResolveArraySizeError::ExpectedPositiveArrayLength
576                    }
577                })?;
578
579                if length == 0 {
580                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
581                }
582
583                Ok(IndexableLength::Known(length))
584            }
585            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
586        }
587    }
588}
589
590/// Return an iterator over the individual components assembled by a
591/// `Compose` expression.
592///
593/// Given `ty` and `components` from an `Expression::Compose`, return an
594/// iterator over the components of the resulting value.
595///
596/// Normally, this would just be an iterator over `components`. However,
597/// `Compose` expressions can concatenate vectors, in which case the i'th
598/// value being composed is not generally the i'th element of `components`.
599/// This function consults `ty` to decide if this concatenation is occurring,
600/// and returns an iterator that produces the components of the result of
601/// the `Compose` expression in either case.
602pub fn flatten_compose<'arenas>(
603    ty: crate::Handle<crate::Type>,
604    components: &'arenas [crate::Handle<crate::Expression>],
605    expressions: &'arenas crate::Arena<crate::Expression>,
606    types: &'arenas crate::UniqueArena<crate::Type>,
607) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
608    // Returning `impl Iterator` is a bit tricky. We may or may not
609    // want to flatten the components, but we have to settle on a
610    // single concrete type to return. This function returns a single
611    // iterator chain that handles both the flattening and
612    // non-flattening cases.
613    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
614        (size as usize, true)
615    } else {
616        (components.len(), false)
617    };
618
619    /// Flatten `Compose` expressions if `is_vector` is true.
620    fn flatten_compose<'c>(
621        component: &'c crate::Handle<crate::Expression>,
622        is_vector: bool,
623        expressions: &'c crate::Arena<crate::Expression>,
624    ) -> &'c [crate::Handle<crate::Expression>] {
625        if is_vector {
626            if let crate::Expression::Compose {
627                ty: _,
628                components: ref subcomponents,
629            } = expressions[*component]
630            {
631                return subcomponents;
632            }
633        }
634        core::slice::from_ref(component)
635    }
636
637    /// Flatten `Splat` expressions if `is_vector` is true.
638    fn flatten_splat<'c>(
639        component: &'c crate::Handle<crate::Expression>,
640        is_vector: bool,
641        expressions: &'c crate::Arena<crate::Expression>,
642    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
643        let mut expr = *component;
644        let mut count = 1;
645        if is_vector {
646            if let crate::Expression::Splat { size, value } = expressions[expr] {
647                expr = value;
648                count = size as usize;
649            }
650        }
651        core::iter::repeat_n(expr, count)
652    }
653
654    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
655    // flatten up to two levels of `Compose` expressions.
656    //
657    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
658    // `Splat` expressions. Fortunately, the operand of a `Splat` must
659    // be a scalar, so we can stop there.
660    components
661        .iter()
662        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
663        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
664        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
665        .take(size)
666}
667
668impl super::ShaderStage {
669    pub const fn compute_like(self) -> bool {
670        match self {
671            Self::Vertex | Self::Fragment => false,
672            Self::Compute | Self::Task | Self::Mesh => true,
673            Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => false,
674        }
675    }
676
677    /// Mesh or task shader
678    pub const fn mesh_like(self) -> bool {
679        match self {
680            Self::Task | Self::Mesh => true,
681            _ => false,
682        }
683    }
684}
685
686#[test]
687fn test_matrix_size() {
688    let module = crate::Module::default();
689    assert_eq!(
690        crate::TypeInner::Matrix {
691            columns: crate::VectorSize::Tri,
692            rows: crate::VectorSize::Tri,
693            scalar: crate::Scalar::F32,
694        }
695        .size(module.to_ctx()),
696        48,
697    );
698}
699
700impl crate::Module {
701    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
702    /// and by validators. This only validates the output variable itself, and not the
703    /// vertex and primitive output types.
704    ///
705    /// The output contains the extracted mesh stage info, with overrides unset,
706    /// and then the overrides separately. This is because the overrides should be
707    /// treated as expressions elsewhere, but that requires mutably modifying the
708    /// module and the expressions should only be created at parse time, not validation
709    /// time.
710    #[allow(clippy::type_complexity)]
711    pub fn analyze_mesh_shader_info(
712        &self,
713        gv: crate::Handle<crate::GlobalVariable>,
714    ) -> (
715        crate::MeshStageInfo,
716        [Option<crate::Handle<crate::Override>>; 2],
717        Option<crate::WithSpan<crate::valid::EntryPointError>>,
718    ) {
719        use crate::span::AddSpan;
720        use crate::valid::EntryPointError;
721        #[derive(Default)]
722        struct OutError {
723            pub inner: Option<EntryPointError>,
724        }
725        impl OutError {
726            pub fn set(&mut self, err: EntryPointError) {
727                if self.inner.is_none() {
728                    self.inner = Some(err);
729                }
730            }
731        }
732
733        // Used to temporarily initialize stuff
734        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
735        let mut output = crate::MeshStageInfo {
736            topology: crate::MeshOutputTopology::Triangles,
737            max_vertices: 0,
738            max_vertices_override: None,
739            max_primitives: 0,
740            max_primitives_override: None,
741            vertex_output_type: null_type,
742            primitive_output_type: null_type,
743            output_variable: gv,
744        };
745        // Stores the error to output, if any.
746        let mut error = OutError::default();
747        let r#type = &self.types[self.global_variables[gv].ty].inner;
748
749        let mut topology = output.topology;
750        // Max, max override, type
751        let mut vertex_info = (0, None, null_type);
752        let mut primitive_info = (0, None, null_type);
753
754        match r#type {
755            &crate::TypeInner::Struct { ref members, .. } => {
756                let mut builtins = crate::FastHashSet::default();
757                for member in members {
758                    match member.binding {
759                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
760                            // Must have type u32
761                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
762                                error.set(EntryPointError::BadMeshOutputVariableField);
763                            }
764                            // Each builtin should only occur once
765                            if builtins.contains(&crate::BuiltIn::VertexCount) {
766                                error.set(EntryPointError::BadMeshOutputVariableType);
767                            }
768                            builtins.insert(crate::BuiltIn::VertexCount);
769                        }
770                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
771                            // Must have type u32
772                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
773                                error.set(EntryPointError::BadMeshOutputVariableField);
774                            }
775                            // Each builtin should only occur once
776                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
777                                error.set(EntryPointError::BadMeshOutputVariableType);
778                            }
779                            builtins.insert(crate::BuiltIn::PrimitiveCount);
780                        }
781                        Some(crate::Binding::BuiltIn(
782                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
783                        )) => {
784                            let ty = &self.types[member.ty].inner;
785                            // Analyze the array type to determine size and vertex/primitive type
786                            let (a, b, c) = match ty {
787                                &crate::TypeInner::Array { base, size, .. } => {
788                                    let ty = base;
789                                    let (max, max_override) = match size {
790                                        crate::ArraySize::Constant(a) => (a.get(), None),
791                                        crate::ArraySize::Pending(o) => (0, Some(o)),
792                                        crate::ArraySize::Dynamic => {
793                                            error.set(EntryPointError::BadMeshOutputVariableField);
794                                            (0, None)
795                                        }
796                                    };
797                                    (max, max_override, ty)
798                                }
799                                _ => {
800                                    error.set(EntryPointError::BadMeshOutputVariableField);
801                                    (0, None, null_type)
802                                }
803                            };
804                            if matches!(
805                                member.binding,
806                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
807                            ) {
808                                // Primitives require special analysis to determine topology
809                                primitive_info = (a, b, c);
810                                match self.types[c].inner {
811                                    crate::TypeInner::Struct { ref members, .. } => {
812                                        for member in members {
813                                            match member.binding {
814                                                Some(crate::Binding::BuiltIn(
815                                                    crate::BuiltIn::PointIndex,
816                                                )) => {
817                                                    topology = crate::MeshOutputTopology::Points;
818                                                }
819                                                Some(crate::Binding::BuiltIn(
820                                                    crate::BuiltIn::LineIndices,
821                                                )) => {
822                                                    topology = crate::MeshOutputTopology::Lines;
823                                                }
824                                                Some(crate::Binding::BuiltIn(
825                                                    crate::BuiltIn::TriangleIndices,
826                                                )) => {
827                                                    topology = crate::MeshOutputTopology::Triangles;
828                                                }
829                                                _ => (),
830                                            }
831                                        }
832                                    }
833                                    _ => (),
834                                }
835                                // Each builtin should only occur once
836                                if builtins.contains(&crate::BuiltIn::Primitives) {
837                                    error.set(EntryPointError::BadMeshOutputVariableType);
838                                }
839                                builtins.insert(crate::BuiltIn::Primitives);
840                            } else {
841                                vertex_info = (a, b, c);
842                                // Each builtin should only occur once
843                                if builtins.contains(&crate::BuiltIn::Vertices) {
844                                    error.set(EntryPointError::BadMeshOutputVariableType);
845                                }
846                                builtins.insert(crate::BuiltIn::Vertices);
847                            }
848                        }
849                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
850                    }
851                }
852                output = crate::MeshStageInfo {
853                    topology,
854                    max_vertices: vertex_info.0,
855                    max_vertices_override: None,
856                    vertex_output_type: vertex_info.2,
857                    max_primitives: primitive_info.0,
858                    max_primitives_override: None,
859                    primitive_output_type: primitive_info.2,
860                    ..output
861                }
862            }
863            _ => error.set(EntryPointError::BadMeshOutputVariableType),
864        }
865        (
866            output,
867            [vertex_info.1, primitive_info.1],
868            error
869                .inner
870                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
871        )
872    }
873
874    pub fn uses_mesh_shaders(&self) -> bool {
875        let binding_uses_mesh = |b: &crate::Binding| {
876            matches!(
877                b,
878                crate::Binding::BuiltIn(
879                    crate::BuiltIn::MeshTaskSize
880                        | crate::BuiltIn::CullPrimitive
881                        | crate::BuiltIn::PointIndex
882                        | crate::BuiltIn::LineIndices
883                        | crate::BuiltIn::TriangleIndices
884                        | crate::BuiltIn::VertexCount
885                        | crate::BuiltIn::Vertices
886                        | crate::BuiltIn::PrimitiveCount
887                        | crate::BuiltIn::Primitives,
888                ) | crate::Binding::Location {
889                    per_primitive: true,
890                    ..
891                }
892            )
893        };
894        for (_, ty) in self.types.iter() {
895            match ty.inner {
896                crate::TypeInner::Struct { ref members, .. } => {
897                    for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
898                        if binding_uses_mesh(binding) {
899                            return true;
900                        }
901                    }
902                }
903                _ => (),
904            }
905        }
906        for ep in &self.entry_points {
907            if matches!(
908                ep.stage,
909                crate::ShaderStage::Mesh | crate::ShaderStage::Task
910            ) {
911                return true;
912            }
913            for binding in ep
914                .function
915                .arguments
916                .iter()
917                .filter_map(|arg| arg.binding.as_ref())
918                .chain(
919                    ep.function
920                        .result
921                        .iter()
922                        .filter_map(|res| res.binding.as_ref()),
923                )
924            {
925                if binding_uses_mesh(binding) {
926                    return true;
927                }
928            }
929        }
930        if self
931            .global_variables
932            .iter()
933            .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
934        {
935            return true;
936        }
937        false
938    }
939}
940
941impl crate::MeshOutputTopology {
942    pub const fn to_builtin(self) -> crate::BuiltIn {
943        match self {
944            Self::Points => crate::BuiltIn::PointIndex,
945            Self::Lines => crate::BuiltIn::LineIndices,
946            Self::Triangles => crate::BuiltIn::TriangleIndices,
947        }
948    }
949}
950
951impl crate::AddressSpace {
952    pub const fn is_workgroup_like(self) -> bool {
953        matches!(self, Self::WorkGroup | Self::TaskPayload)
954    }
955}