naga/proc/overloads/
mathfunction.rs

1//! Overload sets for [`ir::MathFunction`].
2
3use crate::proc::overloads::any_overload_set::AnyOverloadSet;
4use crate::proc::overloads::list::List;
5use crate::proc::overloads::regular::regular;
6use crate::proc::overloads::utils::{
7    concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
8    scalar_or_vecn, triples, vector_sizes,
9};
10use crate::proc::overloads::OverloadSet;
11
12use crate::ir;
13
14impl ir::MathFunction {
15    pub fn overloads(self) -> impl OverloadSet {
16        use ir::MathFunction as Mf;
17
18        let set: AnyOverloadSet = match self {
19            // Component-wise unary numeric operations
20            Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
21
22            // Component-wise binary numeric operations
23            Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
24
25            // Component-wise ternary numeric operations
26            Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
27
28            // Component-wise unary floating-point operations
29            Mf::Sin
30            | Mf::Cos
31            | Mf::Tan
32            | Mf::Asin
33            | Mf::Acos
34            | Mf::Atan
35            | Mf::Sinh
36            | Mf::Cosh
37            | Mf::Tanh
38            | Mf::Asinh
39            | Mf::Acosh
40            | Mf::Atanh
41            | Mf::Saturate
42            | Mf::Radians
43            | Mf::Degrees
44            | Mf::Ceil
45            | Mf::Floor
46            | Mf::Round
47            | Mf::Fract
48            | Mf::Trunc
49            | Mf::Exp
50            | Mf::Exp2
51            | Mf::Log
52            | Mf::Log2
53            | Mf::Sqrt
54            | Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
55
56            // Component-wise binary floating-point operations
57            Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
58
59            // Component-wise ternary floating-point operations
60            Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
61
62            // Component-wise unary concrete integer operations
63            Mf::CountTrailingZeros
64            | Mf::CountLeadingZeros
65            | Mf::CountOneBits
66            | Mf::ReverseBits
67            | Mf::FirstTrailingBit
68            | Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
69
70            // Packing functions
71            Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
72            Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
73                regular!(1, VEC2 of F32 -> U32).into()
74            }
75            Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
76            Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
77            Mf::Pack4xI8Clamp => regular!(1, VEC4 of I32 -> U32).into(),
78            Mf::Pack4xU8Clamp => regular!(1, VEC4 of U32 -> U32).into(),
79
80            // Unpacking functions
81            Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
82            Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
83                regular!(1, SCALAR of U32 -> Vec2F).into()
84            }
85            Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
86            Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
87            Mf::Dot4I8Packed => regular!(2, SCALAR of U32 -> I32).into(),
88            Mf::Dot4U8Packed => regular!(2, SCALAR of U32 -> U32).into(),
89
90            // One-off operations
91            Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
92            Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
93            Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
94            Mf::Ldexp => ldexp().into(),
95            Mf::Outer => outer().into(),
96            Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
97            Mf::Distance => {
98                regular!(2, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into()
99            }
100            Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
101            Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
102            Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
103            Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
104            Mf::Refract => refract().into(),
105            Mf::Mix => mix().into(),
106            Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
107            Mf::Transpose => transpose().into(),
108            Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
109            Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
110            Mf::ExtractBits => extract_bits().into(),
111            Mf::InsertBits => insert_bits().into(),
112        };
113
114        set
115    }
116}
117
118fn ldexp() -> List {
119    /// Construct the exponent scalar given the mantissa's inner.
120    fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
121        match mantissa.kind {
122            ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
123            ir::ScalarKind::Float => ir::Scalar::I32,
124            _ => unreachable!("not a float scalar"),
125        }
126    }
127
128    list(
129        // The ldexp mantissa argument can be any floating-point type.
130        float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
131            // The exponent type is the integer counterpart of the mantissa type.
132            let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
133            // There are scalar and vector component-wise overloads.
134            scalar_or_vecn(mantissa_scalar)
135                .zip(scalar_or_vecn(exponent_scalar))
136                .map(move |(mantissa, exponent)| {
137                    let result = mantissa.clone();
138                    rule([mantissa, exponent], result)
139                })
140        }),
141    )
142}
143
144fn outer() -> List {
145    list(
146        triples(
147            vector_sizes(),
148            vector_sizes(),
149            float_scalars_unimplemented_abstract(),
150        )
151        .map(|(cols, rows, scalar)| {
152            let left = ir::TypeInner::Vector { size: cols, scalar };
153            let right = ir::TypeInner::Vector { size: rows, scalar };
154            let result = ir::TypeInner::Matrix {
155                columns: cols,
156                rows,
157                scalar,
158            };
159            rule([left, right], result)
160        }),
161    )
162}
163
164fn refract() -> List {
165    list(
166        pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
167            let incident = ir::TypeInner::Vector { size, scalar };
168            let normal = incident.clone();
169            let ratio = ir::TypeInner::Scalar(scalar);
170            let result = incident.clone();
171            rule([incident, normal, ratio], result)
172        }),
173    )
174}
175
176fn transpose() -> List {
177    list(
178        triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
179            let input = ir::TypeInner::Matrix {
180                columns: a,
181                rows: b,
182                scalar,
183            };
184            let output = ir::TypeInner::Matrix {
185                columns: b,
186                rows: a,
187                scalar,
188            };
189            rule([input], output)
190        }),
191    )
192}
193
194fn extract_bits() -> List {
195    list(concrete_int_scalars().flat_map(|scalar| {
196        scalar_or_vecn(scalar).map(|input| {
197            let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
198            let count = ir::TypeInner::Scalar(ir::Scalar::U32);
199            let output = input.clone();
200            rule([input, offset, count], output)
201        })
202    }))
203}
204
205fn insert_bits() -> List {
206    list(concrete_int_scalars().flat_map(|scalar| {
207        scalar_or_vecn(scalar).map(|input| {
208            let newbits = input.clone();
209            let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
210            let count = ir::TypeInner::Scalar(ir::Scalar::U32);
211            let output = input.clone();
212            rule([input, newbits, offset, count], output)
213        })
214    }))
215}
216
217fn mix() -> List {
218    list(float_scalars().flat_map(|scalar| {
219        scalar_or_vecn(scalar).flat_map(move |input| {
220            let scalar_ratio = ir::TypeInner::Scalar(scalar);
221            [
222                rule([input.clone(), input.clone(), input.clone()], input.clone()),
223                rule([input.clone(), input.clone(), scalar_ratio], input),
224            ]
225        })
226    }))
227}