naga/proc/overloads/
scalar_set.rs

1//! A set of scalar types, represented as a bitset.
2
3use crate::ir::Scalar;
4use crate::proc::overloads::one_bits_iter::OneBitsIter;
5
6macro_rules! define_scalar_set {
7    { $( $scalar:ident, )* } => {
8        /// An enum used to assign distinct bit numbers to [`ScalarSet`] elements.
9        #[expect(non_camel_case_types, clippy::upper_case_acronyms)]
10        #[repr(u32)]
11        enum ScalarSetBits {
12            $( $scalar, )*
13            Count,
14        }
15
16        /// A table mapping bit numbers to the [`Scalar`] values they represent.
17        static SCALARS_FOR_BITS: [Scalar; ScalarSetBits::Count as usize] = [
18            $(
19                Scalar::$scalar,
20            )*
21        ];
22
23        bitflags::bitflags! {
24            /// A set of scalar types.
25            ///
26            /// This represents a set of [`Scalar`] types.
27            ///
28            /// The Naga IR conversion rules arrange scalar types into a
29            /// lattice. The scalar types' bit values are chosen such that, if
30            /// A is convertible to B, then A's bit value is less than B's.
31            #[derive(Copy, Clone, Debug)]
32            pub(crate) struct ScalarSet: u16 {
33                $(
34                    const $scalar = 1 << (ScalarSetBits::$scalar as u32);
35                )*
36            }
37        }
38
39        impl ScalarSet {
40            /// Return the set of scalars containing only `scalar`.
41            #[expect(dead_code)]
42            pub const fn singleton(scalar: Scalar) -> Self {
43                match scalar {
44                    $(
45                        Scalar::$scalar => Self::$scalar,
46                    )*
47                    _ => Self::empty(),
48                }
49            }
50        }
51    }
52}
53
54define_scalar_set! {
55    // Scalar types must be listed here in an order such that, if A is
56    // convertible to B, then A appears before B.
57    //
58    // In the concrete types, the 32-bit types *must* appear before
59    // other sizes, since that is how we represent conversion rank.
60    ABSTRACT_INT, ABSTRACT_FLOAT,
61    I32, I64,
62    U32, U64,
63    F32, F16, F64,
64    BOOL,
65}
66
67impl ScalarSet {
68    /// Return the set of scalars to which `scalar` can be automatically
69    /// converted.
70    pub fn convertible_from(scalar: Scalar) -> Self {
71        use Scalar as Sc;
72        match scalar {
73            Sc::I32 => Self::I32,
74            Sc::I64 => Self::I64,
75            Sc::U32 => Self::U32,
76            Sc::U64 => Self::U64,
77            Sc::F16 => Self::F16,
78            Sc::F32 => Self::F32,
79            Sc::F64 => Self::F64,
80            Sc::BOOL => Self::BOOL,
81            Sc::ABSTRACT_INT => Self::INTEGER | Self::FLOAT,
82            Sc::ABSTRACT_FLOAT => Self::FLOAT,
83            _ => Self::empty(),
84        }
85    }
86
87    /// Return the lowest-ranked member of `self` as a [`Scalar`].
88    ///
89    /// # Panics
90    ///
91    /// Panics if `self` is empty.
92    pub fn most_general_scalar(self) -> Scalar {
93        // If the set is empty, this returns the full bit-length of
94        // `self.bits()`, an index which is out of bounds for
95        // `SCALARS_FOR_BITS`.
96        let lowest = self.bits().trailing_zeros();
97        *SCALARS_FOR_BITS.get(lowest as usize).unwrap()
98    }
99
100    /// Return an iterator over this set's members.
101    ///
102    /// Members are produced as singleton, in order from most general to least.
103    pub fn members(self) -> impl Iterator<Item = ScalarSet> {
104        OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
105    }
106
107    pub const FLOAT: Self = Self::ABSTRACT_FLOAT
108        .union(Self::F16)
109        .union(Self::F32)
110        .union(Self::F64);
111
112    pub const INTEGER: Self = Self::ABSTRACT_INT
113        .union(Self::I32)
114        .union(Self::I64)
115        .union(Self::U32)
116        .union(Self::U64);
117
118    pub const NUMERIC: Self = Self::FLOAT.union(Self::INTEGER);
119    pub const ABSTRACT: Self = Self::ABSTRACT_INT.union(Self::ABSTRACT_FLOAT);
120    pub const CONCRETE: Self = Self::all().difference(Self::ABSTRACT);
121    pub const CONCRETE_INTEGER: Self = Self::INTEGER.intersection(Self::CONCRETE);
122    pub const CONCRETE_FLOAT: Self = Self::FLOAT.intersection(Self::CONCRETE);
123
124    /// Floating-point scalars, with the abstract floats omitted for
125    /// #7405.
126    pub const FLOAT_ABSTRACT_UNIMPLEMENTED: Self = Self::CONCRETE_FLOAT;
127}
128
129macro_rules! scalar_set {
130    ( $( $scalar:ident )|* ) => {
131        {
132            use $crate::proc::overloads::scalar_set::ScalarSet;
133            ScalarSet::empty()
134                $(
135                    .union(ScalarSet::$scalar)
136                )*
137        }
138    }
139}
140
141pub(in crate::proc::overloads) use scalar_set;