naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U32(u32),
94    I32(i32),
95    U64(u64),
96    I64(i64),
97    Bool(bool),
98    AbstractInt(i64),
99    AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103    fn from(l: crate::Literal) -> Self {
104        match l {
105            crate::Literal::F64(v) => Self::F64(v.to_bits()),
106            crate::Literal::F32(v) => Self::F32(v.to_bits()),
107            crate::Literal::F16(v) => Self::F16(v.to_bits()),
108            crate::Literal::U32(v) => Self::U32(v),
109            crate::Literal::I32(v) => Self::I32(v),
110            crate::Literal::U64(v) => Self::U64(v),
111            crate::Literal::I64(v) => Self::I64(v),
112            crate::Literal::Bool(v) => Self::Bool(v),
113            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115        }
116    }
117}
118
119impl crate::Literal {
120    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121        match (value, scalar.kind, scalar.width) {
122            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
125            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
126            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
127            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
128            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
129            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
130            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
131            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
132            _ => None,
133        }
134    }
135
136    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
137        Self::new(0, scalar)
138    }
139
140    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
141        Self::new(1, scalar)
142    }
143
144    pub const fn width(&self) -> crate::Bytes {
145        match *self {
146            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
147            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
148            Self::F16(_) => 2,
149            Self::Bool(_) => crate::BOOL_WIDTH,
150            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
151        }
152    }
153    pub const fn scalar(&self) -> crate::Scalar {
154        match *self {
155            Self::F64(_) => crate::Scalar::F64,
156            Self::F32(_) => crate::Scalar::F32,
157            Self::F16(_) => crate::Scalar::F16,
158            Self::U32(_) => crate::Scalar::U32,
159            Self::I32(_) => crate::Scalar::I32,
160            Self::U64(_) => crate::Scalar::U64,
161            Self::I64(_) => crate::Scalar::I64,
162            Self::Bool(_) => crate::Scalar::BOOL,
163            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
164            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
165        }
166    }
167    pub const fn scalar_kind(&self) -> crate::ScalarKind {
168        self.scalar().kind
169    }
170    pub const fn ty_inner(&self) -> crate::TypeInner {
171        crate::TypeInner::Scalar(self.scalar())
172    }
173}
174
175impl super::AddressSpace {
176    pub fn access(self) -> crate::StorageAccess {
177        use crate::StorageAccess as Sa;
178        match self {
179            crate::AddressSpace::Function
180            | crate::AddressSpace::Private
181            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
182            crate::AddressSpace::Uniform => Sa::LOAD,
183            crate::AddressSpace::Storage { access } => access,
184            crate::AddressSpace::Handle => Sa::LOAD,
185            crate::AddressSpace::PushConstant => Sa::LOAD,
186            // TaskPayload isn't always writable, but this is checked for elsewhere,
187            // when not using multiple payloads and matching the entry payload is checked.
188            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
189        }
190    }
191}
192
193impl super::MathFunction {
194    pub const fn argument_count(&self) -> usize {
195        match *self {
196            // comparison
197            Self::Abs => 1,
198            Self::Min => 2,
199            Self::Max => 2,
200            Self::Clamp => 3,
201            Self::Saturate => 1,
202            // trigonometry
203            Self::Cos => 1,
204            Self::Cosh => 1,
205            Self::Sin => 1,
206            Self::Sinh => 1,
207            Self::Tan => 1,
208            Self::Tanh => 1,
209            Self::Acos => 1,
210            Self::Asin => 1,
211            Self::Atan => 1,
212            Self::Atan2 => 2,
213            Self::Asinh => 1,
214            Self::Acosh => 1,
215            Self::Atanh => 1,
216            Self::Radians => 1,
217            Self::Degrees => 1,
218            // decomposition
219            Self::Ceil => 1,
220            Self::Floor => 1,
221            Self::Round => 1,
222            Self::Fract => 1,
223            Self::Trunc => 1,
224            Self::Modf => 1,
225            Self::Frexp => 1,
226            Self::Ldexp => 2,
227            // exponent
228            Self::Exp => 1,
229            Self::Exp2 => 1,
230            Self::Log => 1,
231            Self::Log2 => 1,
232            Self::Pow => 2,
233            // geometry
234            Self::Dot => 2,
235            Self::Dot4I8Packed => 2,
236            Self::Dot4U8Packed => 2,
237            Self::Outer => 2,
238            Self::Cross => 2,
239            Self::Distance => 2,
240            Self::Length => 1,
241            Self::Normalize => 1,
242            Self::FaceForward => 3,
243            Self::Reflect => 2,
244            Self::Refract => 3,
245            // computational
246            Self::Sign => 1,
247            Self::Fma => 3,
248            Self::Mix => 3,
249            Self::Step => 2,
250            Self::SmoothStep => 3,
251            Self::Sqrt => 1,
252            Self::InverseSqrt => 1,
253            Self::Inverse => 1,
254            Self::Transpose => 1,
255            Self::Determinant => 1,
256            Self::QuantizeToF16 => 1,
257            // bits
258            Self::CountTrailingZeros => 1,
259            Self::CountLeadingZeros => 1,
260            Self::CountOneBits => 1,
261            Self::ReverseBits => 1,
262            Self::ExtractBits => 3,
263            Self::InsertBits => 4,
264            Self::FirstTrailingBit => 1,
265            Self::FirstLeadingBit => 1,
266            // data packing
267            Self::Pack4x8snorm => 1,
268            Self::Pack4x8unorm => 1,
269            Self::Pack2x16snorm => 1,
270            Self::Pack2x16unorm => 1,
271            Self::Pack2x16float => 1,
272            Self::Pack4xI8 => 1,
273            Self::Pack4xU8 => 1,
274            Self::Pack4xI8Clamp => 1,
275            Self::Pack4xU8Clamp => 1,
276            // data unpacking
277            Self::Unpack4x8snorm => 1,
278            Self::Unpack4x8unorm => 1,
279            Self::Unpack2x16snorm => 1,
280            Self::Unpack2x16unorm => 1,
281            Self::Unpack2x16float => 1,
282            Self::Unpack4xI8 => 1,
283            Self::Unpack4xU8 => 1,
284        }
285    }
286}
287
288impl crate::Expression {
289    /// Returns true if the expression is considered emitted at the start of a function.
290    pub const fn needs_pre_emit(&self) -> bool {
291        match *self {
292            Self::Literal(_)
293            | Self::Constant(_)
294            | Self::Override(_)
295            | Self::ZeroValue(_)
296            | Self::FunctionArgument(_)
297            | Self::GlobalVariable(_)
298            | Self::LocalVariable(_) => true,
299            _ => false,
300        }
301    }
302
303    /// Return true if this expression is a dynamic array/vector/matrix index,
304    /// for [`Access`].
305    ///
306    /// This method returns true if this expression is a dynamically computed
307    /// index, and as such can only be used to index matrices when they appear
308    /// behind a pointer. See the documentation for [`Access`] for details.
309    ///
310    /// Note, this does not check the _type_ of the given expression. It's up to
311    /// the caller to establish that the `Access` expression is well-typed
312    /// through other means, like [`ResolveContext`].
313    ///
314    /// [`Access`]: crate::Expression::Access
315    /// [`ResolveContext`]: crate::proc::ResolveContext
316    pub const fn is_dynamic_index(&self) -> bool {
317        match *self {
318            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
319            _ => true,
320        }
321    }
322}
323
324impl crate::Function {
325    /// Return the global variable being accessed by the expression `pointer`.
326    ///
327    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
328    /// expressions that ultimately access some part of a `GlobalVariable`,
329    /// return a handle for that global.
330    ///
331    /// If the expression does not ultimately access a global variable, return
332    /// `None`.
333    pub fn originating_global(
334        &self,
335        mut pointer: crate::Handle<crate::Expression>,
336    ) -> Option<crate::Handle<crate::GlobalVariable>> {
337        loop {
338            pointer = match self.expressions[pointer] {
339                crate::Expression::Access { base, .. } => base,
340                crate::Expression::AccessIndex { base, .. } => base,
341                crate::Expression::GlobalVariable(handle) => return Some(handle),
342                crate::Expression::LocalVariable(_) => return None,
343                crate::Expression::FunctionArgument(_) => return None,
344                // There are no other expressions that produce pointer values.
345                _ => unreachable!(),
346            }
347        }
348    }
349}
350
351impl crate::SampleLevel {
352    pub const fn implicit_derivatives(&self) -> bool {
353        match *self {
354            Self::Auto | Self::Bias(_) => true,
355            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
356        }
357    }
358}
359
360impl crate::Binding {
361    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
362        match *self {
363            crate::Binding::BuiltIn(built_in) => Some(built_in),
364            Self::Location { .. } => None,
365        }
366    }
367}
368
369impl super::SwizzleComponent {
370    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
371
372    pub const fn index(&self) -> u32 {
373        match *self {
374            Self::X => 0,
375            Self::Y => 1,
376            Self::Z => 2,
377            Self::W => 3,
378        }
379    }
380    pub const fn from_index(idx: u32) -> Self {
381        match idx {
382            0 => Self::X,
383            1 => Self::Y,
384            2 => Self::Z,
385            _ => Self::W,
386        }
387    }
388}
389
390impl super::ImageClass {
391    pub const fn is_multisampled(self) -> bool {
392        match self {
393            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
394            crate::ImageClass::Storage { .. } => false,
395            crate::ImageClass::External => false,
396        }
397    }
398
399    pub const fn is_mipmapped(self) -> bool {
400        match self {
401            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
402            crate::ImageClass::Storage { .. } => false,
403            crate::ImageClass::External => false,
404        }
405    }
406
407    pub const fn is_depth(self) -> bool {
408        matches!(self, crate::ImageClass::Depth { .. })
409    }
410}
411
412impl crate::Module {
413    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
414        GlobalCtx {
415            types: &self.types,
416            constants: &self.constants,
417            overrides: &self.overrides,
418            global_expressions: &self.global_expressions,
419        }
420    }
421
422    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
423        compare_types(lhs, rhs, &self.types)
424    }
425}
426
427#[derive(Debug)]
428pub(super) enum U32EvalError {
429    NonConst,
430    Negative,
431}
432
433#[derive(Clone, Copy)]
434pub struct GlobalCtx<'a> {
435    pub types: &'a crate::UniqueArena<crate::Type>,
436    pub constants: &'a crate::Arena<crate::Constant>,
437    pub overrides: &'a crate::Arena<crate::Override>,
438    pub global_expressions: &'a crate::Arena<crate::Expression>,
439}
440
441impl GlobalCtx<'_> {
442    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
443    #[allow(dead_code)]
444    pub(super) fn eval_expr_to_u32(
445        &self,
446        handle: crate::Handle<crate::Expression>,
447    ) -> Result<u32, U32EvalError> {
448        self.eval_expr_to_u32_from(handle, self.global_expressions)
449    }
450
451    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
452    pub(super) fn eval_expr_to_u32_from(
453        &self,
454        handle: crate::Handle<crate::Expression>,
455        arena: &crate::Arena<crate::Expression>,
456    ) -> Result<u32, U32EvalError> {
457        match self.eval_expr_to_literal_from(handle, arena) {
458            Some(crate::Literal::U32(value)) => Ok(value),
459            Some(crate::Literal::I32(value)) => {
460                value.try_into().map_err(|_| U32EvalError::Negative)
461            }
462            _ => Err(U32EvalError::NonConst),
463        }
464    }
465
466    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
467    #[allow(dead_code)]
468    pub(super) fn eval_expr_to_bool_from(
469        &self,
470        handle: crate::Handle<crate::Expression>,
471        arena: &crate::Arena<crate::Expression>,
472    ) -> Option<bool> {
473        match self.eval_expr_to_literal_from(handle, arena) {
474            Some(crate::Literal::Bool(value)) => Some(value),
475            _ => None,
476        }
477    }
478
479    #[allow(dead_code)]
480    pub(crate) fn eval_expr_to_literal(
481        &self,
482        handle: crate::Handle<crate::Expression>,
483    ) -> Option<crate::Literal> {
484        self.eval_expr_to_literal_from(handle, self.global_expressions)
485    }
486
487    pub(super) fn eval_expr_to_literal_from(
488        &self,
489        handle: crate::Handle<crate::Expression>,
490        arena: &crate::Arena<crate::Expression>,
491    ) -> Option<crate::Literal> {
492        fn get(
493            gctx: GlobalCtx,
494            handle: crate::Handle<crate::Expression>,
495            arena: &crate::Arena<crate::Expression>,
496        ) -> Option<crate::Literal> {
497            match arena[handle] {
498                crate::Expression::Literal(literal) => Some(literal),
499                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
500                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
501                    _ => None,
502                },
503                _ => None,
504            }
505        }
506        match arena[handle] {
507            crate::Expression::Constant(c) => {
508                get(*self, self.constants[c].init, self.global_expressions)
509            }
510            _ => get(*self, handle, arena),
511        }
512    }
513
514    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
515        compare_types(lhs, rhs, self.types)
516    }
517}
518
519#[derive(Error, Debug, Clone, Copy, PartialEq)]
520pub enum ResolveArraySizeError {
521    #[error("array element count must be positive (> 0)")]
522    ExpectedPositiveArrayLength,
523    #[error("internal: array size override has not been resolved")]
524    NonConstArrayLength,
525}
526
527impl crate::ArraySize {
528    /// Return the number of elements that `size` represents, if known at code generation time.
529    ///
530    /// If `size` is override-based, return an error unless the override's
531    /// initializer is a fully evaluated constant expression. You can call
532    /// [`pipeline_constants::process_overrides`] to supply values for a
533    /// module's overrides and ensure their initializers are fully evaluated, as
534    /// this function expects.
535    ///
536    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
537    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
538        match *self {
539            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
540            crate::ArraySize::Pending(handle) => {
541                let Some(expr) = gctx.overrides[handle].init else {
542                    return Err(ResolveArraySizeError::NonConstArrayLength);
543                };
544                let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
545                    U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
546                    U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
547                })?;
548
549                if length == 0 {
550                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
551                }
552
553                Ok(IndexableLength::Known(length))
554            }
555            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
556        }
557    }
558}
559
560/// Return an iterator over the individual components assembled by a
561/// `Compose` expression.
562///
563/// Given `ty` and `components` from an `Expression::Compose`, return an
564/// iterator over the components of the resulting value.
565///
566/// Normally, this would just be an iterator over `components`. However,
567/// `Compose` expressions can concatenate vectors, in which case the i'th
568/// value being composed is not generally the i'th element of `components`.
569/// This function consults `ty` to decide if this concatenation is occurring,
570/// and returns an iterator that produces the components of the result of
571/// the `Compose` expression in either case.
572pub fn flatten_compose<'arenas>(
573    ty: crate::Handle<crate::Type>,
574    components: &'arenas [crate::Handle<crate::Expression>],
575    expressions: &'arenas crate::Arena<crate::Expression>,
576    types: &'arenas crate::UniqueArena<crate::Type>,
577) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
578    // Returning `impl Iterator` is a bit tricky. We may or may not
579    // want to flatten the components, but we have to settle on a
580    // single concrete type to return. This function returns a single
581    // iterator chain that handles both the flattening and
582    // non-flattening cases.
583    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
584        (size as usize, true)
585    } else {
586        (components.len(), false)
587    };
588
589    /// Flatten `Compose` expressions if `is_vector` is true.
590    fn flatten_compose<'c>(
591        component: &'c crate::Handle<crate::Expression>,
592        is_vector: bool,
593        expressions: &'c crate::Arena<crate::Expression>,
594    ) -> &'c [crate::Handle<crate::Expression>] {
595        if is_vector {
596            if let crate::Expression::Compose {
597                ty: _,
598                components: ref subcomponents,
599            } = expressions[*component]
600            {
601                return subcomponents;
602            }
603        }
604        core::slice::from_ref(component)
605    }
606
607    /// Flatten `Splat` expressions if `is_vector` is true.
608    fn flatten_splat<'c>(
609        component: &'c crate::Handle<crate::Expression>,
610        is_vector: bool,
611        expressions: &'c crate::Arena<crate::Expression>,
612    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
613        let mut expr = *component;
614        let mut count = 1;
615        if is_vector {
616            if let crate::Expression::Splat { size, value } = expressions[expr] {
617                expr = value;
618                count = size as usize;
619            }
620        }
621        core::iter::repeat_n(expr, count)
622    }
623
624    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
625    // flatten up to two levels of `Compose` expressions.
626    //
627    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
628    // `Splat` expressions. Fortunately, the operand of a `Splat` must
629    // be a scalar, so we can stop there.
630    components
631        .iter()
632        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
633        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
634        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
635        .take(size)
636}
637
638impl super::ShaderStage {
639    pub const fn compute_like(self) -> bool {
640        match self {
641            Self::Vertex | Self::Fragment => false,
642            Self::Compute | Self::Task | Self::Mesh => true,
643        }
644    }
645}
646
647#[test]
648fn test_matrix_size() {
649    let module = crate::Module::default();
650    assert_eq!(
651        crate::TypeInner::Matrix {
652            columns: crate::VectorSize::Tri,
653            rows: crate::VectorSize::Tri,
654            scalar: crate::Scalar::F32,
655        }
656        .size(module.to_ctx()),
657        48,
658    );
659}
660
661impl crate::Module {
662    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
663    /// and by validators. This only validates the output variable itself, and not the
664    /// vertex and primitive output types.
665    ///
666    /// The output contains the extracted mesh stage info, with overrides unset,
667    /// and then the overrides separately. This is because the overrides should be
668    /// treated as expressions elsewhere, but that requires mutably modifying the
669    /// module and the expressions should only be created at parse time, not validation
670    /// time.
671    #[allow(clippy::type_complexity)]
672    pub fn analyze_mesh_shader_info(
673        &self,
674        gv: crate::Handle<crate::GlobalVariable>,
675    ) -> (
676        crate::MeshStageInfo,
677        [Option<crate::Handle<crate::Override>>; 2],
678        Option<crate::WithSpan<crate::valid::EntryPointError>>,
679    ) {
680        use crate::span::AddSpan;
681        use crate::valid::EntryPointError;
682        #[derive(Default)]
683        struct OutError {
684            pub inner: Option<EntryPointError>,
685        }
686        impl OutError {
687            pub fn set(&mut self, err: EntryPointError) {
688                if self.inner.is_none() {
689                    self.inner = Some(err);
690                }
691            }
692        }
693
694        // Used to temporarily initialize stuff
695        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
696        let mut output = crate::MeshStageInfo {
697            topology: crate::MeshOutputTopology::Triangles,
698            max_vertices: 0,
699            max_vertices_override: None,
700            max_primitives: 0,
701            max_primitives_override: None,
702            vertex_output_type: null_type,
703            primitive_output_type: null_type,
704            output_variable: gv,
705        };
706        // Stores the error to output, if any.
707        let mut error = OutError::default();
708        let r#type = &self.types[self.global_variables[gv].ty].inner;
709
710        let mut topology = output.topology;
711        // Max, max override, type
712        let mut vertex_info = (0, None, null_type);
713        let mut primitive_info = (0, None, null_type);
714
715        match r#type {
716            &crate::TypeInner::Struct { ref members, .. } => {
717                let mut builtins = crate::FastHashSet::default();
718                for member in members {
719                    match member.binding {
720                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
721                            // Must have type u32
722                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
723                                error.set(EntryPointError::BadMeshOutputVariableField);
724                            }
725                            // Each builtin should only occur once
726                            if builtins.contains(&crate::BuiltIn::VertexCount) {
727                                error.set(EntryPointError::BadMeshOutputVariableType);
728                            }
729                            builtins.insert(crate::BuiltIn::VertexCount);
730                        }
731                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
732                            // Must have type u32
733                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
734                                error.set(EntryPointError::BadMeshOutputVariableField);
735                            }
736                            // Each builtin should only occur once
737                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
738                                error.set(EntryPointError::BadMeshOutputVariableType);
739                            }
740                            builtins.insert(crate::BuiltIn::PrimitiveCount);
741                        }
742                        Some(crate::Binding::BuiltIn(
743                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
744                        )) => {
745                            let ty = &self.types[member.ty].inner;
746                            // Analyze the array type to determine size and vertex/primitive type
747                            let (a, b, c) = match ty {
748                                &crate::TypeInner::Array { base, size, .. } => {
749                                    let ty = base;
750                                    let (max, max_override) = match size {
751                                        crate::ArraySize::Constant(a) => (a.get(), None),
752                                        crate::ArraySize::Pending(o) => (0, Some(o)),
753                                        crate::ArraySize::Dynamic => {
754                                            error.set(EntryPointError::BadMeshOutputVariableField);
755                                            (0, None)
756                                        }
757                                    };
758                                    (max, max_override, ty)
759                                }
760                                _ => {
761                                    error.set(EntryPointError::BadMeshOutputVariableField);
762                                    (0, None, null_type)
763                                }
764                            };
765                            if matches!(
766                                member.binding,
767                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
768                            ) {
769                                // Primitives require special analysis to determine topology
770                                primitive_info = (a, b, c);
771                                match self.types[c].inner {
772                                    crate::TypeInner::Struct { ref members, .. } => {
773                                        for member in members {
774                                            match member.binding {
775                                                Some(crate::Binding::BuiltIn(
776                                                    crate::BuiltIn::PointIndex,
777                                                )) => {
778                                                    topology = crate::MeshOutputTopology::Points;
779                                                }
780                                                Some(crate::Binding::BuiltIn(
781                                                    crate::BuiltIn::LineIndices,
782                                                )) => {
783                                                    topology = crate::MeshOutputTopology::Lines;
784                                                }
785                                                Some(crate::Binding::BuiltIn(
786                                                    crate::BuiltIn::TriangleIndices,
787                                                )) => {
788                                                    topology = crate::MeshOutputTopology::Triangles;
789                                                }
790                                                _ => (),
791                                            }
792                                        }
793                                    }
794                                    _ => (),
795                                }
796                                // Each builtin should only occur once
797                                if builtins.contains(&crate::BuiltIn::Primitives) {
798                                    error.set(EntryPointError::BadMeshOutputVariableType);
799                                }
800                                builtins.insert(crate::BuiltIn::Primitives);
801                            } else {
802                                vertex_info = (a, b, c);
803                                // Each builtin should only occur once
804                                if builtins.contains(&crate::BuiltIn::Vertices) {
805                                    error.set(EntryPointError::BadMeshOutputVariableType);
806                                }
807                                builtins.insert(crate::BuiltIn::Vertices);
808                            }
809                        }
810                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
811                    }
812                }
813                output = crate::MeshStageInfo {
814                    topology,
815                    max_vertices: vertex_info.0,
816                    max_vertices_override: None,
817                    vertex_output_type: vertex_info.2,
818                    max_primitives: primitive_info.0,
819                    max_primitives_override: None,
820                    primitive_output_type: primitive_info.2,
821                    ..output
822                }
823            }
824            _ => error.set(EntryPointError::BadMeshOutputVariableType),
825        }
826        (
827            output,
828            [vertex_info.1, primitive_info.1],
829            error
830                .inner
831                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
832        )
833    }
834}