use crate::proc::overloads::any_overload_set::AnyOverloadSet;
use crate::proc::overloads::list::List;
use crate::proc::overloads::regular::regular;
use crate::proc::overloads::utils::{
concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
scalar_or_vecn, triples, vector_sizes,
};
use crate::proc::overloads::OverloadSet;
use crate::ir;
impl ir::MathFunction {
pub fn overloads(self) -> impl OverloadSet {
use ir::MathFunction as Mf;
let set: AnyOverloadSet = match self {
Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
Mf::Sin
| Mf::Cos
| Mf::Tan
| Mf::Asin
| Mf::Acos
| Mf::Atan
| Mf::Sinh
| Mf::Cosh
| Mf::Tanh
| Mf::Asinh
| Mf::Acosh
| Mf::Atanh
| Mf::Saturate
| Mf::Radians
| Mf::Degrees
| Mf::Ceil
| Mf::Floor
| Mf::Round
| Mf::Fract
| Mf::Trunc
| Mf::Exp
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Sqrt
| Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
Mf::CountTrailingZeros
| Mf::CountLeadingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::FirstTrailingBit
| Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
regular!(1, VEC2 of F32 -> U32).into()
}
Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
regular!(1, SCALAR of U32 -> Vec2F).into()
}
Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
Mf::Ldexp => ldexp().into(),
Mf::Outer => outer().into(),
Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
Mf::Distance => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::Refract => refract().into(),
Mf::Mix => mix().into(),
Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
Mf::Transpose => transpose().into(),
Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
Mf::ExtractBits => extract_bits().into(),
Mf::InsertBits => insert_bits().into(),
};
set
}
}
fn ldexp() -> List {
fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
match mantissa.kind {
ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
ir::ScalarKind::Float => ir::Scalar::I32,
_ => unreachable!("not a float scalar"),
}
}
list(
float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
scalar_or_vecn(mantissa_scalar)
.zip(scalar_or_vecn(exponent_scalar))
.map(move |(mantissa, exponent)| {
let result = mantissa.clone();
rule([mantissa, exponent], result)
})
}),
)
}
fn outer() -> List {
list(
triples(
vector_sizes(),
vector_sizes(),
float_scalars_unimplemented_abstract(),
)
.map(|(cols, rows, scalar)| {
let left = ir::TypeInner::Vector { size: cols, scalar };
let right = ir::TypeInner::Vector { size: rows, scalar };
let result = ir::TypeInner::Matrix {
columns: cols,
rows,
scalar,
};
rule([left, right], result)
}),
)
}
fn refract() -> List {
list(
pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
let incident = ir::TypeInner::Vector { size, scalar };
let normal = incident.clone();
let ratio = ir::TypeInner::Scalar(scalar);
let result = incident.clone();
rule([incident, normal, ratio], result)
}),
)
}
fn transpose() -> List {
list(
triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
let input = ir::TypeInner::Matrix {
columns: a,
rows: b,
scalar,
};
let output = ir::TypeInner::Matrix {
columns: b,
rows: a,
scalar,
};
rule([input], output)
}),
)
}
fn extract_bits() -> List {
list(concrete_int_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).map(|input| {
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
let output = input.clone();
rule([input, offset, count], output)
})
}))
}
fn insert_bits() -> List {
list(concrete_int_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).map(|input| {
let newbits = input.clone();
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
let output = input.clone();
rule([input, newbits, offset, count], output)
})
}))
}
fn mix() -> List {
list(float_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).flat_map(move |input| {
let scalar_ratio = ir::TypeInner::Scalar(scalar);
[
rule([input.clone(), input.clone(), input.clone()], input.clone()),
rule([input.clone(), input.clone(), scalar_ratio], input),
]
})
}))
}