naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28    concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35    fn from(format: super::StorageFormat) -> Self {
36        use super::{ScalarKind as Sk, StorageFormat as Sf};
37        let kind = match format {
38            Sf::R8Unorm => Sk::Float,
39            Sf::R8Snorm => Sk::Float,
40            Sf::R8Uint => Sk::Uint,
41            Sf::R8Sint => Sk::Sint,
42            Sf::R16Uint => Sk::Uint,
43            Sf::R16Sint => Sk::Sint,
44            Sf::R16Float => Sk::Float,
45            Sf::Rg8Unorm => Sk::Float,
46            Sf::Rg8Snorm => Sk::Float,
47            Sf::Rg8Uint => Sk::Uint,
48            Sf::Rg8Sint => Sk::Sint,
49            Sf::R32Uint => Sk::Uint,
50            Sf::R32Sint => Sk::Sint,
51            Sf::R32Float => Sk::Float,
52            Sf::Rg16Uint => Sk::Uint,
53            Sf::Rg16Sint => Sk::Sint,
54            Sf::Rg16Float => Sk::Float,
55            Sf::Rgba8Unorm => Sk::Float,
56            Sf::Rgba8Snorm => Sk::Float,
57            Sf::Rgba8Uint => Sk::Uint,
58            Sf::Rgba8Sint => Sk::Sint,
59            Sf::Bgra8Unorm => Sk::Float,
60            Sf::Rgb10a2Uint => Sk::Uint,
61            Sf::Rgb10a2Unorm => Sk::Float,
62            Sf::Rg11b10Ufloat => Sk::Float,
63            Sf::R64Uint => Sk::Uint,
64            Sf::Rg32Uint => Sk::Uint,
65            Sf::Rg32Sint => Sk::Sint,
66            Sf::Rg32Float => Sk::Float,
67            Sf::Rgba16Uint => Sk::Uint,
68            Sf::Rgba16Sint => Sk::Sint,
69            Sf::Rgba16Float => Sk::Float,
70            Sf::Rgba32Uint => Sk::Uint,
71            Sf::Rgba32Sint => Sk::Sint,
72            Sf::Rgba32Float => Sk::Float,
73            Sf::R16Unorm => Sk::Float,
74            Sf::R16Snorm => Sk::Float,
75            Sf::Rg16Unorm => Sk::Float,
76            Sf::Rg16Snorm => Sk::Float,
77            Sf::Rgba16Unorm => Sk::Float,
78            Sf::Rgba16Snorm => Sk::Float,
79        };
80        let width = match format {
81            Sf::R64Uint => 8,
82            _ => 4,
83        };
84        super::Scalar { kind, width }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90    F64(u64),
91    F32(u32),
92    F16(u16),
93    U16(u16),
94    I16(i16),
95    U32(u32),
96    I32(i32),
97    U64(u64),
98    I64(i64),
99    Bool(bool),
100    AbstractInt(i64),
101    AbstractFloat(u64),
102}
103
104impl From<crate::Literal> for HashableLiteral {
105    fn from(l: crate::Literal) -> Self {
106        match l {
107            crate::Literal::F64(v) => Self::F64(v.to_bits()),
108            crate::Literal::F32(v) => Self::F32(v.to_bits()),
109            crate::Literal::F16(v) => Self::F16(v.to_bits()),
110            crate::Literal::U16(v) => Self::U16(v),
111            crate::Literal::I16(v) => Self::I16(v),
112            crate::Literal::U32(v) => Self::U32(v),
113            crate::Literal::I32(v) => Self::I32(v),
114            crate::Literal::U64(v) => Self::U64(v),
115            crate::Literal::I64(v) => Self::I64(v),
116            crate::Literal::Bool(v) => Self::Bool(v),
117            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
118            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
119        }
120    }
121}
122
123impl crate::Literal {
124    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
125        match (value, scalar.kind, scalar.width) {
126            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
127            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
128            (value, crate::ScalarKind::Float, 2) => {
129                Some(Self::F16(half::f16::from_f32_const(value as _)))
130            }
131            (value, crate::ScalarKind::Uint, 2) => Some(Self::U16(value as _)),
132            (value, crate::ScalarKind::Sint, 2) => Some(Self::I16(value as _)),
133            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
134            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
135            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
136            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
137            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
138            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
139            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
140            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
141            _ => None,
142        }
143    }
144
145    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
146        Self::new(0, scalar)
147    }
148
149    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
150        Self::new(1, scalar)
151    }
152
153    pub const fn minus_one(scalar: crate::Scalar) -> Option<Self> {
154        match (scalar.kind, scalar.width) {
155            (crate::ScalarKind::Float, 8) => Some(Self::F64(-1.0)),
156            (crate::ScalarKind::Float, 4) => Some(Self::F32(-1.0)),
157            (crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))),
158            (crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)),
159            (crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)),
160            (crate::ScalarKind::Sint, 2) => Some(Self::I16(-1)),
161            (crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)),
162            _ => None,
163        }
164    }
165
166    pub const fn width(&self) -> crate::Bytes {
167        match *self {
168            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
169            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
170            Self::F16(_) | Self::U16(_) | Self::I16(_) => 2,
171            Self::Bool(_) => crate::BOOL_WIDTH,
172            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
173        }
174    }
175    pub const fn scalar(&self) -> crate::Scalar {
176        match *self {
177            Self::F64(_) => crate::Scalar::F64,
178            Self::F32(_) => crate::Scalar::F32,
179            Self::F16(_) => crate::Scalar::F16,
180            Self::U16(_) => crate::Scalar::U16,
181            Self::I16(_) => crate::Scalar::I16,
182            Self::U32(_) => crate::Scalar::U32,
183            Self::I32(_) => crate::Scalar::I32,
184            Self::U64(_) => crate::Scalar::U64,
185            Self::I64(_) => crate::Scalar::I64,
186            Self::Bool(_) => crate::Scalar::BOOL,
187            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
188            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
189        }
190    }
191    pub const fn scalar_kind(&self) -> crate::ScalarKind {
192        self.scalar().kind
193    }
194    pub const fn ty_inner(&self) -> crate::TypeInner {
195        crate::TypeInner::Scalar(self.scalar())
196    }
197}
198
199impl TryFrom<crate::Literal> for u32 {
200    type Error = ConstValueError;
201
202    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
203        match value {
204            crate::Literal::U16(value) => Ok(value as u32),
205            crate::Literal::I16(value) => value.try_into().map_err(|_| ConstValueError::Negative),
206            crate::Literal::U32(value) => Ok(value),
207            crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
208            _ => Err(ConstValueError::InvalidType),
209        }
210    }
211}
212
213impl TryFrom<crate::Literal> for bool {
214    type Error = ConstValueError;
215
216    fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
217        match value {
218            crate::Literal::Bool(value) => Ok(value),
219            _ => Err(ConstValueError::InvalidType),
220        }
221    }
222}
223
224impl super::AddressSpace {
225    pub fn access(self) -> crate::StorageAccess {
226        use crate::StorageAccess as Sa;
227        match self {
228            crate::AddressSpace::Function
229            | crate::AddressSpace::Private
230            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
231            crate::AddressSpace::Uniform => Sa::LOAD,
232            crate::AddressSpace::Storage { access } => access,
233            crate::AddressSpace::Handle => Sa::LOAD,
234            crate::AddressSpace::Immediate => Sa::LOAD,
235            // TaskPayload isn't always writable, but this is checked for elsewhere,
236            // when not using multiple payloads and matching the entry payload is checked.
237            crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
238            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
239                Sa::LOAD | Sa::STORE
240            }
241        }
242    }
243}
244
245impl super::MathFunction {
246    pub const fn argument_count(&self) -> usize {
247        match *self {
248            // comparison
249            Self::Abs => 1,
250            Self::Min => 2,
251            Self::Max => 2,
252            Self::Clamp => 3,
253            Self::Saturate => 1,
254            // trigonometry
255            Self::Cos => 1,
256            Self::Cosh => 1,
257            Self::Sin => 1,
258            Self::Sinh => 1,
259            Self::Tan => 1,
260            Self::Tanh => 1,
261            Self::Acos => 1,
262            Self::Asin => 1,
263            Self::Atan => 1,
264            Self::Atan2 => 2,
265            Self::Asinh => 1,
266            Self::Acosh => 1,
267            Self::Atanh => 1,
268            Self::Radians => 1,
269            Self::Degrees => 1,
270            // decomposition
271            Self::Ceil => 1,
272            Self::Floor => 1,
273            Self::Round => 1,
274            Self::Fract => 1,
275            Self::Trunc => 1,
276            Self::Modf => 1,
277            Self::Frexp => 1,
278            Self::Ldexp => 2,
279            // exponent
280            Self::Exp => 1,
281            Self::Exp2 => 1,
282            Self::Log => 1,
283            Self::Log2 => 1,
284            Self::Pow => 2,
285            // geometry
286            Self::Dot => 2,
287            Self::Dot4I8Packed => 2,
288            Self::Dot4U8Packed => 2,
289            Self::Outer => 2,
290            Self::Cross => 2,
291            Self::Distance => 2,
292            Self::Length => 1,
293            Self::Normalize => 1,
294            Self::FaceForward => 3,
295            Self::Reflect => 2,
296            Self::Refract => 3,
297            // computational
298            Self::Sign => 1,
299            Self::Fma => 3,
300            Self::Mix => 3,
301            Self::Step => 2,
302            Self::SmoothStep => 3,
303            Self::Sqrt => 1,
304            Self::InverseSqrt => 1,
305            Self::Inverse => 1,
306            Self::Transpose => 1,
307            Self::Determinant => 1,
308            Self::QuantizeToF16 => 1,
309            // bits
310            Self::CountTrailingZeros => 1,
311            Self::CountLeadingZeros => 1,
312            Self::CountOneBits => 1,
313            Self::ReverseBits => 1,
314            Self::ExtractBits => 3,
315            Self::InsertBits => 4,
316            Self::FirstTrailingBit => 1,
317            Self::FirstLeadingBit => 1,
318            // data packing
319            Self::Pack4x8snorm => 1,
320            Self::Pack4x8unorm => 1,
321            Self::Pack2x16snorm => 1,
322            Self::Pack2x16unorm => 1,
323            Self::Pack2x16float => 1,
324            Self::Pack4xI8 => 1,
325            Self::Pack4xU8 => 1,
326            Self::Pack4xI8Clamp => 1,
327            Self::Pack4xU8Clamp => 1,
328            // data unpacking
329            Self::Unpack4x8snorm => 1,
330            Self::Unpack4x8unorm => 1,
331            Self::Unpack2x16snorm => 1,
332            Self::Unpack2x16unorm => 1,
333            Self::Unpack2x16float => 1,
334            Self::Unpack4xI8 => 1,
335            Self::Unpack4xU8 => 1,
336        }
337    }
338}
339
340impl crate::Expression {
341    /// Returns true if the expression is considered emitted at the start of a function.
342    pub const fn needs_pre_emit(&self) -> bool {
343        match *self {
344            Self::Literal(_)
345            | Self::Constant(_)
346            | Self::Override(_)
347            | Self::ZeroValue(_)
348            | Self::FunctionArgument(_)
349            | Self::GlobalVariable(_)
350            | Self::LocalVariable(_) => true,
351            _ => false,
352        }
353    }
354
355    /// Return true if this expression is a dynamic array/vector/matrix index,
356    /// for [`Access`].
357    ///
358    /// This method returns true if this expression is a dynamically computed
359    /// index, and as such can only be used to index matrices when they appear
360    /// behind a pointer. See the documentation for [`Access`] for details.
361    ///
362    /// Note, this does not check the _type_ of the given expression. It's up to
363    /// the caller to establish that the `Access` expression is well-typed
364    /// through other means, like [`ResolveContext`].
365    ///
366    /// [`Access`]: crate::Expression::Access
367    /// [`ResolveContext`]: crate::proc::ResolveContext
368    pub const fn is_dynamic_index(&self) -> bool {
369        match *self {
370            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
371            _ => true,
372        }
373    }
374}
375
376impl crate::Function {
377    /// Return the global variable being accessed by the expression `pointer`.
378    ///
379    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
380    /// expressions that ultimately access some part of a `GlobalVariable`,
381    /// return a handle for that global.
382    ///
383    /// If the expression does not ultimately access a global variable, return
384    /// `None`.
385    pub fn originating_global(
386        &self,
387        mut pointer: crate::Handle<crate::Expression>,
388    ) -> Option<crate::Handle<crate::GlobalVariable>> {
389        loop {
390            pointer = match self.expressions[pointer] {
391                crate::Expression::Access { base, .. } => base,
392                crate::Expression::AccessIndex { base, .. } => base,
393                crate::Expression::GlobalVariable(handle) => return Some(handle),
394                // Other expressions are not on this path to a global.
395                _ => return None,
396            }
397        }
398    }
399}
400
401impl crate::SampleLevel {
402    pub const fn implicit_derivatives(&self) -> bool {
403        match *self {
404            Self::Auto | Self::Bias(_) => true,
405            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
406        }
407    }
408}
409
410impl crate::Binding {
411    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
412        match *self {
413            crate::Binding::BuiltIn(built_in) => Some(built_in),
414            Self::Location { .. } => None,
415        }
416    }
417}
418
419impl super::SwizzleComponent {
420    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
421
422    pub const fn index(&self) -> u32 {
423        match *self {
424            Self::X => 0,
425            Self::Y => 1,
426            Self::Z => 2,
427            Self::W => 3,
428        }
429    }
430    pub const fn from_index(idx: u32) -> Self {
431        match idx {
432            0 => Self::X,
433            1 => Self::Y,
434            2 => Self::Z,
435            _ => Self::W,
436        }
437    }
438}
439
440impl super::ImageClass {
441    pub const fn is_multisampled(self) -> bool {
442        match self {
443            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
444            crate::ImageClass::Storage { .. } => false,
445            crate::ImageClass::External => false,
446        }
447    }
448
449    pub const fn is_mipmapped(self) -> bool {
450        match self {
451            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
452            crate::ImageClass::Storage { .. } => false,
453            crate::ImageClass::External => false,
454        }
455    }
456
457    pub const fn is_depth(self) -> bool {
458        matches!(self, crate::ImageClass::Depth { .. })
459    }
460}
461
462impl crate::Module {
463    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
464        GlobalCtx {
465            types: &self.types,
466            constants: &self.constants,
467            overrides: &self.overrides,
468            global_expressions: &self.global_expressions,
469        }
470    }
471
472    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
473        compare_types(lhs, rhs, &self.types)
474    }
475}
476
477#[derive(Debug)]
478pub enum ConstValueError {
479    NonConst,
480    Negative,
481    InvalidType,
482}
483
484impl From<core::convert::Infallible> for ConstValueError {
485    fn from(_: core::convert::Infallible) -> Self {
486        unreachable!()
487    }
488}
489
490#[derive(Clone, Copy, Debug)]
491pub struct GlobalCtx<'a> {
492    pub types: &'a crate::UniqueArena<crate::Type>,
493    pub constants: &'a crate::Arena<crate::Constant>,
494    pub overrides: &'a crate::Arena<crate::Override>,
495    pub global_expressions: &'a crate::Arena<crate::Expression>,
496}
497
498impl GlobalCtx<'_> {
499    /// Try to evaluate the expression in `self.global_expressions` using its `handle`
500    /// and return it as a `T: TryFrom<ir::Literal>`.
501    ///
502    /// This currently only evaluates scalar expressions. If adding support for vectors,
503    /// consider changing `valid::expression::validate_constant_shift_amounts` to use that
504    /// support.
505    #[cfg_attr(
506        not(any(
507            feature = "glsl-in",
508            feature = "spv-in",
509            feature = "wgsl-in",
510            glsl_out,
511            hlsl_out,
512            msl_out,
513            wgsl_out
514        )),
515        allow(dead_code)
516    )]
517    pub(super) fn get_const_val<T, E>(
518        &self,
519        handle: crate::Handle<crate::Expression>,
520    ) -> Result<T, ConstValueError>
521    where
522        T: TryFrom<crate::Literal, Error = E>,
523        E: Into<ConstValueError>,
524    {
525        self.get_const_val_from(handle, self.global_expressions)
526    }
527
528    pub(super) fn get_const_val_from<T, E>(
529        &self,
530        handle: crate::Handle<crate::Expression>,
531        arena: &crate::Arena<crate::Expression>,
532    ) -> Result<T, ConstValueError>
533    where
534        T: TryFrom<crate::Literal, Error = E>,
535        E: Into<ConstValueError>,
536    {
537        fn get(
538            gctx: GlobalCtx,
539            handle: crate::Handle<crate::Expression>,
540            arena: &crate::Arena<crate::Expression>,
541        ) -> Option<crate::Literal> {
542            match arena[handle] {
543                crate::Expression::Literal(literal) => Some(literal),
544                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
545                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
546                    _ => None,
547                },
548                _ => None,
549            }
550        }
551        let value = match arena[handle] {
552            crate::Expression::Constant(c) => {
553                get(*self, self.constants[c].init, self.global_expressions)
554            }
555            _ => get(*self, handle, arena),
556        };
557        match value {
558            Some(v) => v.try_into().map_err(Into::into),
559            None => Err(ConstValueError::NonConst),
560        }
561    }
562
563    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
564        compare_types(lhs, rhs, self.types)
565    }
566}
567
568#[derive(Error, Debug, Clone, Copy, PartialEq)]
569pub enum ResolveArraySizeError {
570    #[error("array element count must be positive (> 0)")]
571    ExpectedPositiveArrayLength,
572    #[error("internal: array size override has not been resolved")]
573    NonConstArrayLength,
574}
575
576impl crate::ArraySize {
577    /// Return the number of elements that `size` represents, if known at code generation time.
578    ///
579    /// If `size` is override-based, return an error unless the override's
580    /// initializer is a fully evaluated constant expression. You can call
581    /// [`pipeline_constants::process_overrides`] to supply values for a
582    /// module's overrides and ensure their initializers are fully evaluated, as
583    /// this function expects.
584    ///
585    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
586    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
587        match *self {
588            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
589            crate::ArraySize::Pending(handle) => {
590                let Some(expr) = gctx.overrides[handle].init else {
591                    return Err(ResolveArraySizeError::NonConstArrayLength);
592                };
593                let length = gctx.get_const_val(expr).map_err(|err| match err {
594                    ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
595                    ConstValueError::Negative | ConstValueError::InvalidType => {
596                        ResolveArraySizeError::ExpectedPositiveArrayLength
597                    }
598                })?;
599
600                if length == 0 {
601                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
602                }
603
604                Ok(IndexableLength::Known(length))
605            }
606            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
607        }
608    }
609}
610
611/// Return an iterator over the individual components assembled by a
612/// `Compose` expression.
613///
614/// Given `ty` and `components` from an `Expression::Compose`, return an
615/// iterator over the components of the resulting value.
616///
617/// Normally, this would just be an iterator over `components`. However,
618/// `Compose` expressions can concatenate vectors, in which case the i'th
619/// value being composed is not generally the i'th element of `components`.
620/// This function consults `ty` to decide if this concatenation is occurring,
621/// and returns an iterator that produces the components of the result of
622/// the `Compose` expression in either case.
623pub fn flatten_compose<'arenas>(
624    ty: crate::Handle<crate::Type>,
625    components: &'arenas [crate::Handle<crate::Expression>],
626    expressions: &'arenas crate::Arena<crate::Expression>,
627    types: &'arenas crate::UniqueArena<crate::Type>,
628) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
629    // Returning `impl Iterator` is a bit tricky. We may or may not
630    // want to flatten the components, but we have to settle on a
631    // single concrete type to return. This function returns a single
632    // iterator chain that handles both the flattening and
633    // non-flattening cases.
634    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
635        (size as usize, true)
636    } else {
637        (components.len(), false)
638    };
639
640    /// Flatten `Compose` expressions if `is_vector` is true.
641    fn flatten_compose<'c>(
642        component: &'c crate::Handle<crate::Expression>,
643        is_vector: bool,
644        expressions: &'c crate::Arena<crate::Expression>,
645    ) -> &'c [crate::Handle<crate::Expression>] {
646        if is_vector {
647            if let crate::Expression::Compose {
648                ty: _,
649                components: ref subcomponents,
650            } = expressions[*component]
651            {
652                return subcomponents;
653            }
654        }
655        core::slice::from_ref(component)
656    }
657
658    /// Flatten `Splat` expressions if `is_vector` is true.
659    fn flatten_splat<'c>(
660        component: &'c crate::Handle<crate::Expression>,
661        is_vector: bool,
662        expressions: &'c crate::Arena<crate::Expression>,
663    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
664        let mut expr = *component;
665        let mut count = 1;
666        if is_vector {
667            if let crate::Expression::Splat { size, value } = expressions[expr] {
668                expr = value;
669                count = size as usize;
670            }
671        }
672        core::iter::repeat_n(expr, count)
673    }
674
675    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
676    // flatten up to two levels of `Compose` expressions.
677    //
678    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
679    // `Splat` expressions. Fortunately, the operand of a `Splat` must
680    // be a scalar, so we can stop there.
681    components
682        .iter()
683        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
684        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
685        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
686        .take(size)
687}
688
689#[test]
690fn test_matrix_size() {
691    let module = crate::Module::default();
692    assert_eq!(
693        crate::TypeInner::Matrix {
694            columns: crate::VectorSize::Tri,
695            rows: crate::VectorSize::Tri,
696            scalar: crate::Scalar::F32,
697        }
698        .size(module.to_ctx()),
699        48,
700    );
701}
702
703impl crate::Module {
704    /// Extracts mesh shader info from a mesh output global variable. Used in frontends
705    /// and by validators. This only validates the output variable itself, and not the
706    /// vertex and primitive output types.
707    ///
708    /// The output contains the extracted mesh stage info, with overrides unset,
709    /// and then the overrides separately. This is because the overrides should be
710    /// treated as expressions elsewhere, but that requires mutably modifying the
711    /// module and the expressions should only be created at parse time, not validation
712    /// time.
713    #[allow(clippy::type_complexity)]
714    pub fn analyze_mesh_shader_info(
715        &self,
716        gv: crate::Handle<crate::GlobalVariable>,
717    ) -> (
718        crate::MeshStageInfo,
719        [Option<crate::Handle<crate::Override>>; 2],
720        Option<crate::WithSpan<crate::valid::EntryPointError>>,
721    ) {
722        use crate::span::AddSpan;
723        use crate::valid::EntryPointError;
724        #[derive(Default)]
725        struct OutError {
726            pub inner: Option<EntryPointError>,
727        }
728        impl OutError {
729            pub fn set(&mut self, err: EntryPointError) {
730                if self.inner.is_none() {
731                    self.inner = Some(err);
732                }
733            }
734        }
735
736        // Used to temporarily initialize stuff
737        let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
738        let mut output = crate::MeshStageInfo {
739            topology: crate::MeshOutputTopology::Triangles,
740            max_vertices: 0,
741            max_vertices_override: None,
742            max_primitives: 0,
743            max_primitives_override: None,
744            vertex_output_type: null_type,
745            primitive_output_type: null_type,
746            output_variable: gv,
747        };
748        // Stores the error to output, if any.
749        let mut error = OutError::default();
750        let r#type = &self.types[self.global_variables[gv].ty].inner;
751
752        let mut topology = output.topology;
753        // Max, max override, type
754        let mut vertex_info = (0, None, null_type);
755        let mut primitive_info = (0, None, null_type);
756
757        match r#type {
758            &crate::TypeInner::Struct { ref members, .. } => {
759                let mut builtins = crate::FastHashSet::default();
760                for member in members {
761                    match member.binding {
762                        Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
763                            // Must have type u32
764                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
765                                error.set(EntryPointError::BadMeshOutputVariableField);
766                            }
767                            // Each builtin should only occur once
768                            if builtins.contains(&crate::BuiltIn::VertexCount) {
769                                error.set(EntryPointError::BadMeshOutputVariableType);
770                            }
771                            builtins.insert(crate::BuiltIn::VertexCount);
772                        }
773                        Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
774                            // Must have type u32
775                            if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
776                                error.set(EntryPointError::BadMeshOutputVariableField);
777                            }
778                            // Each builtin should only occur once
779                            if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
780                                error.set(EntryPointError::BadMeshOutputVariableType);
781                            }
782                            builtins.insert(crate::BuiltIn::PrimitiveCount);
783                        }
784                        Some(crate::Binding::BuiltIn(
785                            crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
786                        )) => {
787                            let ty = &self.types[member.ty].inner;
788                            // Analyze the array type to determine size and vertex/primitive type
789                            let (a, b, c) = match ty {
790                                &crate::TypeInner::Array { base, size, .. } => {
791                                    let ty = base;
792                                    let (max, max_override) = match size {
793                                        crate::ArraySize::Constant(a) => (a.get(), None),
794                                        crate::ArraySize::Pending(o) => (0, Some(o)),
795                                        crate::ArraySize::Dynamic => {
796                                            error.set(EntryPointError::BadMeshOutputVariableField);
797                                            (0, None)
798                                        }
799                                    };
800                                    (max, max_override, ty)
801                                }
802                                _ => {
803                                    error.set(EntryPointError::BadMeshOutputVariableField);
804                                    (0, None, null_type)
805                                }
806                            };
807                            if matches!(
808                                member.binding,
809                                Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
810                            ) {
811                                // Primitives require special analysis to determine topology
812                                primitive_info = (a, b, c);
813                                match self.types[c].inner {
814                                    crate::TypeInner::Struct { ref members, .. } => {
815                                        for member in members {
816                                            match member.binding {
817                                                Some(crate::Binding::BuiltIn(
818                                                    crate::BuiltIn::PointIndex,
819                                                )) => {
820                                                    topology = crate::MeshOutputTopology::Points;
821                                                }
822                                                Some(crate::Binding::BuiltIn(
823                                                    crate::BuiltIn::LineIndices,
824                                                )) => {
825                                                    topology = crate::MeshOutputTopology::Lines;
826                                                }
827                                                Some(crate::Binding::BuiltIn(
828                                                    crate::BuiltIn::TriangleIndices,
829                                                )) => {
830                                                    topology = crate::MeshOutputTopology::Triangles;
831                                                }
832                                                _ => (),
833                                            }
834                                        }
835                                    }
836                                    _ => (),
837                                }
838                                // Each builtin should only occur once
839                                if builtins.contains(&crate::BuiltIn::Primitives) {
840                                    error.set(EntryPointError::BadMeshOutputVariableType);
841                                }
842                                builtins.insert(crate::BuiltIn::Primitives);
843                            } else {
844                                vertex_info = (a, b, c);
845                                // Each builtin should only occur once
846                                if builtins.contains(&crate::BuiltIn::Vertices) {
847                                    error.set(EntryPointError::BadMeshOutputVariableType);
848                                }
849                                builtins.insert(crate::BuiltIn::Vertices);
850                            }
851                        }
852                        _ => error.set(EntryPointError::BadMeshOutputVariableType),
853                    }
854                }
855                output = crate::MeshStageInfo {
856                    topology,
857                    max_vertices: vertex_info.0,
858                    max_vertices_override: None,
859                    vertex_output_type: vertex_info.2,
860                    max_primitives: primitive_info.0,
861                    max_primitives_override: None,
862                    primitive_output_type: primitive_info.2,
863                    ..output
864                }
865            }
866            _ => error.set(EntryPointError::BadMeshOutputVariableType),
867        }
868        (
869            output,
870            [vertex_info.1, primitive_info.1],
871            error
872                .inner
873                .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
874        )
875    }
876
877    pub fn uses_mesh_shaders(&self) -> bool {
878        let binding_uses_mesh = |b: &crate::Binding| {
879            matches!(
880                b,
881                crate::Binding::BuiltIn(
882                    crate::BuiltIn::MeshTaskSize
883                        | crate::BuiltIn::CullPrimitive
884                        | crate::BuiltIn::PointIndex
885                        | crate::BuiltIn::LineIndices
886                        | crate::BuiltIn::TriangleIndices
887                        | crate::BuiltIn::VertexCount
888                        | crate::BuiltIn::Vertices
889                        | crate::BuiltIn::PrimitiveCount
890                        | crate::BuiltIn::Primitives,
891                ) | crate::Binding::Location {
892                    per_primitive: true,
893                    ..
894                }
895            )
896        };
897        for (_, ty) in self.types.iter() {
898            match ty.inner {
899                crate::TypeInner::Struct { ref members, .. } => {
900                    for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
901                        if binding_uses_mesh(binding) {
902                            return true;
903                        }
904                    }
905                }
906                _ => (),
907            }
908        }
909        for ep in &self.entry_points {
910            if matches!(
911                ep.stage,
912                crate::ShaderStage::Mesh | crate::ShaderStage::Task
913            ) {
914                return true;
915            }
916            for binding in ep
917                .function
918                .arguments
919                .iter()
920                .filter_map(|arg| arg.binding.as_ref())
921                .chain(
922                    ep.function
923                        .result
924                        .iter()
925                        .filter_map(|res| res.binding.as_ref()),
926                )
927            {
928                if binding_uses_mesh(binding) {
929                    return true;
930                }
931            }
932        }
933        if self
934            .global_variables
935            .iter()
936            .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
937        {
938            return true;
939        }
940        false
941    }
942
943    pub fn uses_ray_tracing(&self, ep_index: Option<usize>) -> RayTracingUses {
944        let mut uses = RayTracingUses::default();
945        // Whether this uses ray tracing (unknown whether the usage is pipelines or ray queries).
946        let mut uses_ray_tracing = self.special_types.ray_desc.is_some();
947
948        uses.queries |= self.special_types.ray_intersection.is_some();
949
950        for (_, &crate::Type { ref inner, .. }) in self.types.iter() {
951            // Backends do not know whether these have vertex return - that is done by us
952            match *inner {
953                crate::TypeInner::AccelerationStructure { .. } => {
954                    uses_ray_tracing = true;
955                }
956                crate::TypeInner::RayQuery { .. } => uses.queries = true,
957                _ => {}
958            }
959        }
960
961        for (index, ep) in self.entry_points.iter().enumerate() {
962            if ep_index.is_some() && ep_index != Some(index) {
963                continue;
964            }
965
966            // if we have a ray tracing pipeline shader we are definitely using
967            // pipelines, otherwise, if we have a ray tracing type, we might
968            // be using it in the shader (which would require ray queries),
969            // so we should use queries.
970            if matches!(
971                ep.stage,
972                crate::ShaderStage::RayGeneration
973                    | crate::ShaderStage::AnyHit
974                    | crate::ShaderStage::ClosestHit
975                    | crate::ShaderStage::Miss
976            ) {
977                uses.pipelines = true;
978            } else {
979                uses.queries |= uses_ray_tracing;
980            }
981        }
982
983        uses
984    }
985}
986
987#[derive(Copy, Clone, Debug, Default)]
988pub struct RayTracingUses {
989    pub pipelines: bool,
990    pub queries: bool,
991}
992
993impl crate::MeshOutputTopology {
994    pub const fn to_builtin(self) -> crate::BuiltIn {
995        match self {
996            Self::Points => crate::BuiltIn::PointIndex,
997            Self::Lines => crate::BuiltIn::LineIndices,
998            Self::Triangles => crate::BuiltIn::TriangleIndices,
999        }
1000    }
1001}
1002
1003impl crate::AddressSpace {
1004    pub const fn is_workgroup_like(self) -> bool {
1005        matches!(self, Self::WorkGroup | Self::TaskPayload)
1006    }
1007}
1008
1009impl TryFrom<crate::ScalarKind> for nt::glsl::GlslScalarKind {
1010    type Error = ();
1011
1012    fn try_from(value: crate::ScalarKind) -> Result<Self, Self::Error> {
1013        Ok(match value {
1014            crate::ScalarKind::Sint => nt::glsl::GlslScalarKind::Sint,
1015            crate::ScalarKind::Uint => nt::glsl::GlslScalarKind::Uint,
1016            crate::ScalarKind::Float => nt::glsl::GlslScalarKind::Float,
1017            _ => return Err(()),
1018        })
1019    }
1020}
1021
1022impl From<crate::VectorSize> for nt::glsl::GlslVectorSize {
1023    fn from(val: crate::VectorSize) -> Self {
1024        match val {
1025            crate::VectorSize::Bi => nt::glsl::GlslVectorSize::Bi,
1026            crate::VectorSize::Tri => nt::glsl::GlslVectorSize::Tri,
1027            crate::VectorSize::Quad => nt::glsl::GlslVectorSize::Quad,
1028        }
1029    }
1030}
1031
1032impl TryFrom<crate::Scalar> for nt::glsl::GlslScalar {
1033    type Error = ();
1034
1035    fn try_from(value: crate::Scalar) -> Result<Self, Self::Error> {
1036        Ok(nt::glsl::GlslScalar {
1037            kind: value.kind.try_into()?,
1038            width: value.width,
1039        })
1040    }
1041}
1042
1043impl TryFrom<&crate::TypeInner> for nt::glsl::GlslUniformType {
1044    type Error = ();
1045    fn try_from(value: &crate::TypeInner) -> Result<Self, Self::Error> {
1046        match *value {
1047            crate::TypeInner::Scalar(scalar) => {
1048                Ok(nt::glsl::GlslUniformType::Scalar(scalar.try_into()?))
1049            }
1050            crate::TypeInner::Vector { size, scalar } => Ok(nt::glsl::GlslUniformType::Vector {
1051                size: size.into(),
1052                scalar: scalar.try_into()?,
1053            }),
1054            crate::TypeInner::Matrix {
1055                columns,
1056                rows,
1057                scalar,
1058            } => Ok(nt::glsl::GlslUniformType::Matrix {
1059                columns: columns.into(),
1060                rows: rows.into(),
1061                scalar: scalar.try_into()?,
1062            }),
1063            _ => Err(()),
1064        }
1065    }
1066}