naga/proc/
typifier.rs

1use alloc::{format, string::String};
2
3use thiserror::Error;
4
5use crate::{
6    arena::{Arena, Handle, UniqueArena},
7    common::ForDebugWithTypes,
8    ir,
9};
10
11/// The result of computing an expression's type.
12///
13/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
14/// the (Naga) type it ascribes to some expression.
15///
16/// You might expect such a function to simply return a `Handle<Type>`. However,
17/// we want type resolution to be a read-only process, and that would limit the
18/// possible results to types already present in the expression's associated
19/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
20/// not certain to be present.
21///
22/// So instead, type resolution returns a `TypeResolution` enum: either a
23/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
24/// free-floating [`TypeInner`]. This extends the range to cover anything that
25/// can be represented with a `TypeInner` referring to the existing arena.
26///
27/// What sorts of expressions can have types not available in the arena?
28///
29/// -   An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
30///     [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
31///     and `Matrix` represent their element and column types implicitly, not
32///     via a handle, there may not be a suitable type in the expression's
33///     associated arena. Instead, resolving such an expression returns a
34///     `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
35///     `Vector`.
36///
37/// -   Similarly, the type of an [`Access`] or [`AccessIndex`] expression
38///     applied to a *pointer to* a vector or matrix must produce a *pointer to*
39///     a scalar or vector type. These cannot be represented with a
40///     [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
41///     arena, and as before, we cannot assume that a suitable scalar or vector
42///     type is there. So we take things one step further and provide
43///     [`TypeInner::ValuePointer`], specifically for the case of pointers to
44///     scalars or vectors. This type fits in a `TypeInner` and is exactly
45///     equivalent to a `Pointer` to a `Vector` or `Scalar`.
46///
47/// So, for example, the type of an `Access` expression applied to a value of type:
48///
49/// ```ignore
50/// TypeInner::Matrix { columns, rows, width }
51/// ```
52///
53/// might be:
54///
55/// ```ignore
56/// TypeResolution::Value(TypeInner::Vector {
57///     size: rows,
58///     kind: ScalarKind::Float,
59///     width,
60/// })
61/// ```
62///
63/// and the type of an access to a pointer of address space `space` to such a
64/// matrix might be:
65///
66/// ```ignore
67/// TypeResolution::Value(TypeInner::ValuePointer {
68///     size: Some(rows),
69///     kind: ScalarKind::Float,
70///     width,
71///     space,
72/// })
73/// ```
74///
75/// [`Handle`]: TypeResolution::Handle
76/// [`Value`]: TypeResolution::Value
77///
78/// [`Access`]: crate::Expression::Access
79/// [`AccessIndex`]: crate::Expression::AccessIndex
80///
81/// [`TypeInner`]: crate::TypeInner
82/// [`Matrix`]: crate::TypeInner::Matrix
83/// [`Pointer`]: crate::TypeInner::Pointer
84/// [`Scalar`]: crate::TypeInner::Scalar
85/// [`ValuePointer`]: crate::TypeInner::ValuePointer
86/// [`Vector`]: crate::TypeInner::Vector
87///
88/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
89/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
90#[derive(Debug, PartialEq)]
91#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
92#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
93pub enum TypeResolution {
94    /// A type stored in the associated arena.
95    Handle(Handle<crate::Type>),
96
97    /// A free-floating [`TypeInner`], representing a type that may not be
98    /// available in the associated arena. However, the `TypeInner` itself may
99    /// contain `Handle<Type>` values referring to types from the arena.
100    ///
101    /// The inner type must only be one of the following variants:
102    /// - TypeInner::Pointer
103    /// - TypeInner::ValuePointer
104    /// - TypeInner::Matrix (generated by matrix multiplication)
105    /// - TypeInner::Vector
106    /// - TypeInner::Scalar
107    ///
108    /// [`TypeInner`]: crate::TypeInner
109    Value(crate::TypeInner),
110}
111
112impl TypeResolution {
113    pub const fn handle(&self) -> Option<Handle<crate::Type>> {
114        match *self {
115            Self::Handle(handle) => Some(handle),
116            Self::Value(_) => None,
117        }
118    }
119
120    pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
121        match *self {
122            Self::Handle(handle) => &arena[handle].inner,
123            Self::Value(ref inner) => inner,
124        }
125    }
126}
127
128// Clone is only implemented for numeric variants of `TypeInner`.
129impl Clone for TypeResolution {
130    fn clone(&self) -> Self {
131        use crate::TypeInner as Ti;
132        match *self {
133            Self::Handle(handle) => Self::Handle(handle),
134            Self::Value(ref v) => Self::Value(match *v {
135                Ti::Scalar(scalar) => Ti::Scalar(scalar),
136                Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
137                Ti::Matrix {
138                    rows,
139                    columns,
140                    scalar,
141                } => Ti::Matrix {
142                    rows,
143                    columns,
144                    scalar,
145                },
146                Ti::CooperativeMatrix {
147                    columns,
148                    rows,
149                    scalar,
150                    role,
151                } => Ti::CooperativeMatrix {
152                    columns,
153                    rows,
154                    scalar,
155                    role,
156                },
157                Ti::Pointer { base, space } => Ti::Pointer { base, space },
158                Ti::ValuePointer {
159                    size,
160                    scalar,
161                    space,
162                } => Ti::ValuePointer {
163                    size,
164                    scalar,
165                    space,
166                },
167                Ti::Array { base, size, stride } => Ti::Array { base, size, stride },
168                _ => unreachable!("Unexpected clone type: {:?}", v),
169            }),
170        }
171    }
172}
173
174#[derive(Clone, Debug, Error, PartialEq)]
175pub enum ResolveError {
176    #[error("Index {index} is out of bounds for expression {expr:?}")]
177    OutOfBoundsIndex {
178        expr: Handle<crate::Expression>,
179        index: u32,
180    },
181    #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
182    InvalidAccess {
183        expr: Handle<crate::Expression>,
184        indexed: bool,
185    },
186    #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
187    InvalidSubAccess {
188        ty: Handle<crate::Type>,
189        indexed: bool,
190    },
191    #[error("Invalid scalar {0:?}")]
192    InvalidScalar(Handle<crate::Expression>),
193    #[error("Invalid vector {0:?}")]
194    InvalidVector(Handle<crate::Expression>),
195    #[error("Invalid pointer {0:?}")]
196    InvalidPointer(Handle<crate::Expression>),
197    #[error("Invalid image {0:?}")]
198    InvalidImage(Handle<crate::Expression>),
199    #[error("Function {name} not defined")]
200    FunctionNotDefined { name: String },
201    #[error("Function without return type")]
202    FunctionReturnsVoid,
203    #[error("Incompatible operands: {0}")]
204    IncompatibleOperands(String),
205    #[error("Function argument {0} doesn't exist")]
206    FunctionArgumentNotFound(u32),
207    #[error("Special type is not registered within the module")]
208    MissingSpecialType,
209    #[error("Call to builtin {0} has incorrect or ambiguous arguments")]
210    BuiltinArgumentsInvalid(String),
211}
212
213impl From<crate::proc::MissingSpecialType> for ResolveError {
214    fn from(_unit_struct: crate::proc::MissingSpecialType) -> Self {
215        ResolveError::MissingSpecialType
216    }
217}
218
219#[expect(missing_debug_implementations, reason = "would be way too verbose?")]
220pub struct ResolveContext<'a> {
221    pub constants: &'a Arena<crate::Constant>,
222    pub overrides: &'a Arena<crate::Override>,
223    pub types: &'a UniqueArena<crate::Type>,
224    pub special_types: &'a crate::SpecialTypes,
225    pub global_vars: &'a Arena<crate::GlobalVariable>,
226    pub local_vars: &'a Arena<crate::LocalVariable>,
227    pub functions: &'a Arena<crate::Function>,
228    pub arguments: &'a [crate::FunctionArgument],
229}
230
231impl<'a> ResolveContext<'a> {
232    /// Initialize a resolve context from the module.
233    pub const fn with_locals(
234        module: &'a crate::Module,
235        local_vars: &'a Arena<crate::LocalVariable>,
236        arguments: &'a [crate::FunctionArgument],
237    ) -> Self {
238        Self {
239            constants: &module.constants,
240            overrides: &module.overrides,
241            types: &module.types,
242            special_types: &module.special_types,
243            global_vars: &module.global_variables,
244            local_vars,
245            functions: &module.functions,
246            arguments,
247        }
248    }
249
250    /// Determine the type of `expr`.
251    ///
252    /// The `past` argument must be a closure that can resolve the types of any
253    /// expressions that `expr` refers to. These can be gathered by caching the
254    /// results of prior calls to `resolve`, perhaps as done by the
255    /// [`front::Typifier`] utility type.
256    ///
257    /// Type resolution is a read-only process: this method takes `self` by
258    /// shared reference. However, this means that we cannot add anything to
259    /// `self.types` that we might need to describe `expr`. To work around this,
260    /// this method returns a [`TypeResolution`], rather than simply returning a
261    /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
262    /// details.
263    ///
264    /// [`front::Typifier`]: crate::front::Typifier
265    pub fn resolve(
266        &self,
267        expr: &crate::Expression,
268        past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
269    ) -> Result<TypeResolution, ResolveError> {
270        use crate::TypeInner as Ti;
271        let types = self.types;
272        Ok(match *expr {
273            crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
274                // Arrays and matrices can only be indexed dynamically behind a
275                // pointer, but that's a validation error, not a type error, so
276                // go ahead provide a type here.
277                Ti::Array { base, .. } => TypeResolution::Handle(base),
278                Ti::Matrix { rows, scalar, .. } => {
279                    TypeResolution::Value(Ti::Vector { size: rows, scalar })
280                }
281                Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
282                Ti::ValuePointer {
283                    size: Some(_),
284                    scalar,
285                    space,
286                } => TypeResolution::Value(Ti::ValuePointer {
287                    size: None,
288                    scalar,
289                    space,
290                }),
291                Ti::Pointer { base, space } => {
292                    TypeResolution::Value(match types[base].inner {
293                        Ti::Array { base, .. } => Ti::Pointer { base, space },
294                        Ti::Vector { size: _, scalar } => Ti::ValuePointer {
295                            size: None,
296                            scalar,
297                            space,
298                        },
299                        // Matrices are only dynamically indexed behind a pointer
300                        Ti::Matrix {
301                            columns: _,
302                            rows,
303                            scalar,
304                        } => Ti::ValuePointer {
305                            size: Some(rows),
306                            scalar,
307                            space,
308                        },
309                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
310                        ref other => {
311                            log::error!("Access sub-type {other:?}");
312                            return Err(ResolveError::InvalidSubAccess {
313                                ty: base,
314                                indexed: false,
315                            });
316                        }
317                    })
318                }
319                Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
320                ref other => {
321                    log::error!("Access type {other:?}");
322                    return Err(ResolveError::InvalidAccess {
323                        expr: base,
324                        indexed: false,
325                    });
326                }
327            },
328            crate::Expression::AccessIndex { base, index } => {
329                match *past(base)?.inner_with(types) {
330                    Ti::Vector { size, scalar } => {
331                        if index >= size as u32 {
332                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
333                        }
334                        TypeResolution::Value(Ti::Scalar(scalar))
335                    }
336                    Ti::Matrix {
337                        columns,
338                        rows,
339                        scalar,
340                    } => {
341                        if index >= columns as u32 {
342                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
343                        }
344                        TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
345                    }
346                    Ti::Array { base, .. } => TypeResolution::Handle(base),
347                    Ti::Struct { ref members, .. } => {
348                        let member = members
349                            .get(index as usize)
350                            .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
351                        TypeResolution::Handle(member.ty)
352                    }
353                    Ti::ValuePointer {
354                        size: Some(size),
355                        scalar,
356                        space,
357                    } => {
358                        if index >= size as u32 {
359                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
360                        }
361                        TypeResolution::Value(Ti::ValuePointer {
362                            size: None,
363                            scalar,
364                            space,
365                        })
366                    }
367                    Ti::Pointer {
368                        base: ty_base,
369                        space,
370                    } => TypeResolution::Value(match types[ty_base].inner {
371                        Ti::Array { base, .. } => Ti::Pointer { base, space },
372                        Ti::Vector { size, scalar } => {
373                            if index >= size as u32 {
374                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
375                            }
376                            Ti::ValuePointer {
377                                size: None,
378                                scalar,
379                                space,
380                            }
381                        }
382                        Ti::Matrix {
383                            rows,
384                            columns,
385                            scalar,
386                        } => {
387                            if index >= columns as u32 {
388                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
389                            }
390                            Ti::ValuePointer {
391                                size: Some(rows),
392                                scalar,
393                                space,
394                            }
395                        }
396                        Ti::Struct { ref members, .. } => {
397                            let member = members
398                                .get(index as usize)
399                                .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
400                            Ti::Pointer {
401                                base: member.ty,
402                                space,
403                            }
404                        }
405                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
406                        ref other => {
407                            log::error!("Access index sub-type {other:?}");
408                            return Err(ResolveError::InvalidSubAccess {
409                                ty: ty_base,
410                                indexed: true,
411                            });
412                        }
413                    }),
414                    Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
415                    ref other => {
416                        log::error!("Access index type {other:?}");
417                        return Err(ResolveError::InvalidAccess {
418                            expr: base,
419                            indexed: true,
420                        });
421                    }
422                }
423            }
424            crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
425                Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
426                ref other => {
427                    log::error!("Scalar type {other:?}");
428                    return Err(ResolveError::InvalidScalar(value));
429                }
430            },
431            crate::Expression::Swizzle {
432                size,
433                vector,
434                pattern: _,
435            } => match *past(vector)?.inner_with(types) {
436                Ti::Vector { size: _, scalar } => {
437                    TypeResolution::Value(Ti::Vector { size, scalar })
438                }
439                ref other => {
440                    log::error!("Vector type {other:?}");
441                    return Err(ResolveError::InvalidVector(vector));
442                }
443            },
444            crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
445            crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
446            crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
447            crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
448            crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
449            crate::Expression::FunctionArgument(index) => {
450                let arg = self
451                    .arguments
452                    .get(index as usize)
453                    .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
454                TypeResolution::Handle(arg.ty)
455            }
456            crate::Expression::GlobalVariable(h) => {
457                let var = &self.global_vars[h];
458                if var.space == crate::AddressSpace::Handle {
459                    TypeResolution::Handle(var.ty)
460                } else {
461                    TypeResolution::Value(Ti::Pointer {
462                        base: var.ty,
463                        space: var.space,
464                    })
465                }
466            }
467            crate::Expression::LocalVariable(h) => {
468                let var = &self.local_vars[h];
469                TypeResolution::Value(Ti::Pointer {
470                    base: var.ty,
471                    space: crate::AddressSpace::Function,
472                })
473            }
474            crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
475                Ti::Pointer { base, space: _ } => {
476                    if let Ti::Atomic(scalar) = types[base].inner {
477                        TypeResolution::Value(Ti::Scalar(scalar))
478                    } else {
479                        TypeResolution::Handle(base)
480                    }
481                }
482                Ti::ValuePointer {
483                    size,
484                    scalar,
485                    space: _,
486                } => TypeResolution::Value(match size {
487                    Some(size) => Ti::Vector { size, scalar },
488                    None => Ti::Scalar(scalar),
489                }),
490                ref other => {
491                    log::error!("Pointer {pointer:?} type {other:?}");
492                    return Err(ResolveError::InvalidPointer(pointer));
493                }
494            },
495            crate::Expression::ImageSample {
496                image,
497                gather: Some(_),
498                ..
499            } => match *past(image)?.inner_with(types) {
500                Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
501                    scalar: crate::Scalar {
502                        kind: match class {
503                            crate::ImageClass::Sampled { kind, multi: _ } => kind,
504                            _ => crate::ScalarKind::Float,
505                        },
506                        width: 4,
507                    },
508                    size: crate::VectorSize::Quad,
509                }),
510                ref other => {
511                    log::error!("Image type {other:?}");
512                    return Err(ResolveError::InvalidImage(image));
513                }
514            },
515            crate::Expression::ImageSample { image, .. }
516            | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
517                Ti::Image { class, .. } => TypeResolution::Value(match class {
518                    crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
519                    crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
520                        scalar: crate::Scalar { kind, width: 4 },
521                        size: crate::VectorSize::Quad,
522                    },
523                    crate::ImageClass::Storage { format, .. } => Ti::Vector {
524                        scalar: format.into(),
525                        size: crate::VectorSize::Quad,
526                    },
527                    crate::ImageClass::External => Ti::Vector {
528                        scalar: crate::Scalar::F32,
529                        size: crate::VectorSize::Quad,
530                    },
531                }),
532                ref other => {
533                    log::error!("Image type {other:?}");
534                    return Err(ResolveError::InvalidImage(image));
535                }
536            },
537            crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
538                crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
539                    Ti::Image { dim, .. } => match dim {
540                        crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
541                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
542                            size: crate::VectorSize::Bi,
543                            scalar: crate::Scalar::U32,
544                        },
545                        crate::ImageDimension::D3 => Ti::Vector {
546                            size: crate::VectorSize::Tri,
547                            scalar: crate::Scalar::U32,
548                        },
549                    },
550                    ref other => {
551                        log::error!("Image type {other:?}");
552                        return Err(ResolveError::InvalidImage(image));
553                    }
554                },
555                crate::ImageQuery::NumLevels
556                | crate::ImageQuery::NumLayers
557                | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
558            }),
559            crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
560            crate::Expression::Binary { op, left, right } => match op {
561                crate::BinaryOperator::Add
562                | crate::BinaryOperator::Subtract
563                | crate::BinaryOperator::Divide
564                | crate::BinaryOperator::Modulo => past(left)?.clone(),
565                crate::BinaryOperator::Multiply => {
566                    let (res_left, res_right) = (past(left)?, past(right)?);
567                    match (res_left.inner_with(types), res_right.inner_with(types)) {
568                        (
569                            &Ti::Matrix {
570                                columns: _,
571                                rows,
572                                scalar,
573                            },
574                            &Ti::Matrix { columns, .. },
575                        ) => TypeResolution::Value(Ti::Matrix {
576                            columns,
577                            rows,
578                            scalar,
579                        }),
580                        (
581                            &Ti::Matrix {
582                                columns: _,
583                                rows,
584                                scalar,
585                            },
586                            &Ti::Vector { .. },
587                        ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
588                        (
589                            &Ti::Vector { .. },
590                            &Ti::Matrix {
591                                columns,
592                                rows: _,
593                                scalar,
594                            },
595                        ) => TypeResolution::Value(Ti::Vector {
596                            size: columns,
597                            scalar,
598                        }),
599                        (&Ti::Scalar { .. }, _) => res_right.clone(),
600                        (_, &Ti::Scalar { .. }) => res_left.clone(),
601                        (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
602                        (
603                            &Ti::CooperativeMatrix {
604                                columns: _,
605                                rows,
606                                scalar,
607                                role,
608                            },
609                            &Ti::CooperativeMatrix { columns, .. },
610                        ) => TypeResolution::Value(Ti::CooperativeMatrix {
611                            columns,
612                            rows,
613                            scalar,
614                            role,
615                        }),
616                        (tl, tr) => {
617                            return Err(ResolveError::IncompatibleOperands(format!(
618                                "{tl:?} * {tr:?}"
619                            )))
620                        }
621                    }
622                }
623                crate::BinaryOperator::Equal
624                | crate::BinaryOperator::NotEqual
625                | crate::BinaryOperator::Less
626                | crate::BinaryOperator::LessEqual
627                | crate::BinaryOperator::Greater
628                | crate::BinaryOperator::GreaterEqual => {
629                    // These accept scalars or vectors.
630                    let scalar = crate::Scalar::BOOL;
631                    let inner = match *past(left)?.inner_with(types) {
632                        Ti::Scalar { .. } => Ti::Scalar(scalar),
633                        Ti::Vector { size, .. } => Ti::Vector { size, scalar },
634                        ref other => {
635                            return Err(ResolveError::IncompatibleOperands(format!(
636                                "{op:?}({other:?}, _)"
637                            )))
638                        }
639                    };
640                    TypeResolution::Value(inner)
641                }
642                crate::BinaryOperator::LogicalAnd | crate::BinaryOperator::LogicalOr => {
643                    // These accept scalars only.
644                    let bool = Ti::Scalar(crate::Scalar::BOOL);
645                    let ty = past(left)?.inner_with(types);
646                    if *ty == bool {
647                        TypeResolution::Value(bool)
648                    } else {
649                        return Err(ResolveError::IncompatibleOperands(format!(
650                            "{op:?}({:?}, _)",
651                            ty.for_debug(types),
652                        )));
653                    }
654                }
655                crate::BinaryOperator::And
656                | crate::BinaryOperator::ExclusiveOr
657                | crate::BinaryOperator::InclusiveOr
658                | crate::BinaryOperator::ShiftLeft
659                | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
660            },
661            crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
662            crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
663            crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
664            crate::Expression::Select { accept, .. } => past(accept)?.clone(),
665            crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
666            crate::Expression::Relational { fun, argument } => match fun {
667                crate::RelationalFunction::All | crate::RelationalFunction::Any => {
668                    TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
669                }
670                crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
671                    match *past(argument)?.inner_with(types) {
672                        Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
673                        Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
674                            scalar: crate::Scalar::BOOL,
675                            size,
676                        }),
677                        ref other => {
678                            return Err(ResolveError::IncompatibleOperands(format!(
679                                "{fun:?}({other:?})"
680                            )))
681                        }
682                    }
683                }
684            },
685            crate::Expression::Math {
686                fun,
687                arg,
688                arg1,
689                arg2: _,
690                arg3: _,
691            } => {
692                use crate::proc::OverloadSet as _;
693
694                let mut overloads = fun.overloads();
695                log::debug!(
696                    "initial overloads for {fun:?}, {:#?}",
697                    overloads.for_debug(types)
698                );
699
700                // If any argument is not a constant expression, then no
701                // overloads that accept abstract values should be considered.
702                // `OverloadSet::concrete_only` is supposed to help impose this
703                // restriction. However, no `MathFunction` accepts a mix of
704                // abstract and concrete arguments, so we don't need to worry
705                // about that here.
706
707                let res_arg = past(arg)?;
708                overloads = overloads.arg(0, res_arg.inner_with(types), types);
709                log::debug!(
710                    "overloads after arg 0 of type {:?}: {:#?}",
711                    res_arg.for_debug(types),
712                    overloads.for_debug(types)
713                );
714
715                if let Some(arg1) = arg1 {
716                    let res_arg1 = past(arg1)?;
717                    overloads = overloads.arg(1, res_arg1.inner_with(types), types);
718                    log::debug!(
719                        "overloads after arg 1 of type {:?}: {:#?}",
720                        res_arg1.for_debug(types),
721                        overloads.for_debug(types)
722                    );
723                }
724
725                if overloads.is_empty() {
726                    return Err(ResolveError::BuiltinArgumentsInvalid(format!("{fun:?}")));
727                }
728
729                let rule = overloads.most_preferred();
730
731                rule.conclusion.into_resolution(self.special_types)?
732            }
733            crate::Expression::As {
734                expr,
735                kind,
736                convert,
737            } => match *past(expr)?.inner_with(types) {
738                Ti::Scalar(crate::Scalar { width, .. }) => {
739                    TypeResolution::Value(Ti::Scalar(crate::Scalar {
740                        kind,
741                        width: convert.unwrap_or(width),
742                    }))
743                }
744                Ti::Vector {
745                    size,
746                    scalar: crate::Scalar { kind: _, width },
747                } => TypeResolution::Value(Ti::Vector {
748                    size,
749                    scalar: crate::Scalar {
750                        kind,
751                        width: convert.unwrap_or(width),
752                    },
753                }),
754                Ti::Matrix {
755                    columns,
756                    rows,
757                    mut scalar,
758                } => {
759                    if let Some(width) = convert {
760                        scalar.width = width;
761                    }
762                    TypeResolution::Value(Ti::Matrix {
763                        columns,
764                        rows,
765                        scalar,
766                    })
767                }
768                ref other => {
769                    return Err(ResolveError::IncompatibleOperands(format!(
770                        "{other:?} as {kind:?}"
771                    )))
772                }
773            },
774            crate::Expression::CallResult(function) => {
775                let result = self.functions[function]
776                    .result
777                    .as_ref()
778                    .ok_or(ResolveError::FunctionReturnsVoid)?;
779                TypeResolution::Handle(result.ty)
780            }
781            crate::Expression::ArrayLength(_) => {
782                TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
783            }
784            crate::Expression::RayQueryProceedResult => {
785                TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
786            }
787            crate::Expression::RayQueryGetIntersection { .. } => {
788                let result = self
789                    .special_types
790                    .ray_intersection
791                    .ok_or(ResolveError::MissingSpecialType)?;
792                TypeResolution::Handle(result)
793            }
794            crate::Expression::RayQueryVertexPositions { .. } => {
795                let result = self
796                    .special_types
797                    .ray_vertex_return
798                    .ok_or(ResolveError::MissingSpecialType)?;
799                TypeResolution::Handle(result)
800            }
801            crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
802                scalar: crate::Scalar::U32,
803                size: crate::VectorSize::Quad,
804            }),
805            crate::Expression::CooperativeLoad {
806                columns,
807                rows,
808                role,
809                ref data,
810            } => {
811                let scalar = past(data.pointer)?
812                    .inner_with(types)
813                    .pointer_base_type()
814                    .and_then(|tr| tr.inner_with(types).scalar())
815                    .ok_or(ResolveError::InvalidPointer(data.pointer))?;
816                TypeResolution::Value(Ti::CooperativeMatrix {
817                    columns,
818                    rows,
819                    scalar,
820                    role,
821                })
822            }
823            crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(),
824        })
825    }
826}
827
828/// Compare two types.
829///
830/// This is the most general way of comparing two types, as it can distinguish
831/// two structs with different names but the same members. For other ways, see
832/// [`TypeInner::non_struct_equivalent`] and [`TypeInner::eq`].
833///
834/// In Naga code, this is usually called via the like-named methods on [`Module`],
835/// [`GlobalCtx`], and `BlockContext`.
836///
837/// [`TypeInner::non_struct_equivalent`]: crate::ir::TypeInner::non_struct_equivalent
838/// [`TypeInner::eq`]: crate::ir::TypeInner
839/// [`Module`]: crate::ir::Module
840/// [`GlobalCtx`]: crate::proc::GlobalCtx
841pub fn compare_types(
842    lhs: &TypeResolution,
843    rhs: &TypeResolution,
844    types: &UniqueArena<crate::Type>,
845) -> bool {
846    match lhs {
847        &TypeResolution::Handle(lhs_handle)
848            if matches!(
849                types[lhs_handle],
850                ir::Type {
851                    inner: ir::TypeInner::Struct { .. },
852                    ..
853                }
854            ) =>
855        {
856            // Structs can only be in the arena, not in a TypeResolution::Value
857            rhs.handle()
858                .is_some_and(|rhs_handle| lhs_handle == rhs_handle)
859        }
860        _ => lhs
861            .inner_with(types)
862            .non_struct_equivalent(rhs.inner_with(types), types),
863    }
864}
865
866#[test]
867fn test_error_size() {
868    assert_eq!(size_of::<ResolveError>(), 32);
869}