naga/proc/overloads/
utils.rs

1//! Utility functions for constructing [`List`] overload sets.
2//!
3//! [`List`]: crate::proc::overloads::list::List
4
5use crate::ir;
6use crate::proc::overloads::list::List;
7use crate::proc::overloads::rule::{Conclusion, Rule};
8use crate::proc::TypeResolution;
9
10use alloc::vec::Vec;
11
12/// Produce all vector sizes.
13pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
14    static SIZES: [ir::VectorSize; 3] = [
15        ir::VectorSize::Bi,
16        ir::VectorSize::Tri,
17        ir::VectorSize::Quad,
18    ];
19
20    SIZES.iter().cloned()
21}
22
23/// Produce all the floating-point [`ir::Scalar`]s.
24///
25/// Note that `F32` must appear before other sizes; this is how we
26/// represent conversion rank.
27pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
28    [
29        ir::Scalar::ABSTRACT_FLOAT,
30        ir::Scalar::F32,
31        ir::Scalar::F16,
32        ir::Scalar::F64,
33    ]
34    .into_iter()
35}
36
37/// Produce all the floating-point [`ir::Scalar`]s, but omit
38/// abstract types, for #7405.
39pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {
40    [ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
41}
42
43/// Produce all concrete integer [`ir::Scalar`]s.
44///
45/// Note that `I32` and `U32` must come first; this is how we
46/// represent conversion rank.
47pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
48    [
49        ir::Scalar::I32,
50        ir::Scalar::U32,
51        ir::Scalar::I64,
52        ir::Scalar::U64,
53    ]
54    .into_iter()
55}
56
57/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
58/// their scalar.
59pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
60    [
61        ir::TypeInner::Scalar(scalar),
62        ir::TypeInner::Vector {
63            size: ir::VectorSize::Bi,
64            scalar,
65        },
66        ir::TypeInner::Vector {
67            size: ir::VectorSize::Tri,
68            scalar,
69        },
70        ir::TypeInner::Vector {
71            size: ir::VectorSize::Quad,
72            scalar,
73        },
74    ]
75    .into_iter()
76}
77
78/// Construct a [`Rule`] for an operation with the given
79/// argument types and return type.
80pub fn rule<const N: usize>(args: [ir::TypeInner; N], ret: ir::TypeInner) -> Rule {
81    Rule {
82        arguments: Vec::from_iter(args.into_iter().map(TypeResolution::Value)),
83        conclusion: Conclusion::Value(ret),
84    }
85}
86
87/// Construct a [`List`] from the given rules.
88pub fn list(rules: impl Iterator<Item = Rule>) -> List {
89    List::from_rules(rules.collect())
90}
91
92/// Return the cartesian product of two iterators.
93pub fn pairs<T: Clone, U>(
94    left: impl Iterator<Item = T>,
95    right: impl Iterator<Item = U> + Clone,
96) -> impl Iterator<Item = (T, U)> {
97    left.flat_map(move |t| right.clone().map(move |u| (t.clone(), u)))
98}
99
100/// Return the cartesian product of three iterators.
101pub fn triples<T: Clone, U: Clone, V>(
102    left: impl Iterator<Item = T>,
103    middle: impl Iterator<Item = U> + Clone,
104    right: impl Iterator<Item = V> + Clone,
105) -> impl Iterator<Item = (T, U, V)> {
106    left.flat_map(move |t| {
107        let right = right.clone();
108        middle.clone().flat_map(move |u| {
109            let t = t.clone();
110            right.clone().map(move |v| (t.clone(), u.clone(), v))
111        })
112    })
113}