naga/proc/overloads/
utils.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//! Utility functions for constructing [`List`] overload sets.
//!
//! [`List`]: crate::proc::overloads::list::List

use crate::ir;
use crate::proc::overloads::list::List;
use crate::proc::overloads::rule::{Conclusion, Rule};
use crate::proc::TypeResolution;

use alloc::vec::Vec;

/// Produce all vector sizes.
pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
    static SIZES: [ir::VectorSize; 3] = [
        ir::VectorSize::Bi,
        ir::VectorSize::Tri,
        ir::VectorSize::Quad,
    ];

    SIZES.iter().cloned()
}

/// Produce all the floating-point [`ir::Scalar`]s.
///
/// Note that `F32` must appear before other sizes; this is how we
/// represent conversion rank.
pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
    [
        ir::Scalar::ABSTRACT_FLOAT,
        ir::Scalar::F32,
        ir::Scalar::F16,
        ir::Scalar::F64,
    ]
    .into_iter()
}

/// Produce all the floating-point [`ir::Scalar`]s, but omit
/// abstract types, for #7405.
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {
    [ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
}

/// Produce all concrete integer [`ir::Scalar`]s.
///
/// Note that `I32` and `U32` must come first; this is how we
/// represent conversion rank.
pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
    [
        ir::Scalar::I32,
        ir::Scalar::U32,
        ir::Scalar::I64,
        ir::Scalar::U64,
    ]
    .into_iter()
}

/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
/// their scalar.
pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
    [
        ir::TypeInner::Scalar(scalar),
        ir::TypeInner::Vector {
            size: ir::VectorSize::Bi,
            scalar,
        },
        ir::TypeInner::Vector {
            size: ir::VectorSize::Tri,
            scalar,
        },
        ir::TypeInner::Vector {
            size: ir::VectorSize::Quad,
            scalar,
        },
    ]
    .into_iter()
}

/// Construct a [`Rule`] for an operation with the given
/// argument types and return type.
pub fn rule<const N: usize>(args: [ir::TypeInner; N], ret: ir::TypeInner) -> Rule {
    Rule {
        arguments: Vec::from_iter(args.into_iter().map(TypeResolution::Value)),
        conclusion: Conclusion::Value(ret),
    }
}

/// Construct a [`List`] from the given rules.
pub fn list(rules: impl Iterator<Item = Rule>) -> List {
    List::from_rules(rules.collect())
}

/// Return the cartesian product of two iterators.
pub fn pairs<T: Clone, U>(
    left: impl Iterator<Item = T>,
    right: impl Iterator<Item = U> + Clone,
) -> impl Iterator<Item = (T, U)> {
    left.flat_map(move |t| right.clone().map(move |u| (t.clone(), u)))
}

/// Return the cartesian product of three iterators.
pub fn triples<T: Clone, U: Clone, V>(
    left: impl Iterator<Item = T>,
    middle: impl Iterator<Item = U> + Clone,
    right: impl Iterator<Item = V> + Clone,
) -> impl Iterator<Item = (T, U, V)> {
    left.flat_map(move |t| {
        let right = right.clone();
        middle.clone().flat_map(move |u| {
            let t = t.clone();
            right.clone().map(move |v| (t.clone(), u.clone(), v))
        })
    })
}