use crate::common::{DiagnosticDebug, ForDebugWithTypes};
use crate::ir;
use crate::proc::overloads::constructor_set::{ConstructorSet, ConstructorSize};
use crate::proc::overloads::rule::{Conclusion, Rule};
use crate::proc::overloads::scalar_set::ScalarSet;
use crate::proc::overloads::OverloadSet;
use crate::proc::{GlobalCtx, TypeResolution};
use crate::UniqueArena;
use alloc::vec::Vec;
use core::fmt;
#[derive(Clone)]
pub(in crate::proc::overloads) struct Regular {
pub arity: usize,
pub constructors: ConstructorSet,
pub scalars: ScalarSet,
pub conclude: ConclusionRule,
}
impl Regular {
pub(in crate::proc::overloads) const EMPTY: Regular = Regular {
arity: 0,
constructors: ConstructorSet::empty(),
scalars: ScalarSet::empty(),
conclude: ConclusionRule::ArgumentType,
};
fn members(&self) -> impl Iterator<Item = (ConstructorSize, ir::Scalar)> {
let scalars = self.scalars;
self.constructors.members().flat_map(move |constructor| {
let size = constructor.size();
scalars
.members()
.map(move |singleton| (size, singleton.most_general_scalar()))
})
}
fn rules(&self) -> impl Iterator<Item = Rule> {
let arity = self.arity;
let conclude = self.conclude;
self.members()
.map(move |(size, scalar)| make_rule(arity, size, scalar, conclude))
}
}
impl OverloadSet for Regular {
fn is_empty(&self) -> bool {
self.constructors.is_empty() || self.scalars.is_empty()
}
fn min_arguments(&self) -> usize {
assert!(!self.is_empty());
self.arity
}
fn max_arguments(&self) -> usize {
assert!(!self.is_empty());
self.arity
}
fn arg(&self, i: usize, ty: &ir::TypeInner, types: &UniqueArena<ir::Type>) -> Self {
if i >= self.arity {
return Self::EMPTY;
}
let constructor = ConstructorSet::singleton(ty);
let scalars = match ty.scalar_for_conversions(types) {
Some(ty_scalar) => ScalarSet::convertible_from(ty_scalar),
None => ScalarSet::empty(),
};
Self {
arity: self.arity,
constructors: self.constructors & constructor,
scalars: self.scalars & scalars,
conclude: self.conclude,
}
}
fn concrete_only(self, _types: &UniqueArena<ir::Type>) -> Self {
Self {
scalars: self.scalars & ScalarSet::CONCRETE,
..self
}
}
fn most_preferred(&self) -> Rule {
assert!(!self.is_empty());
assert!(self.constructors.is_singleton());
let size = self.constructors.size();
let scalar = self.scalars.most_general_scalar();
make_rule(self.arity, size, scalar, self.conclude)
}
fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
self.rules().collect()
}
fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
if i >= self.arity {
return Vec::new();
}
self.members()
.map(|(size, scalar)| TypeResolution::Value(size.to_inner(scalar)))
.collect()
}
fn for_debug(&self, types: &UniqueArena<ir::Type>) -> impl fmt::Debug {
DiagnosticDebug((self, types))
}
}
fn make_rule(
arity: usize,
size: ConstructorSize,
scalar: ir::Scalar,
conclusion_rule: ConclusionRule,
) -> Rule {
let inner = size.to_inner(scalar);
let arg = TypeResolution::Value(inner.clone());
Rule {
arguments: core::iter::repeat(arg.clone()).take(arity).collect(),
conclusion: conclusion_rule.conclude(size, scalar),
}
}
#[derive(Clone, Copy, Debug)]
#[repr(u8)]
pub(in crate::proc::overloads) enum ConclusionRule {
ArgumentType,
Scalar,
Frexp,
Modf,
U32,
Vec2F,
Vec4F,
Vec4I,
Vec4U,
}
impl ConclusionRule {
fn conclude(self, size: ConstructorSize, scalar: ir::Scalar) -> Conclusion {
match self {
Self::ArgumentType => Conclusion::Value(size.to_inner(scalar)),
Self::Scalar => Conclusion::Value(ir::TypeInner::Scalar(scalar)),
Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Bi,
scalar: ir::Scalar::F32,
}),
Self::Vec4F => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::F32,
}),
Self::Vec4I => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::I32,
}),
Self::Vec4U => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::U32,
}),
}
}
}
impl fmt::Debug for DiagnosticDebug<(&Regular, &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (regular, types) = self.0;
let rules: Vec<Rule> = regular.rules().collect();
f.debug_struct("List")
.field("rules", &rules.for_debug(types))
.field("conclude", ®ular.conclude)
.finish()
}
}
impl ForDebugWithTypes for &Regular {}
impl fmt::Debug for DiagnosticDebug<(&[Rule], &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (rules, types) = self.0;
f.debug_list()
.entries(rules.iter().map(|rule| rule.for_debug(types)))
.finish()
}
}
impl ForDebugWithTypes for &[Rule] {}
macro_rules! regular {
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|*) => {
{
use $crate::proc::overloads;
use overloads::constructor_set::constructor_set;
use overloads::regular::{Regular, ConclusionRule};
use overloads::scalar_set::scalar_set;
Regular {
arity: $arity,
constructors: constructor_set!( $( $constr )|* ),
scalars: scalar_set!( $( $scalar )|* ),
conclude: ConclusionRule::ArgumentType,
}
}
};
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|* -> $conclude:ident) => {
{
use $crate::proc::overloads;
use overloads::constructor_set::constructor_set;
use overloads::regular::{Regular, ConclusionRule};
use overloads::scalar_set::scalar_set;
Regular {
arity: $arity,
constructors:constructor_set!( $( $constr )|* ),
scalars: scalar_set!( $( $scalar )|* ),
conclude: ConclusionRule::$conclude,
}
}
};
}
pub(in crate::proc::overloads) use regular;
#[cfg(test)]
mod test {
use super::*;
use crate::ir;
const fn scalar(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Scalar(scalar)
}
const fn vec2(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Vector {
scalar,
size: ir::VectorSize::Bi,
}
}
const fn vec3(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Vector {
scalar,
size: ir::VectorSize::Tri,
}
}
#[track_caller]
fn check_return_type(set: &Regular, expected: &ir::TypeInner, arena: &UniqueArena<ir::Type>) {
assert!(!set.is_empty());
let special_types = ir::SpecialTypes::default();
let preferred = set.most_preferred();
let conclusion = preferred.conclusion;
let resolution = conclusion
.into_resolution(&special_types)
.expect("special types should have been pre-registered");
let inner = resolution.inner_with(arena);
assert!(
inner.equivalent(expected, arena),
"Expected {:?}, got {:?}",
expected.for_debug(arena),
inner.for_debug(arena),
);
}
#[test]
fn unary_vec_or_scalar_numeric_scalar() {
let arena = UniqueArena::default();
let builtin = regular!(1, SCALAR of NUMERIC);
let ok = builtin.arg(0, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let err = builtin.arg(0, &scalar(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
}
#[test]
fn unary_vec_or_scalar_numeric_vector() {
let arena = UniqueArena::default();
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
let ok = builtin.arg(0, &vec3(ir::Scalar::F64), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F64), &arena);
let err = builtin.arg(0, &vec3(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
}
#[test]
fn unary_vec_or_scalar_numeric_matrix() {
let arena = UniqueArena::default();
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
let err = builtin.arg(
0,
&ir::TypeInner::Matrix {
columns: ir::VectorSize::Tri,
rows: ir::VectorSize::Tri,
scalar: ir::Scalar::F32,
},
&arena,
);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_scalar() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::U32), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::U32), &arena)
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let err = builtin
.arg(0, &scalar(ir::Scalar::BOOL), &arena)
.arg(1, &scalar(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::F64), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &vec2(ir::Scalar::F32), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &vec2(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_vector() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &vec3(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
let err = builtin
.arg(0, &vec2(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &vec3(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F64), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_vector_abstract() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &vec2(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &vec2(ir::Scalar::U32), &arena);
check_return_type(&ok, &vec2(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &vec3(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let err = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &scalar(ir::Scalar::I32), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
assert!(err.is_empty());
}
}