naga/proc/overloads/
scalar_set.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! A set of scalar types, represented as a bitset.

use crate::ir::Scalar;
use crate::proc::overloads::one_bits_iter::OneBitsIter;

macro_rules! define_scalar_set {
    { $( $scalar:ident, )* } => {
        /// An enum used to assign distinct bit numbers to [`ScalarSet`] elements.
        #[expect(non_camel_case_types, clippy::upper_case_acronyms)]
        #[repr(u32)]
        enum ScalarSetBits {
            $( $scalar, )*
            Count,
        }

        /// A table mapping bit numbers to the [`Scalar`] values they represent.
        static SCALARS_FOR_BITS: [Scalar; ScalarSetBits::Count as usize] = [
            $(
                Scalar::$scalar,
            )*
        ];

        bitflags::bitflags! {
            /// A set of scalar types.
            ///
            /// This represents a set of [`Scalar`] types.
            ///
            /// The Naga IR conversion rules arrange scalar types into a
            /// lattice. The scalar types' bit values are chosen such that, if
            /// A is convertible to B, then A's bit value is less than B's.
            #[derive(Copy, Clone, Debug)]
            pub(crate) struct ScalarSet: u16 {
                $(
                    const $scalar = 1 << (ScalarSetBits::$scalar as u32);
                )*
            }
        }

        impl ScalarSet {
            /// Return the set of scalars containing only `scalar`.
            #[expect(dead_code)]
            pub const fn singleton(scalar: Scalar) -> Self {
                match scalar {
                    $(
                        Scalar::$scalar => Self::$scalar,
                    )*
                    _ => Self::empty(),
                }
            }
        }
    }
}

define_scalar_set! {
    // Scalar types must be listed here in an order such that, if A is
    // convertible to B, then A appears before B.
    //
    // In the concrete types, the 32-bit types *must* appear before
    // other sizes, since that is how we represent conversion rank.
    ABSTRACT_INT, ABSTRACT_FLOAT,
    I32, I64,
    U32, U64,
    F32, F16, F64,
    BOOL,
}

impl ScalarSet {
    /// Return the set of scalars to which `scalar` can be automatically
    /// converted.
    pub fn convertible_from(scalar: Scalar) -> Self {
        use Scalar as Sc;
        match scalar {
            Sc::I32 => Self::I32,
            Sc::I64 => Self::I64,
            Sc::U32 => Self::U32,
            Sc::U64 => Self::U64,
            Sc::F16 => Self::F16,
            Sc::F32 => Self::F32,
            Sc::F64 => Self::F64,
            Sc::BOOL => Self::BOOL,
            Sc::ABSTRACT_INT => Self::INTEGER | Self::FLOAT,
            Sc::ABSTRACT_FLOAT => Self::FLOAT,
            _ => Self::empty(),
        }
    }

    /// Return the lowest-ranked member of `self` as a [`Scalar`].
    ///
    /// # Panics
    ///
    /// Panics if `self` is empty.
    pub fn most_general_scalar(self) -> Scalar {
        // If the set is empty, this returns the full bit-length of
        // `self.bits()`, an index which is out of bounds for
        // `SCALARS_FOR_BITS`.
        let lowest = self.bits().trailing_zeros();
        *SCALARS_FOR_BITS.get(lowest as usize).unwrap()
    }

    /// Return an iterator over this set's members.
    ///
    /// Members are produced as singleton, in order from most general to least.
    pub fn members(self) -> impl Iterator<Item = ScalarSet> {
        OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
    }

    pub const FLOAT: Self = Self::ABSTRACT_FLOAT
        .union(Self::F16)
        .union(Self::F32)
        .union(Self::F64);

    pub const INTEGER: Self = Self::ABSTRACT_INT
        .union(Self::I32)
        .union(Self::I64)
        .union(Self::U32)
        .union(Self::U64);

    pub const NUMERIC: Self = Self::FLOAT.union(Self::INTEGER);
    pub const ABSTRACT: Self = Self::ABSTRACT_INT.union(Self::ABSTRACT_FLOAT);
    pub const CONCRETE: Self = Self::all().difference(Self::ABSTRACT);
    pub const CONCRETE_INTEGER: Self = Self::INTEGER.intersection(Self::CONCRETE);
    pub const CONCRETE_FLOAT: Self = Self::FLOAT.intersection(Self::CONCRETE);

    /// Floating-point scalars, with the abstract floats omitted for
    /// #7405.
    pub const FLOAT_ABSTRACT_UNIMPLEMENTED: Self = Self::CONCRETE_FLOAT;
}

macro_rules! scalar_set {
    ( $( $scalar:ident )|* ) => {
        {
            use $crate::proc::overloads::scalar_set::ScalarSet;
            ScalarSet::empty()
                $(
                    .union(ScalarSet::$scalar)
                )*
        }
    }
}

pub(in crate::proc::overloads) use scalar_set;