naga/proc/overloads/
regular.rs

1/*! A representation for highly regular overload sets common in Naga IR.
2
3Many Naga builtin functions' overload sets have a highly regular
4structure. For example, many arithmetic functions can be applied to
5any floating-point type, or any vector thereof. This module defines a
6handful of types for representing such simple overload sets that is
7simple and efficient.
8
9*/
10
11use crate::common::{DiagnosticDebug, ForDebugWithTypes};
12use crate::ir;
13use crate::proc::overloads::constructor_set::{ConstructorSet, ConstructorSize};
14use crate::proc::overloads::rule::{Conclusion, Rule};
15use crate::proc::overloads::scalar_set::ScalarSet;
16use crate::proc::overloads::OverloadSet;
17use crate::proc::{GlobalCtx, TypeResolution};
18use crate::UniqueArena;
19
20use alloc::vec::Vec;
21use core::fmt;
22
23/// Overload sets represented as sets of scalars and constructors.
24///
25/// This type represents an [`OverloadSet`] using a bitset of scalar
26/// types and a bitset of type constructors that might be applied to
27/// those scalars. The overload set contains a rule for every possible
28/// combination of scalars and constructors, essentially the cartesian
29/// product of the two sets.
30///
31/// For example, if the arity is 2, set of scalars is { AbstractFloat,
32/// `f32` }, and the set of constructors is { `vec2`, `vec3` }, then
33/// that represents the set of overloads:
34///
35/// - (`vec2<AbstractFloat>`, `vec2<AbstractFloat>`) -> `vec2<AbstractFloat>`
36/// - (`vec2<f32>`, `vec2<f32>`) -> `vec2<f32>`
37/// - (`vec3<AbstractFloat>`, `vec3<AbstractFloat>`) -> `vec3<AbstractFloat>`
38/// - (`vec3<f32>`, `vec3<f32>`) -> `vec3<f32>`
39///
40/// The `conclude` value says how to determine the return type from
41/// the argument type.
42///
43/// Restrictions:
44///
45/// - All overloads must take the same number of arguments.
46///
47/// - For any given overload, all its arguments must have the same
48///   type.
49#[derive(Clone)]
50pub(in crate::proc::overloads) struct Regular {
51    /// The number of arguments in the rules.
52    pub arity: usize,
53
54    /// The set of type constructors to apply.
55    pub constructors: ConstructorSet,
56
57    /// The set of scalars to apply them to.
58    pub scalars: ScalarSet,
59
60    /// How to determine a member rule's return type given the type of
61    /// its arguments.
62    pub conclude: ConclusionRule,
63}
64
65impl Regular {
66    pub(in crate::proc::overloads) const EMPTY: Regular = Regular {
67        arity: 0,
68        constructors: ConstructorSet::empty(),
69        scalars: ScalarSet::empty(),
70        conclude: ConclusionRule::ArgumentType,
71    };
72
73    /// Return an iterator over all the argument types allowed by `self`.
74    ///
75    /// Return an iterator that produces, for each overload in `self`, the
76    /// constructor and scalar of its argument types and return type.
77    ///
78    /// A [`Regular`] value can only represent overload sets where, in
79    /// each overload, all the arguments have the same type, and the
80    /// return type is always going to be a determined by the argument
81    /// types, so giving the constructor and scalar is sufficient to
82    /// characterize the entire rule.
83    fn members(&self) -> impl Iterator<Item = (ConstructorSize, ir::Scalar)> {
84        let scalars = self.scalars;
85        self.constructors.members().flat_map(move |constructor| {
86            let size = constructor.size();
87            // Technically, we don't need the "most general" `TypeInner` here,
88            // but since `ScalarSet::members` only produces singletons anyway,
89            // the effect is the same.
90            scalars
91                .members()
92                .map(move |singleton| (size, singleton.most_general_scalar()))
93        })
94    }
95
96    fn rules(&self) -> impl Iterator<Item = Rule> {
97        let arity = self.arity;
98        let conclude = self.conclude;
99        self.members()
100            .map(move |(size, scalar)| make_rule(arity, size, scalar, conclude))
101    }
102}
103
104impl OverloadSet for Regular {
105    fn is_empty(&self) -> bool {
106        self.constructors.is_empty() || self.scalars.is_empty()
107    }
108
109    fn min_arguments(&self) -> usize {
110        assert!(!self.is_empty());
111        self.arity
112    }
113
114    fn max_arguments(&self) -> usize {
115        assert!(!self.is_empty());
116        self.arity
117    }
118
119    fn arg(&self, i: usize, ty: &ir::TypeInner, types: &UniqueArena<ir::Type>) -> Self {
120        if i >= self.arity {
121            return Self::EMPTY;
122        }
123
124        let constructor = ConstructorSet::singleton(ty);
125
126        let scalars = match ty.scalar_for_conversions(types) {
127            Some(ty_scalar) => ScalarSet::convertible_from(ty_scalar),
128            None => ScalarSet::empty(),
129        };
130
131        Self {
132            arity: self.arity,
133
134            // Constrain all member rules' constructors to match `ty`'s.
135            constructors: self.constructors & constructor,
136
137            // Constrain all member rules' arguments to be something
138            // that `ty` can be converted to.
139            scalars: self.scalars & scalars,
140
141            conclude: self.conclude,
142        }
143    }
144
145    fn concrete_only(self, _types: &UniqueArena<ir::Type>) -> Self {
146        Self {
147            scalars: self.scalars & ScalarSet::CONCRETE,
148            ..self
149        }
150    }
151
152    fn most_preferred(&self) -> Rule {
153        assert!(!self.is_empty());
154
155        // If there is more than one constructor allowed, then we must
156        // not have had any arguments supplied at all. In any case, we
157        // don't have any unambiguously preferred candidate.
158        assert!(self.constructors.is_singleton());
159
160        let size = self.constructors.size();
161        let scalar = self.scalars.most_general_scalar();
162        make_rule(self.arity, size, scalar, self.conclude)
163    }
164
165    fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
166        self.rules().collect()
167    }
168
169    fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
170        if i >= self.arity {
171            return Vec::new();
172        }
173        self.members()
174            .map(|(size, scalar)| TypeResolution::Value(size.to_inner(scalar)))
175            .collect()
176    }
177
178    fn for_debug(&self, types: &UniqueArena<ir::Type>) -> impl fmt::Debug {
179        DiagnosticDebug((self, types))
180    }
181}
182
183/// Construct a [`Regular`] member [`Rule`] for the given arity and type.
184///
185/// [`Regular`] can only represent rules where all the argument types and the
186/// return type are the same, so just knowing `arity` and `inner` is sufficient.
187///
188/// [`Rule`]: crate::proc::overloads::Rule
189fn make_rule(
190    arity: usize,
191    size: ConstructorSize,
192    scalar: ir::Scalar,
193    conclusion_rule: ConclusionRule,
194) -> Rule {
195    let inner = size.to_inner(scalar);
196    let arg = TypeResolution::Value(inner.clone());
197    Rule {
198        arguments: core::iter::repeat_n(arg.clone(), arity).collect(),
199        conclusion: conclusion_rule.conclude(size, scalar),
200    }
201}
202
203/// Conclusion-computing rules.
204#[derive(Clone, Copy, Debug)]
205#[repr(u8)]
206pub(in crate::proc::overloads) enum ConclusionRule {
207    ArgumentType,
208    Scalar,
209    Frexp,
210    Modf,
211    U32,
212    I32,
213    Vec2F,
214    Vec4F,
215    Vec4I,
216    Vec4U,
217}
218
219impl ConclusionRule {
220    fn conclude(self, size: ConstructorSize, scalar: ir::Scalar) -> Conclusion {
221        match self {
222            Self::ArgumentType => Conclusion::Value(size.to_inner(scalar)),
223            Self::Scalar => Conclusion::Value(ir::TypeInner::Scalar(scalar)),
224            Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
225            Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
226            Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
227            Self::I32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::I32)),
228            Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
229                size: ir::VectorSize::Bi,
230                scalar: ir::Scalar::F32,
231            }),
232            Self::Vec4F => Conclusion::Value(ir::TypeInner::Vector {
233                size: ir::VectorSize::Quad,
234                scalar: ir::Scalar::F32,
235            }),
236            Self::Vec4I => Conclusion::Value(ir::TypeInner::Vector {
237                size: ir::VectorSize::Quad,
238                scalar: ir::Scalar::I32,
239            }),
240            Self::Vec4U => Conclusion::Value(ir::TypeInner::Vector {
241                size: ir::VectorSize::Quad,
242                scalar: ir::Scalar::U32,
243            }),
244        }
245    }
246}
247
248impl fmt::Debug for DiagnosticDebug<(&Regular, &UniqueArena<ir::Type>)> {
249    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250        let (regular, types) = self.0;
251        let rules: Vec<Rule> = regular.rules().collect();
252        f.debug_struct("List")
253            .field("rules", &rules.for_debug(types))
254            .field("conclude", &regular.conclude)
255            .finish()
256    }
257}
258
259impl ForDebugWithTypes for &Regular {}
260
261impl fmt::Debug for DiagnosticDebug<(&[Rule], &UniqueArena<ir::Type>)> {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        let (rules, types) = self.0;
264        f.debug_list()
265            .entries(rules.iter().map(|rule| rule.for_debug(types)))
266            .finish()
267    }
268}
269
270impl ForDebugWithTypes for &[Rule] {}
271
272/// Construct a [`Regular`] [`OverloadSet`].
273///
274/// Examples:
275///
276/// - `regular!(2, SCALAR|VECN of FLOAT)`: An overload set whose rules take two
277///   arguments of the same type: a floating-point scalar (possibly abstract) or
278///   a vector of such. The return type is the same as the argument type.
279///
280/// - `regular!(1, VECN of FLOAT -> Scalar)`: An overload set whose rules take
281///   one argument that is a vector of floats, and whose return type is the leaf
282///   scalar type of the argument type.
283///
284/// The constructor values (before the `<` angle brackets `>`) are
285/// constants from [`ConstructorSet`].
286///
287/// The scalar values (inside the `<` angle brackets `>`) are
288/// constants from [`ScalarSet`].
289///
290/// When a return type identifier is given, it is treated as a variant
291/// of the the [`ConclusionRule`] enum.
292macro_rules! regular {
293    // regular!(ARITY, CONSTRUCTOR of SCALAR)
294    ( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|*) => {
295        {
296            use $crate::proc::overloads;
297            use overloads::constructor_set::constructor_set;
298            use overloads::regular::{Regular, ConclusionRule};
299            use overloads::scalar_set::scalar_set;
300            Regular {
301                arity: $arity,
302                constructors: constructor_set!( $( $constr )|* ),
303                scalars: scalar_set!( $( $scalar )|* ),
304                conclude: ConclusionRule::ArgumentType,
305            }
306        }
307    };
308
309    // regular!(ARITY, CONSTRUCTOR of SCALAR -> CONCLUSION_RULE)
310    ( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|* -> $conclude:ident) => {
311        {
312            use $crate::proc::overloads;
313            use overloads::constructor_set::constructor_set;
314            use overloads::regular::{Regular, ConclusionRule};
315            use overloads::scalar_set::scalar_set;
316            Regular {
317                arity: $arity,
318                constructors:constructor_set!( $( $constr )|* ),
319                scalars: scalar_set!( $( $scalar )|* ),
320                conclude: ConclusionRule::$conclude,
321            }
322        }
323    };
324}
325
326pub(in crate::proc::overloads) use regular;
327
328#[cfg(test)]
329mod test {
330    use super::*;
331    use crate::ir;
332
333    const fn scalar(scalar: ir::Scalar) -> ir::TypeInner {
334        ir::TypeInner::Scalar(scalar)
335    }
336
337    const fn vec2(scalar: ir::Scalar) -> ir::TypeInner {
338        ir::TypeInner::Vector {
339            scalar,
340            size: ir::VectorSize::Bi,
341        }
342    }
343
344    const fn vec3(scalar: ir::Scalar) -> ir::TypeInner {
345        ir::TypeInner::Vector {
346            scalar,
347            size: ir::VectorSize::Tri,
348        }
349    }
350
351    /// Assert that `set` has a most preferred candidate whose type
352    /// conclusion is `expected`.
353    #[track_caller]
354    fn check_return_type(set: &Regular, expected: &ir::TypeInner, arena: &UniqueArena<ir::Type>) {
355        assert!(!set.is_empty());
356
357        let special_types = ir::SpecialTypes::default();
358
359        let preferred = set.most_preferred();
360        let conclusion = preferred.conclusion;
361        let resolution = conclusion
362            .into_resolution(&special_types)
363            .expect("special types should have been pre-registered");
364        let inner = resolution.inner_with(arena);
365
366        assert!(
367            inner.non_struct_equivalent(expected, arena),
368            "Expected {:?}, got {:?}",
369            expected.for_debug(arena),
370            inner.for_debug(arena),
371        );
372    }
373
374    #[test]
375    fn unary_vec_or_scalar_numeric_scalar() {
376        let arena = UniqueArena::default();
377
378        let builtin = regular!(1, SCALAR of NUMERIC);
379
380        let ok = builtin.arg(0, &scalar(ir::Scalar::U32), &arena);
381        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
382
383        let err = builtin.arg(0, &scalar(ir::Scalar::BOOL), &arena);
384        assert!(err.is_empty());
385    }
386
387    #[test]
388    fn unary_vec_or_scalar_numeric_vector() {
389        let arena = UniqueArena::default();
390
391        let builtin = regular!(1, VECN|SCALAR of NUMERIC);
392
393        let ok = builtin.arg(0, &vec3(ir::Scalar::F64), &arena);
394        check_return_type(&ok, &vec3(ir::Scalar::F64), &arena);
395
396        let err = builtin.arg(0, &vec3(ir::Scalar::BOOL), &arena);
397        assert!(err.is_empty());
398    }
399
400    #[test]
401    fn unary_vec_or_scalar_numeric_matrix() {
402        let arena = UniqueArena::default();
403
404        let builtin = regular!(1, VECN|SCALAR of NUMERIC);
405
406        let err = builtin.arg(
407            0,
408            &ir::TypeInner::Matrix {
409                columns: ir::VectorSize::Tri,
410                rows: ir::VectorSize::Tri,
411                scalar: ir::Scalar::F32,
412            },
413            &arena,
414        );
415        assert!(err.is_empty());
416    }
417
418    #[test]
419    #[rustfmt::skip]
420    fn binary_vec_or_scalar_numeric_scalar() {
421        let arena = UniqueArena::default();
422
423        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
424
425        let ok = builtin
426            .arg(0, &scalar(ir::Scalar::F32), &arena)
427            .arg(1, &scalar(ir::Scalar::F32), &arena);
428        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
429
430        let ok = builtin
431            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
432            .arg(1, &scalar(ir::Scalar::F32), &arena);
433        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
434
435        let ok = builtin
436            .arg(0, &scalar(ir::Scalar::F32), &arena)
437            .arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
438        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
439
440        let ok = builtin
441            .arg(0, &scalar(ir::Scalar::U32), &arena)
442            .arg(1, &scalar(ir::Scalar::U32), &arena);
443        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
444
445        let ok = builtin
446            .arg(0, &scalar(ir::Scalar::U32), &arena)
447            .arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
448        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
449
450        let ok = builtin
451            .arg(0, &scalar(ir::Scalar::ABSTRACT_INT), &arena)
452            .arg(1, &scalar(ir::Scalar::U32), &arena);
453        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
454
455        // Not numeric.
456        let err = builtin
457            .arg(0, &scalar(ir::Scalar::BOOL), &arena)
458            .arg(1, &scalar(ir::Scalar::BOOL), &arena);
459        assert!(err.is_empty());
460
461        // Different floating-point types.
462        let err = builtin
463            .arg(0, &scalar(ir::Scalar::F32), &arena)
464            .arg(1, &scalar(ir::Scalar::F64), &arena);
465        assert!(err.is_empty());
466
467        // Different constructor.
468        let err = builtin
469            .arg(0, &scalar(ir::Scalar::F32), &arena)
470            .arg(1, &vec2(ir::Scalar::F32), &arena);
471        assert!(err.is_empty());
472
473        // Different vector size
474        let err = builtin
475            .arg(0, &vec2(ir::Scalar::F32), &arena)
476            .arg(1, &vec3(ir::Scalar::F32), &arena);
477        assert!(err.is_empty());
478    }
479
480    #[test]
481    #[rustfmt::skip]
482    fn binary_vec_or_scalar_numeric_vector() {
483        let arena = UniqueArena::default();
484
485        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
486
487        let ok = builtin
488            .arg(0, &vec3(ir::Scalar::F32), &arena)
489            .arg(1, &vec3(ir::Scalar::F32), &arena);
490        check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
491
492        // Different vector sizes.
493        let err = builtin
494            .arg(0, &vec2(ir::Scalar::F32), &arena)
495            .arg(1, &vec3(ir::Scalar::F32), &arena);
496        assert!(err.is_empty());
497
498        // Different vector scalars.
499        let err = builtin
500            .arg(0, &vec3(ir::Scalar::F32), &arena)
501            .arg(1, &vec3(ir::Scalar::F64), &arena);
502        assert!(err.is_empty());
503
504        // Mix of vectors and scalars.
505        let err = builtin
506            .arg(0, &scalar(ir::Scalar::F32), &arena)
507            .arg(1, &vec3(ir::Scalar::F32), &arena);
508        assert!(err.is_empty());
509    }
510
511    #[test]
512    #[rustfmt::skip]
513    fn binary_vec_or_scalar_numeric_vector_abstract() {
514        let arena = UniqueArena::default();
515
516        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
517
518        let ok = builtin
519            .arg(0, &vec2(ir::Scalar::ABSTRACT_INT), &arena)
520            .arg(1, &vec2(ir::Scalar::U32), &arena);
521        check_return_type(&ok, &vec2(ir::Scalar::U32), &arena);
522
523        let ok = builtin
524            .arg(0, &vec3(ir::Scalar::ABSTRACT_INT), &arena)
525            .arg(1, &vec3(ir::Scalar::F32), &arena);
526        check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
527
528        let ok = builtin
529            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
530            .arg(1, &scalar(ir::Scalar::F32), &arena);
531        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
532
533        let err = builtin
534            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
535            .arg(1, &scalar(ir::Scalar::U32), &arena);
536        assert!(err.is_empty());
537
538        let err = builtin
539            .arg(0, &scalar(ir::Scalar::I32), &arena)
540            .arg(1, &scalar(ir::Scalar::U32), &arena);
541        assert!(err.is_empty());
542    }
543}