1use crate::proc::overloads::any_overload_set::AnyOverloadSet;
4use crate::proc::overloads::list::List;
5use crate::proc::overloads::regular::regular;
6use crate::proc::overloads::utils::{
7 concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
8 scalar_or_vecn, triples, vector_sizes,
9};
10use crate::proc::overloads::OverloadSet;
11
12use crate::ir;
13
14impl ir::MathFunction {
15 pub fn overloads(self) -> impl OverloadSet {
16 use ir::MathFunction as Mf;
17
18 let set: AnyOverloadSet = match self {
19 Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
21
22 Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
24
25 Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
27
28 Mf::Sin
30 | Mf::Cos
31 | Mf::Tan
32 | Mf::Asin
33 | Mf::Acos
34 | Mf::Atan
35 | Mf::Sinh
36 | Mf::Cosh
37 | Mf::Tanh
38 | Mf::Asinh
39 | Mf::Acosh
40 | Mf::Atanh
41 | Mf::Saturate
42 | Mf::Radians
43 | Mf::Degrees
44 | Mf::Ceil
45 | Mf::Floor
46 | Mf::Round
47 | Mf::Fract
48 | Mf::Trunc
49 | Mf::Exp
50 | Mf::Exp2
51 | Mf::Log
52 | Mf::Log2
53 | Mf::Sqrt
54 | Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
55
56 Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
58
59 Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
61
62 Mf::CountTrailingZeros
64 | Mf::CountLeadingZeros
65 | Mf::CountOneBits
66 | Mf::ReverseBits
67 | Mf::FirstTrailingBit
68 | Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
69
70 Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
72 Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
73 regular!(1, VEC2 of F32 -> U32).into()
74 }
75 Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
76 Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
77 Mf::Pack4xI8Clamp => regular!(1, VEC4 of I32 -> U32).into(),
78 Mf::Pack4xU8Clamp => regular!(1, VEC4 of U32 -> U32).into(),
79
80 Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
82 Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
83 regular!(1, SCALAR of U32 -> Vec2F).into()
84 }
85 Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
86 Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
87 Mf::Dot4I8Packed => regular!(2, SCALAR of U32 -> I32).into(),
88 Mf::Dot4U8Packed => regular!(2, SCALAR of U32 -> U32).into(),
89
90 Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
92 Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
93 Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
94 Mf::Ldexp => ldexp().into(),
95 Mf::Outer => outer().into(),
96 Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
97 Mf::Distance => {
98 regular!(2, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into()
99 }
100 Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
101 Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
102 Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
103 Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
104 Mf::Refract => refract().into(),
105 Mf::Mix => mix().into(),
106 Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
107 Mf::Transpose => transpose().into(),
108 Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
109 Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
110 Mf::ExtractBits => extract_bits().into(),
111 Mf::InsertBits => insert_bits().into(),
112 };
113
114 set
115 }
116}
117
118fn ldexp() -> List {
119 fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
121 match mantissa.kind {
122 ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
123 ir::ScalarKind::Float => ir::Scalar::I32,
124 _ => unreachable!("not a float scalar"),
125 }
126 }
127
128 list(
129 float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
131 let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
133 scalar_or_vecn(mantissa_scalar)
135 .zip(scalar_or_vecn(exponent_scalar))
136 .map(move |(mantissa, exponent)| {
137 let result = mantissa.clone();
138 rule([mantissa, exponent], result)
139 })
140 }),
141 )
142}
143
144fn outer() -> List {
145 list(
146 triples(
147 vector_sizes(),
148 vector_sizes(),
149 float_scalars_unimplemented_abstract(),
150 )
151 .map(|(cols, rows, scalar)| {
152 let left = ir::TypeInner::Vector { size: cols, scalar };
153 let right = ir::TypeInner::Vector { size: rows, scalar };
154 let result = ir::TypeInner::Matrix {
155 columns: cols,
156 rows,
157 scalar,
158 };
159 rule([left, right], result)
160 }),
161 )
162}
163
164fn refract() -> List {
165 list(
166 pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
167 let incident = ir::TypeInner::Vector { size, scalar };
168 let normal = incident.clone();
169 let ratio = ir::TypeInner::Scalar(scalar);
170 let result = incident.clone();
171 rule([incident, normal, ratio], result)
172 }),
173 )
174}
175
176fn transpose() -> List {
177 list(
178 triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
179 let input = ir::TypeInner::Matrix {
180 columns: a,
181 rows: b,
182 scalar,
183 };
184 let output = ir::TypeInner::Matrix {
185 columns: b,
186 rows: a,
187 scalar,
188 };
189 rule([input], output)
190 }),
191 )
192}
193
194fn extract_bits() -> List {
195 list(concrete_int_scalars().flat_map(|scalar| {
196 scalar_or_vecn(scalar).map(|input| {
197 let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
198 let count = ir::TypeInner::Scalar(ir::Scalar::U32);
199 let output = input.clone();
200 rule([input, offset, count], output)
201 })
202 }))
203}
204
205fn insert_bits() -> List {
206 list(concrete_int_scalars().flat_map(|scalar| {
207 scalar_or_vecn(scalar).map(|input| {
208 let newbits = input.clone();
209 let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
210 let count = ir::TypeInner::Scalar(ir::Scalar::U32);
211 let output = input.clone();
212 rule([input, newbits, offset, count], output)
213 })
214 }))
215}
216
217fn mix() -> List {
218 list(float_scalars().flat_map(|scalar| {
219 scalar_or_vecn(scalar).flat_map(move |input| {
220 let scalar_ratio = ir::TypeInner::Scalar(scalar);
221 [
222 rule([input.clone(), input.clone(), input.clone()], input.clone()),
223 rule([input.clone(), input.clone(), scalar_ratio], input),
224 ]
225 })
226 }))
227}