naga/proc/overloads/
utils.rs

1//! Utility functions for constructing [`List`] overload sets.
2//!
3//! [`List`]: crate::proc::overloads::list::List
4
5use crate::ir;
6use crate::proc::overloads::list::List;
7use crate::proc::overloads::rule::{Conclusion, Rule};
8use crate::proc::TypeResolution;
9
10use alloc::vec::Vec;
11
12/// Produce all the floating-point [`ir::Scalar`]s.
13///
14/// Note that `F32` must appear before other sizes; this is how we
15/// represent conversion rank.
16pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
17    [
18        ir::Scalar::ABSTRACT_FLOAT,
19        ir::Scalar::F32,
20        ir::Scalar::F16,
21        ir::Scalar::F64,
22    ]
23    .into_iter()
24}
25
26/// Produce all the floating-point [`ir::Scalar`]s, but omit
27/// abstract types, for #7405.
28pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {
29    [ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
30}
31
32/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
33/// their scalar.
34pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
35    [
36        ir::TypeInner::Scalar(scalar),
37        ir::TypeInner::Vector {
38            size: ir::VectorSize::Bi,
39            scalar,
40        },
41        ir::TypeInner::Vector {
42            size: ir::VectorSize::Tri,
43            scalar,
44        },
45        ir::TypeInner::Vector {
46            size: ir::VectorSize::Quad,
47            scalar,
48        },
49    ]
50    .into_iter()
51}
52
53/// Construct a [`Rule`] for an operation with the given
54/// argument types and return type.
55pub fn rule<const N: usize>(args: [ir::TypeInner; N], ret: ir::TypeInner) -> Rule {
56    Rule {
57        arguments: Vec::from_iter(args.into_iter().map(TypeResolution::Value)),
58        conclusion: Conclusion::Value(ret),
59    }
60}
61
62/// Construct a [`List`] from the given rules.
63pub fn list(rules: impl Iterator<Item = Rule>) -> List {
64    List::from_rules(rules.collect())
65}
66
67/// Return the cartesian product of two iterators.
68pub fn pairs<T: Clone, U>(
69    left: impl Iterator<Item = T>,
70    right: impl Iterator<Item = U> + Clone,
71) -> impl Iterator<Item = (T, U)> {
72    left.flat_map(move |t| right.clone().map(move |u| (t.clone(), u)))
73}
74
75/// Return the cartesian product of three iterators.
76pub fn triples<T: Clone, U: Clone, V>(
77    left: impl Iterator<Item = T>,
78    middle: impl Iterator<Item = U> + Clone,
79    right: impl Iterator<Item = V> + Clone,
80) -> impl Iterator<Item = (T, U, V)> {
81    left.flat_map(move |t| {
82        let right = right.clone();
83        middle.clone().flat_map(move |u| {
84            let t = t.clone();
85            right.clone().map(move |v| (t.clone(), u.clone(), v))
86        })
87    })
88}