naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U32(u32),
94    I32(i32),
95    U64(u64),
96    I64(i64),
97    Bool(bool),
98    AbstractInt(i64),
99    AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103    fn from(l: crate::Literal) -> Self {
104        match l {
105            crate::Literal::F64(v) => Self::F64(v.to_bits()),
106            crate::Literal::F32(v) => Self::F32(v.to_bits()),
107            crate::Literal::F16(v) => Self::F16(v.to_bits()),
108            crate::Literal::U32(v) => Self::U32(v),
109            crate::Literal::I32(v) => Self::I32(v),
110            crate::Literal::U64(v) => Self::U64(v),
111            crate::Literal::I64(v) => Self::I64(v),
112            crate::Literal::Bool(v) => Self::Bool(v),
113            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115        }
116    }
117}
118
119impl crate::Literal {
120    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121        match (value, scalar.kind, scalar.width) {
122            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124            (value, crate::ScalarKind::Float, 2) => {
125                Some(Self::F16(half::f16::from_f32_const(value as _)))
126            }
127            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
128            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
129            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
130            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
131            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
132            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
133            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
134            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
135            _ => None,
136        }
137    }
138
139    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
140        Self::new(0, scalar)
141    }
142
143    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
144        Self::new(1, scalar)
145    }
146
147    pub const fn minus_one(scalar: crate::Scalar) -> Option<Self> {
148        match (scalar.kind, scalar.width) {
149            (crate::ScalarKind::Float, 8) => Some(Self::F64(-1.0)),
150            (crate::ScalarKind::Float, 4) => Some(Self::F32(-1.0)),
151            (crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))),
152            (crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)),
153            (crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)),
154            (crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)),
155            _ => None,
156        }
157    }
158
159    pub const fn width(&self) -> crate::Bytes {
160        match *self {
161            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
162            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
163            Self::F16(_) => 2,
164            Self::Bool(_) => crate::BOOL_WIDTH,
165            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
166        }
167    }
168    pub const fn scalar(&self) -> crate::Scalar {
169        match *self {
170            Self::F64(_) => crate::Scalar::F64,
171            Self::F32(_) => crate::Scalar::F32,
172            Self::F16(_) => crate::Scalar::F16,
173            Self::U32(_) => crate::Scalar::U32,
174            Self::I32(_) => crate::Scalar::I32,
175            Self::U64(_) => crate::Scalar::U64,
176            Self::I64(_) => crate::Scalar::I64,
177            Self::Bool(_) => crate::Scalar::BOOL,
178            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
179            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
180        }
181    }
182    pub const fn scalar_kind(&self) -> crate::ScalarKind {
183        self.scalar().kind
184    }
185    pub const fn ty_inner(&self) -> crate::TypeInner {
186        crate::TypeInner::Scalar(self.scalar())
187    }
188}
189
190impl TryFrom<crate::Literal> for u32 {
191    type Error = ConstValueError;
192
193    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
194        match value {
195            crate::Literal::U32(value) => Ok(value),
196            crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
197            _ => Err(ConstValueError::InvalidType),
198        }
199    }
200}
201
202impl TryFrom<crate::Literal> for bool {
203    type Error = ConstValueError;
204
205    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
206        match value {
207            crate::Literal::Bool(value) => Ok(value),
208            _ => Err(ConstValueError::InvalidType),
209        }
210    }
211}
212
213impl super::AddressSpace {
214    pub fn access(self) -> crate::StorageAccess {
215        use crate::StorageAccess as Sa;
216        match self {
217            crate::AddressSpace::Function
218            | crate::AddressSpace::Private
219            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
220            crate::AddressSpace::Uniform => Sa::LOAD,
221            crate::AddressSpace::Storage { access } => access,
222            crate::AddressSpace::Handle => Sa::LOAD,
223            crate::AddressSpace::Immediate => Sa::LOAD,
224            // TaskPayload isn't always writable, but this is checked for elsewhere,
225            // when not using multiple payloads and matching the entry payload is checked.
226            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
227            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
228                Sa::LOAD | Sa::STORE
229            }
230        }
231    }
232}
233
234impl super::MathFunction {
235    pub const fn argument_count(&self) -> usize {
236        match *self {
237            // comparison
238            Self::Abs => 1,
239            Self::Min => 2,
240            Self::Max => 2,
241            Self::Clamp => 3,
242            Self::Saturate => 1,
243            // trigonometry
244            Self::Cos => 1,
245            Self::Cosh => 1,
246            Self::Sin => 1,
247            Self::Sinh => 1,
248            Self::Tan => 1,
249            Self::Tanh => 1,
250            Self::Acos => 1,
251            Self::Asin => 1,
252            Self::Atan => 1,
253            Self::Atan2 => 2,
254            Self::Asinh => 1,
255            Self::Acosh => 1,
256            Self::Atanh => 1,
257            Self::Radians => 1,
258            Self::Degrees => 1,
259            // decomposition
260            Self::Ceil => 1,
261            Self::Floor => 1,
262            Self::Round => 1,
263            Self::Fract => 1,
264            Self::Trunc => 1,
265            Self::Modf => 1,
266            Self::Frexp => 1,
267            Self::Ldexp => 2,
268            // exponent
269            Self::Exp => 1,
270            Self::Exp2 => 1,
271            Self::Log => 1,
272            Self::Log2 => 1,
273            Self::Pow => 2,
274            // geometry
275            Self::Dot => 2,
276            Self::Dot4I8Packed => 2,
277            Self::Dot4U8Packed => 2,
278            Self::Outer => 2,
279            Self::Cross => 2,
280            Self::Distance => 2,
281            Self::Length => 1,
282            Self::Normalize => 1,
283            Self::FaceForward => 3,
284            Self::Reflect => 2,
285            Self::Refract => 3,
286            // computational
287            Self::Sign => 1,
288            Self::Fma => 3,
289            Self::Mix => 3,
290            Self::Step => 2,
291            Self::SmoothStep => 3,
292            Self::Sqrt => 1,
293            Self::InverseSqrt => 1,
294            Self::Inverse => 1,
295            Self::Transpose => 1,
296            Self::Determinant => 1,
297            Self::QuantizeToF16 => 1,
298            // bits
299            Self::CountTrailingZeros => 1,
300            Self::CountLeadingZeros => 1,
301            Self::CountOneBits => 1,
302            Self::ReverseBits => 1,
303            Self::ExtractBits => 3,
304            Self::InsertBits => 4,
305            Self::FirstTrailingBit => 1,
306            Self::FirstLeadingBit => 1,
307            // data packing
308            Self::Pack4x8snorm => 1,
309            Self::Pack4x8unorm => 1,
310            Self::Pack2x16snorm => 1,
311            Self::Pack2x16unorm => 1,
312            Self::Pack2x16float => 1,
313            Self::Pack4xI8 => 1,
314            Self::Pack4xU8 => 1,
315            Self::Pack4xI8Clamp => 1,
316            Self::Pack4xU8Clamp => 1,
317            // data unpacking
318            Self::Unpack4x8snorm => 1,
319            Self::Unpack4x8unorm => 1,
320            Self::Unpack2x16snorm => 1,
321            Self::Unpack2x16unorm => 1,
322            Self::Unpack2x16float => 1,
323            Self::Unpack4xI8 => 1,
324            Self::Unpack4xU8 => 1,
325        }
326    }
327}
328
329impl crate::Expression {
330    /// Returns true if the expression is considered emitted at the start of a function.
331    pub const fn needs_pre_emit(&self) -> bool {
332        match *self {
333            Self::Literal(_)
334            | Self::Constant(_)
335            | Self::Override(_)
336            | Self::ZeroValue(_)
337            | Self::FunctionArgument(_)
338            | Self::GlobalVariable(_)
339            | Self::LocalVariable(_) => true,
340            _ => false,
341        }
342    }
343
344    /// Return true if this expression is a dynamic array/vector/matrix index,
345    /// for [`Access`].
346    ///
347    /// This method returns true if this expression is a dynamically computed
348    /// index, and as such can only be used to index matrices when they appear
349    /// behind a pointer. See the documentation for [`Access`] for details.
350    ///
351    /// Note, this does not check the _type_ of the given expression. It's up to
352    /// the caller to establish that the `Access` expression is well-typed
353    /// through other means, like [`ResolveContext`].
354    ///
355    /// [`Access`]: crate::Expression::Access
356    /// [`ResolveContext`]: crate::proc::ResolveContext
357    pub const fn is_dynamic_index(&self) -> bool {
358        match *self {
359            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
360            _ => true,
361        }
362    }
363}
364
365impl crate::Function {
366    /// Return the global variable being accessed by the expression `pointer`.
367    ///
368    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
369    /// expressions that ultimately access some part of a `GlobalVariable`,
370    /// return a handle for that global.
371    ///
372    /// If the expression does not ultimately access a global variable, return
373    /// `None`.
374    pub fn originating_global(
375        &self,
376        mut pointer: crate::Handle<crate::Expression>,
377    ) -> Option<crate::Handle<crate::GlobalVariable>> {
378        loop {
379            pointer = match self.expressions[pointer] {
380                crate::Expression::Access { base, .. } => base,
381                crate::Expression::AccessIndex { base, .. } => base,
382                crate::Expression::GlobalVariable(handle) => return Some(handle),
383                crate::Expression::LocalVariable(_) => return None,
384                crate::Expression::FunctionArgument(_) => return None,
385                // There are no other expressions that produce pointer values.
386                _ => unreachable!(),
387            }
388        }
389    }
390}
391
392impl crate::SampleLevel {
393    pub const fn implicit_derivatives(&self) -> bool {
394        match *self {
395            Self::Auto | Self::Bias(_) => true,
396            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
397        }
398    }
399}
400
401impl crate::Binding {
402    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
403        match *self {
404            crate::Binding::BuiltIn(built_in) => Some(built_in),
405            Self::Location { .. } => None,
406        }
407    }
408}
409
410impl super::SwizzleComponent {
411    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
412
413    pub const fn index(&self) -> u32 {
414        match *self {
415            Self::X => 0,
416            Self::Y => 1,
417            Self::Z => 2,
418            Self::W => 3,
419        }
420    }
421    pub const fn from_index(idx: u32) -> Self {
422        match idx {
423            0 => Self::X,
424            1 => Self::Y,
425            2 => Self::Z,
426            _ => Self::W,
427        }
428    }
429}
430
431impl super::ImageClass {
432    pub const fn is_multisampled(self) -> bool {
433        match self {
434            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
435            crate::ImageClass::Storage { .. } => false,
436            crate::ImageClass::External => false,
437        }
438    }
439
440    pub const fn is_mipmapped(self) -> bool {
441        match self {
442            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
443            crate::ImageClass::Storage { .. } => false,
444            crate::ImageClass::External => false,
445        }
446    }
447
448    pub const fn is_depth(self) -> bool {
449        matches!(self, crate::ImageClass::Depth { .. })
450    }
451}
452
453impl crate::Module {
454    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
455        GlobalCtx {
456            types: &self.types,
457            constants: &self.constants,
458            overrides: &self.overrides,
459            global_expressions: &self.global_expressions,
460        }
461    }
462
463    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
464        compare_types(lhs, rhs, &self.types)
465    }
466}
467
468#[derive(Debug)]
469pub enum ConstValueError {
470    NonConst,
471    Negative,
472    InvalidType,
473}
474
475impl From<core::convert::Infallible> for ConstValueError {
476    fn from(_: core::convert::Infallible) -> Self {
477        unreachable!()
478    }
479}
480
481#[derive(Clone, Copy)]
482pub struct GlobalCtx<'a> {
483    pub types: &'a crate::UniqueArena<crate::Type>,
484    pub constants: &'a crate::Arena<crate::Constant>,
485    pub overrides: &'a crate::Arena<crate::Override>,
486    pub global_expressions: &'a crate::Arena<crate::Expression>,
487}
488
489impl GlobalCtx<'_> {
490    /// Try to evaluate the expression in `self.global_expressions` using its `handle`
491    /// and return it as a `T: TryFrom<ir::Literal>`.
492    ///
493    /// This currently only evaluates scalar expressions. If adding support for vectors,
494    /// consider changing `valid::expression::validate_constant_shift_amounts` to use that
495    /// support.
496    #[cfg_attr(
497        not(any(
498            feature = "glsl-in",
499            feature = "spv-in",
500            feature = "wgsl-in",
501            glsl_out,
502            hlsl_out,
503            msl_out,
504            wgsl_out
505        )),
506        allow(dead_code)
507    )]
508    pub(super) fn get_const_val<T, E>(
509        &self,
510        handle: crate::Handle<crate::Expression>,
511    ) -> Result<T, ConstValueError>
512    where
513        T: TryFrom<crate::Literal, Error = E>,
514        E: Into<ConstValueError>,
515    {
516        self.get_const_val_from(handle, self.global_expressions)
517    }
518
519    pub(super) fn get_const_val_from<T, E>(
520        &self,
521        handle: crate::Handle<crate::Expression>,
522        arena: &crate::Arena<crate::Expression>,
523    ) -> Result<T, ConstValueError>
524    where
525        T: TryFrom<crate::Literal, Error = E>,
526        E: Into<ConstValueError>,
527    {
528        fn get(
529            gctx: GlobalCtx,
530            handle: crate::Handle<crate::Expression>,
531            arena: &crate::Arena<crate::Expression>,
532        ) -> Option<crate::Literal> {
533            match arena[handle] {
534                crate::Expression::Literal(literal) => Some(literal),
535                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
536                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
537                    _ => None,
538                },
539                _ => None,
540            }
541        }
542        let value = match arena[handle] {
543            crate::Expression::Constant(c) => {
544                get(*self, self.constants[c].init, self.global_expressions)
545            }
546            _ => get(*self, handle, arena),
547        };
548        match value {
549            Some(v) => v.try_into().map_err(Into::into),
550            None => Err(ConstValueError::NonConst),
551        }
552    }
553
554    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
555        compare_types(lhs, rhs, self.types)
556    }
557}
558
559#[derive(Error, Debug, Clone, Copy, PartialEq)]
560pub enum ResolveArraySizeError {
561    #[error("array element count must be positive (> 0)")]
562    ExpectedPositiveArrayLength,
563    #[error("internal: array size override has not been resolved")]
564    NonConstArrayLength,
565}
566
567impl crate::ArraySize {
568    /// Return the number of elements that `size` represents, if known at code generation time.
569    ///
570    /// If `size` is override-based, return an error unless the override's
571    /// initializer is a fully evaluated constant expression. You can call
572    /// [`pipeline_constants::process_overrides`] to supply values for a
573    /// module's overrides and ensure their initializers are fully evaluated, as
574    /// this function expects.
575    ///
576    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
577    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
578        match *self {
579            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
580            crate::ArraySize::Pending(handle) => {
581                let Some(expr) = gctx.overrides[handle].init else {
582                    return Err(ResolveArraySizeError::NonConstArrayLength);
583                };
584                let length = gctx.get_const_val(expr).map_err(|err| match err {
585                    ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
586                    ConstValueError::Negative | ConstValueError::InvalidType => {
587                        ResolveArraySizeError::ExpectedPositiveArrayLength
588                    }
589                })?;
590
591                if length == 0 {
592                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
593                }
594
595                Ok(IndexableLength::Known(length))
596            }
597            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
598        }
599    }
600}
601
602/// Return an iterator over the individual components assembled by a
603/// `Compose` expression.
604///
605/// Given `ty` and `components` from an `Expression::Compose`, return an
606/// iterator over the components of the resulting value.
607///
608/// Normally, this would just be an iterator over `components`. However,
609/// `Compose` expressions can concatenate vectors, in which case the i'th
610/// value being composed is not generally the i'th element of `components`.
611/// This function consults `ty` to decide if this concatenation is occurring,
612/// and returns an iterator that produces the components of the result of
613/// the `Compose` expression in either case.
614pub fn flatten_compose<'arenas>(
615    ty: crate::Handle<crate::Type>,
616    components: &'arenas [crate::Handle<crate::Expression>],
617    expressions: &'arenas crate::Arena<crate::Expression>,
618    types: &'arenas crate::UniqueArena<crate::Type>,
619) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
620    // Returning `impl Iterator` is a bit tricky. We may or may not
621    // want to flatten the components, but we have to settle on a
622    // single concrete type to return. This function returns a single
623    // iterator chain that handles both the flattening and
624    // non-flattening cases.
625    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
626        (size as usize, true)
627    } else {
628        (components.len(), false)
629    };
630
631    /// Flatten `Compose` expressions if `is_vector` is true.
632    fn flatten_compose<'c>(
633        component: &'c crate::Handle<crate::Expression>,
634        is_vector: bool,
635        expressions: &'c crate::Arena<crate::Expression>,
636    ) -> &'c [crate::Handle<crate::Expression>] {
637        if is_vector {
638            if let crate::Expression::Compose {
639                ty: _,
640                components: ref subcomponents,
641            } = expressions[*component]
642            {
643                return subcomponents;
644            }
645        }
646        core::slice::from_ref(component)
647    }
648
649    /// Flatten `Splat` expressions if `is_vector` is true.
650    fn flatten_splat<'c>(
651        component: &'c crate::Handle<crate::Expression>,
652        is_vector: bool,
653        expressions: &'c crate::Arena<crate::Expression>,
654    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
655        let mut expr = *component;
656        let mut count = 1;
657        if is_vector {
658            if let crate::Expression::Splat { size, value } = expressions[expr] {
659                expr = value;
660                count = size as usize;
661            }
662        }
663        core::iter::repeat_n(expr, count)
664    }
665
666    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
667    // flatten up to two levels of `Compose` expressions.
668    //
669    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
670    // `Splat` expressions. Fortunately, the operand of a `Splat` must
671    // be a scalar, so we can stop there.
672    components
673        .iter()
674        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
675        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
676        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
677        .take(size)
678}
679
680impl super::ShaderStage {
681    pub const fn compute_like(self) -> bool {
682        match self {
683            Self::Vertex | Self::Fragment => false,
684            Self::Compute | Self::Task | Self::Mesh => true,
685            Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => false,
686        }
687    }
688
689    /// Mesh or task shader
690    pub const fn mesh_like(self) -> bool {
691        match self {
692            Self::Task | Self::Mesh => true,
693            _ => false,
694        }
695    }
696}
697
698#[test]
699fn test_matrix_size() {
700    let module = crate::Module::default();
701    assert_eq!(
702        crate::TypeInner::Matrix {
703            columns: crate::VectorSize::Tri,
704            rows: crate::VectorSize::Tri,
705            scalar: crate::Scalar::F32,
706        }
707        .size(module.to_ctx()),
708        48,
709    );
710}
711
712impl crate::Module {
713    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
714    /// and by validators. This only validates the output variable itself, and not the
715    /// vertex and primitive output types.
716    ///
717    /// The output contains the extracted mesh stage info, with overrides unset,
718    /// and then the overrides separately. This is because the overrides should be
719    /// treated as expressions elsewhere, but that requires mutably modifying the
720    /// module and the expressions should only be created at parse time, not validation
721    /// time.
722    #[allow(clippy::type_complexity)]
723    pub fn analyze_mesh_shader_info(
724        &self,
725        gv: crate::Handle<crate::GlobalVariable>,
726    ) -> (
727        crate::MeshStageInfo,
728        [Option<crate::Handle<crate::Override>>; 2],
729        Option<crate::WithSpan<crate::valid::EntryPointError>>,
730    ) {
731        use crate::span::AddSpan;
732        use crate::valid::EntryPointError;
733        #[derive(Default)]
734        struct OutError {
735            pub inner: Option<EntryPointError>,
736        }
737        impl OutError {
738            pub fn set(&mut self, err: EntryPointError) {
739                if self.inner.is_none() {
740                    self.inner = Some(err);
741                }
742            }
743        }
744
745        // Used to temporarily initialize stuff
746        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
747        let mut output = crate::MeshStageInfo {
748            topology: crate::MeshOutputTopology::Triangles,
749            max_vertices: 0,
750            max_vertices_override: None,
751            max_primitives: 0,
752            max_primitives_override: None,
753            vertex_output_type: null_type,
754            primitive_output_type: null_type,
755            output_variable: gv,
756        };
757        // Stores the error to output, if any.
758        let mut error = OutError::default();
759        let r#type = &self.types[self.global_variables[gv].ty].inner;
760
761        let mut topology = output.topology;
762        // Max, max override, type
763        let mut vertex_info = (0, None, null_type);
764        let mut primitive_info = (0, None, null_type);
765
766        match r#type {
767            &crate::TypeInner::Struct { ref members, .. } => {
768                let mut builtins = crate::FastHashSet::default();
769                for member in members {
770                    match member.binding {
771                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
772                            // Must have type u32
773                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
774                                error.set(EntryPointError::BadMeshOutputVariableField);
775                            }
776                            // Each builtin should only occur once
777                            if builtins.contains(&crate::BuiltIn::VertexCount) {
778                                error.set(EntryPointError::BadMeshOutputVariableType);
779                            }
780                            builtins.insert(crate::BuiltIn::VertexCount);
781                        }
782                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
783                            // Must have type u32
784                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
785                                error.set(EntryPointError::BadMeshOutputVariableField);
786                            }
787                            // Each builtin should only occur once
788                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
789                                error.set(EntryPointError::BadMeshOutputVariableType);
790                            }
791                            builtins.insert(crate::BuiltIn::PrimitiveCount);
792                        }
793                        Some(crate::Binding::BuiltIn(
794                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
795                        )) => {
796                            let ty = &self.types[member.ty].inner;
797                            // Analyze the array type to determine size and vertex/primitive type
798                            let (a, b, c) = match ty {
799                                &crate::TypeInner::Array { base, size, .. } => {
800                                    let ty = base;
801                                    let (max, max_override) = match size {
802                                        crate::ArraySize::Constant(a) => (a.get(), None),
803                                        crate::ArraySize::Pending(o) => (0, Some(o)),
804                                        crate::ArraySize::Dynamic => {
805                                            error.set(EntryPointError::BadMeshOutputVariableField);
806                                            (0, None)
807                                        }
808                                    };
809                                    (max, max_override, ty)
810                                }
811                                _ => {
812                                    error.set(EntryPointError::BadMeshOutputVariableField);
813                                    (0, None, null_type)
814                                }
815                            };
816                            if matches!(
817                                member.binding,
818                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
819                            ) {
820                                // Primitives require special analysis to determine topology
821                                primitive_info = (a, b, c);
822                                match self.types[c].inner {
823                                    crate::TypeInner::Struct { ref members, .. } => {
824                                        for member in members {
825                                            match member.binding {
826                                                Some(crate::Binding::BuiltIn(
827                                                    crate::BuiltIn::PointIndex,
828                                                )) => {
829                                                    topology = crate::MeshOutputTopology::Points;
830                                                }
831                                                Some(crate::Binding::BuiltIn(
832                                                    crate::BuiltIn::LineIndices,
833                                                )) => {
834                                                    topology = crate::MeshOutputTopology::Lines;
835                                                }
836                                                Some(crate::Binding::BuiltIn(
837                                                    crate::BuiltIn::TriangleIndices,
838                                                )) => {
839                                                    topology = crate::MeshOutputTopology::Triangles;
840                                                }
841                                                _ => (),
842                                            }
843                                        }
844                                    }
845                                    _ => (),
846                                }
847                                // Each builtin should only occur once
848                                if builtins.contains(&crate::BuiltIn::Primitives) {
849                                    error.set(EntryPointError::BadMeshOutputVariableType);
850                                }
851                                builtins.insert(crate::BuiltIn::Primitives);
852                            } else {
853                                vertex_info = (a, b, c);
854                                // Each builtin should only occur once
855                                if builtins.contains(&crate::BuiltIn::Vertices) {
856                                    error.set(EntryPointError::BadMeshOutputVariableType);
857                                }
858                                builtins.insert(crate::BuiltIn::Vertices);
859                            }
860                        }
861                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
862                    }
863                }
864                output = crate::MeshStageInfo {
865                    topology,
866                    max_vertices: vertex_info.0,
867                    max_vertices_override: None,
868                    vertex_output_type: vertex_info.2,
869                    max_primitives: primitive_info.0,
870                    max_primitives_override: None,
871                    primitive_output_type: primitive_info.2,
872                    ..output
873                }
874            }
875            _ => error.set(EntryPointError::BadMeshOutputVariableType),
876        }
877        (
878            output,
879            [vertex_info.1, primitive_info.1],
880            error
881                .inner
882                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
883        )
884    }
885
886    pub fn uses_mesh_shaders(&self) -> bool {
887        let binding_uses_mesh = |b: &crate::Binding| {
888            matches!(
889                b,
890                crate::Binding::BuiltIn(
891                    crate::BuiltIn::MeshTaskSize
892                        | crate::BuiltIn::CullPrimitive
893                        | crate::BuiltIn::PointIndex
894                        | crate::BuiltIn::LineIndices
895                        | crate::BuiltIn::TriangleIndices
896                        | crate::BuiltIn::VertexCount
897                        | crate::BuiltIn::Vertices
898                        | crate::BuiltIn::PrimitiveCount
899                        | crate::BuiltIn::Primitives,
900                ) | crate::Binding::Location {
901                    per_primitive: true,
902                    ..
903                }
904            )
905        };
906        for (_, ty) in self.types.iter() {
907            match ty.inner {
908                crate::TypeInner::Struct { ref members, .. } => {
909                    for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
910                        if binding_uses_mesh(binding) {
911                            return true;
912                        }
913                    }
914                }
915                _ => (),
916            }
917        }
918        for ep in &self.entry_points {
919            if matches!(
920                ep.stage,
921                crate::ShaderStage::Mesh | crate::ShaderStage::Task
922            ) {
923                return true;
924            }
925            for binding in ep
926                .function
927                .arguments
928                .iter()
929                .filter_map(|arg| arg.binding.as_ref())
930                .chain(
931                    ep.function
932                        .result
933                        .iter()
934                        .filter_map(|res| res.binding.as_ref()),
935                )
936            {
937                if binding_uses_mesh(binding) {
938                    return true;
939                }
940            }
941        }
942        if self
943            .global_variables
944            .iter()
945            .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
946        {
947            return true;
948        }
949        false
950    }
951}
952
953impl crate::MeshOutputTopology {
954    pub const fn to_builtin(self) -> crate::BuiltIn {
955        match self {
956            Self::Points => crate::BuiltIn::PointIndex,
957            Self::Lines => crate::BuiltIn::LineIndices,
958            Self::Triangles => crate::BuiltIn::TriangleIndices,
959        }
960    }
961}
962
963impl crate::AddressSpace {
964    pub const fn is_workgroup_like(self) -> bool {
965        matches!(self, Self::WorkGroup | Self::TaskPayload)
966    }
967}