naga/front/wgsl/lower/
construction.rs

1use alloc::{
2    boxed::Box,
3    format,
4    string::{String, ToString},
5    vec,
6    vec::Vec,
7};
8use core::num::NonZeroU32;
9
10use crate::common::wgsl::TypeContext;
11use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
12use crate::front::wgsl::parse::ast;
13use crate::front::wgsl::{Error, Result};
14use crate::{Handle, Span};
15
16/// A cooked form of `ast::ConstructorType` that uses Naga types whenever
17/// possible.
18enum Constructor<T> {
19    /// A vector construction whose component type is inferred from the
20    /// argument: `vec3(1.0)`.
21    PartialVector { size: crate::VectorSize },
22
23    /// A matrix construction whose component type is inferred from the
24    /// argument: `mat2x2(1,2,3,4)`.
25    PartialMatrix {
26        columns: crate::VectorSize,
27        rows: crate::VectorSize,
28    },
29
30    /// An array whose component type and size are inferred from the arguments:
31    /// `array(3,4,5)`.
32    PartialArray,
33
34    /// A known Naga type.
35    ///
36    /// When we match on this type, we need to see the `TypeInner` here, but at
37    /// the point that we build this value we'll still need mutable access to
38    /// the module later. To avoid borrowing from the module, the type parameter
39    /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
40    /// version holding a tuple `(Handle<Type>, &TypeInner)`.
41    Type(T),
42}
43
44impl Constructor<Handle<crate::Type>> {
45    /// Return an equivalent `Constructor` value that includes borrowed
46    /// `TypeInner` values alongside any type handles.
47    ///
48    /// The returned form is more convenient to match on, since the patterns
49    /// can actually see what the handle refers to.
50    fn borrow_inner(
51        self,
52        module: &crate::Module,
53    ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
54        match self {
55            Constructor::PartialVector { size } => Constructor::PartialVector { size },
56            Constructor::PartialMatrix { columns, rows } => {
57                Constructor::PartialMatrix { columns, rows }
58            }
59            Constructor::PartialArray => Constructor::PartialArray,
60            Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
61        }
62    }
63}
64
65impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
66    fn to_error_string(&self, ctx: &ExpressionContext) -> String {
67        match *self {
68            Self::PartialVector { size } => {
69                format!("vec{}<?>", size as u32,)
70            }
71            Self::PartialMatrix { columns, rows } => {
72                format!("mat{}x{}<?>", columns as u32, rows as u32,)
73            }
74            Self::PartialArray => "array<?, ?>".to_string(),
75            Self::Type((handle, _inner)) => ctx.type_to_string(handle),
76        }
77    }
78}
79
80enum Components<'a> {
81    None,
82    One {
83        component: Handle<crate::Expression>,
84        span: Span,
85        ty_inner: &'a crate::TypeInner,
86    },
87    Many {
88        components: Vec<Handle<crate::Expression>>,
89        spans: Vec<Span>,
90    },
91}
92
93impl Components<'_> {
94    fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
95        match self {
96            Self::None => vec![],
97            Self::One { component, .. } => vec![component],
98            Self::Many { components, .. } => components,
99        }
100    }
101}
102
103impl<'source> Lowerer<'source, '_> {
104    /// Generate Naga IR for a type constructor expression.
105    ///
106    /// The `constructor` value represents the head of the constructor
107    /// expression, which is at least a hint of which type is being built; if
108    /// it's one of the `Partial` variants, we need to consider the argument
109    /// types as well.
110    ///
111    /// This is used for [`Construct`] expressions, but also for [`Call`]
112    /// expressions, once we've determined that the "callable" (in WGSL spec
113    /// terms) is actually a type.
114    ///
115    /// [`Construct`]: ast::Expression::Construct
116    /// [`Call`]: ast::Expression::Call
117    pub fn construct(
118        &mut self,
119        span: Span,
120        constructor: &ast::ConstructorType<'source>,
121        ty_span: Span,
122        components: &[Handle<ast::Expression<'source>>],
123        ctx: &mut ExpressionContext<'source, '_, '_>,
124    ) -> Result<'source, Handle<crate::Expression>> {
125        use crate::proc::TypeResolution as Tr;
126
127        let constructor_h = self.constructor(constructor, ctx)?;
128
129        let components = match *components {
130            [] => Components::None,
131            [component] => {
132                let span = ctx.ast_expressions.get_span(component);
133                let component = self.expression_for_abstract(component, ctx)?;
134                let ty_inner = super::resolve_inner!(ctx, component);
135
136                Components::One {
137                    component,
138                    span,
139                    ty_inner,
140                }
141            }
142            ref ast_components @ [_, _, ..] => {
143                let components = ast_components
144                    .iter()
145                    .map(|&expr| self.expression_for_abstract(expr, ctx))
146                    .collect::<Result<_>>()?;
147                let spans = ast_components
148                    .iter()
149                    .map(|&expr| ctx.ast_expressions.get_span(expr))
150                    .collect();
151
152                for &component in &components {
153                    ctx.grow_types(component)?;
154                }
155
156                Components::Many { components, spans }
157            }
158        };
159
160        // Even though we computed `constructor` above, wait until now to borrow
161        // a reference to the `TypeInner`, so that the component-handling code
162        // above can have mutable access to the type arena.
163        let constructor = constructor_h.borrow_inner(ctx.module);
164
165        let expr;
166        match (components, constructor) {
167            // Empty constructor
168            (Components::None, dst_ty) => match dst_ty {
169                Constructor::Type((result_ty, _)) => {
170                    expr = crate::Expression::ZeroValue(result_ty);
171                }
172                Constructor::PartialVector { size } => {
173                    // vec2(), vec3(), vec4() return vectors of abstractInts; the same
174                    // is not true of the similar constructors for matrices or arrays.
175                    // See https://www.w3.org/TR/WGSL/#vec2-builtin et seq.
176                    let result_ty = ctx.module.types.insert(
177                        crate::Type {
178                            name: None,
179                            inner: crate::TypeInner::Vector {
180                                size,
181                                scalar: crate::Scalar::ABSTRACT_INT,
182                            },
183                        },
184                        span,
185                    );
186                    expr = crate::Expression::ZeroValue(result_ty);
187                }
188                Constructor::PartialMatrix { .. } | Constructor::PartialArray => {
189                    // We have no arguments from which to infer the result type, so
190                    // partial constructors aren't acceptable here.
191                    return Err(Box::new(Error::TypeNotInferable(ty_span)));
192                }
193            },
194
195            // Scalar constructor & conversion (scalar -> scalar)
196            (
197                Components::One {
198                    component,
199                    ty_inner: &crate::TypeInner::Scalar { .. },
200                    ..
201                },
202                Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
203            ) => {
204                expr = crate::Expression::As {
205                    expr: component,
206                    kind: scalar.kind,
207                    convert: Some(scalar.width),
208                };
209            }
210
211            // Vector conversion (vector -> vector)
212            (
213                Components::One {
214                    component,
215                    ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
216                    ..
217                },
218                Constructor::Type((
219                    _,
220                    &crate::TypeInner::Vector {
221                        size: dst_size,
222                        scalar: dst_scalar,
223                    },
224                )),
225            ) if dst_size == src_size => {
226                expr = crate::Expression::As {
227                    expr: component,
228                    kind: dst_scalar.kind,
229                    convert: Some(dst_scalar.width),
230                };
231            }
232
233            // Vector conversion (vector -> vector) - partial
234            (
235                Components::One {
236                    component,
237                    ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
238                    ..
239                },
240                Constructor::PartialVector { size: dst_size },
241            ) if dst_size == src_size => {
242                // This is a trivial conversion: the sizes match, and a Partial
243                // constructor doesn't specify a scalar type, so nothing can
244                // possibly happen.
245                return Ok(component);
246            }
247
248            // Matrix conversion (matrix -> matrix)
249            (
250                Components::One {
251                    component,
252                    ty_inner:
253                        &crate::TypeInner::Matrix {
254                            columns: src_columns,
255                            rows: src_rows,
256                            ..
257                        },
258                    ..
259                },
260                Constructor::Type((
261                    _,
262                    &crate::TypeInner::Matrix {
263                        columns: dst_columns,
264                        rows: dst_rows,
265                        scalar: dst_scalar,
266                    },
267                )),
268            ) if dst_columns == src_columns && dst_rows == src_rows => {
269                expr = crate::Expression::As {
270                    expr: component,
271                    kind: dst_scalar.kind,
272                    convert: Some(dst_scalar.width),
273                };
274            }
275
276            // Matrix conversion (matrix -> matrix) - partial
277            (
278                Components::One {
279                    component,
280                    ty_inner:
281                        &crate::TypeInner::Matrix {
282                            columns: src_columns,
283                            rows: src_rows,
284                            ..
285                        },
286                    ..
287                },
288                Constructor::PartialMatrix {
289                    columns: dst_columns,
290                    rows: dst_rows,
291                },
292            ) if dst_columns == src_columns && dst_rows == src_rows => {
293                // This is a trivial conversion: the sizes match, and a Partial
294                // constructor doesn't specify a scalar type, so nothing can
295                // possibly happen.
296                return Ok(component);
297            }
298
299            // Vector constructor (splat) - infer type
300            (
301                Components::One {
302                    component,
303                    ty_inner: &crate::TypeInner::Scalar { .. },
304                    ..
305                },
306                Constructor::PartialVector { size },
307            ) => {
308                expr = crate::Expression::Splat {
309                    size,
310                    value: component,
311                };
312            }
313
314            // Vector constructor (splat)
315            (
316                Components::One {
317                    mut component,
318                    ty_inner: &crate::TypeInner::Scalar(_),
319                    ..
320                },
321                Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })),
322            ) => {
323                ctx.convert_slice_to_common_leaf_scalar(
324                    core::slice::from_mut(&mut component),
325                    scalar,
326                )?;
327                expr = crate::Expression::Splat {
328                    size,
329                    value: component,
330                };
331            }
332
333            // Vector constructor (by elements), partial
334            (
335                Components::Many {
336                    mut components,
337                    spans,
338                },
339                Constructor::PartialVector { size },
340            ) => {
341                let consensus_scalar =
342                    ctx.automatic_conversion_consensus(&components)
343                        .map_err(|index| {
344                            Error::InvalidConstructorComponentType(spans[index], index as i32)
345                        })?;
346                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
347                let inner = consensus_scalar.to_inner_vector(size);
348                let ty = ctx.ensure_type_exists(inner);
349                expr = crate::Expression::Compose { ty, components };
350            }
351
352            // Vector constructor (by elements), full type given
353            (
354                Components::Many { mut components, .. },
355                Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
356            ) => {
357                ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
358                expr = crate::Expression::Compose { ty, components };
359            }
360
361            // Matrix constructor (by elements), partial
362            (
363                Components::Many {
364                    mut components,
365                    spans,
366                },
367                Constructor::PartialMatrix { columns, rows },
368            ) if components.len() == columns as usize * rows as usize => {
369                let consensus_scalar =
370                    ctx.automatic_conversion_consensus(&components)
371                        .map_err(|index| {
372                            Error::InvalidConstructorComponentType(spans[index], index as i32)
373                        })?;
374                // We actually only accept floating-point elements.
375                let consensus_scalar = consensus_scalar
376                    .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
377                    .unwrap_or(consensus_scalar);
378                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
379                let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
380
381                let components = components
382                    .chunks(rows as usize)
383                    .map(|vec_components| {
384                        ctx.append_expression(
385                            crate::Expression::Compose {
386                                ty: vec_ty,
387                                components: Vec::from(vec_components),
388                            },
389                            Default::default(),
390                        )
391                    })
392                    .collect::<Result<Vec<_>>>()?;
393
394                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
395                    columns,
396                    rows,
397                    scalar: consensus_scalar,
398                });
399                expr = crate::Expression::Compose { ty, components };
400            }
401
402            // Matrix constructor (by elements), type given
403            (
404                Components::Many { mut components, .. },
405                Constructor::Type((
406                    _,
407                    &crate::TypeInner::Matrix {
408                        columns,
409                        rows,
410                        scalar,
411                    },
412                )),
413            ) if components.len() == columns as usize * rows as usize => {
414                let element = Tr::Value(crate::TypeInner::Scalar(scalar));
415                ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
416                let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
417
418                let components = components
419                    .chunks(rows as usize)
420                    .map(|vec_components| {
421                        ctx.append_expression(
422                            crate::Expression::Compose {
423                                ty: vec_ty,
424                                components: Vec::from(vec_components),
425                            },
426                            Default::default(),
427                        )
428                    })
429                    .collect::<Result<Vec<_>>>()?;
430
431                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
432                    columns,
433                    rows,
434                    scalar,
435                });
436                expr = crate::Expression::Compose { ty, components };
437            }
438
439            // Matrix constructor (by columns), partial
440            (
441                Components::Many {
442                    mut components,
443                    spans,
444                },
445                Constructor::PartialMatrix { columns, rows },
446            ) => {
447                let consensus_scalar =
448                    ctx.automatic_conversion_consensus(&components)
449                        .map_err(|index| {
450                            Error::InvalidConstructorComponentType(spans[index], index as i32)
451                        })?;
452                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
453                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
454                    columns,
455                    rows,
456                    scalar: consensus_scalar,
457                });
458                expr = crate::Expression::Compose { ty, components };
459            }
460
461            // Matrix constructor (by columns), type given
462            (
463                Components::Many { mut components, .. },
464                Constructor::Type((
465                    ty,
466                    &crate::TypeInner::Matrix {
467                        columns: _,
468                        rows,
469                        scalar,
470                    },
471                )),
472            ) => {
473                let component_ty = crate::TypeInner::Vector { size: rows, scalar };
474                ctx.try_automatic_conversions_slice(
475                    &mut components,
476                    &Tr::Value(component_ty),
477                    ty_span,
478                )?;
479                expr = crate::Expression::Compose { ty, components };
480            }
481
482            // Array constructor - infer type
483            (components, Constructor::PartialArray) => {
484                let mut components = components.into_components_vec();
485                if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
486                    // Note that this will *not* necessarily convert all the
487                    // components to the same type! The `automatic_conversion_consensus`
488                    // method only considers the parameters' leaf scalar
489                    // types; the parameters themselves could be any mix of
490                    // vectors, matrices, and scalars.
491                    //
492                    // But *if* it is possible for this array construction
493                    // expression to be well-typed at all, then all the
494                    // parameters must have the same type constructors (vec,
495                    // matrix, scalar) applied to their leaf scalars, so
496                    // reconciling their scalars is always the right thing to
497                    // do. And if this array construction is not well-typed,
498                    // these conversions will not make it so, and we can let
499                    // validation catch the error.
500                    ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
501                } else {
502                    // There's no consensus scalar. Emit the `Compose`
503                    // expression anyway, and let validation catch the problem.
504                }
505
506                let base = ctx.register_type(components[0])?;
507
508                let inner = crate::TypeInner::Array {
509                    base,
510                    size: crate::ArraySize::Constant(
511                        NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
512                    ),
513                    stride: {
514                        ctx.layouter.update(ctx.module.to_ctx()).unwrap();
515                        ctx.layouter[base].to_stride()
516                    },
517                };
518                let ty = ctx.ensure_type_exists(inner);
519
520                expr = crate::Expression::Compose { ty, components };
521            }
522
523            // Array constructor, explicit type
524            (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
525                let mut components = components.into_components_vec();
526                ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
527                expr = crate::Expression::Compose { ty, components };
528            }
529
530            // Struct constructor
531            (
532                components,
533                Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
534            ) => {
535                let mut components = components.into_components_vec();
536                let struct_ty_span = ctx.module.types.get_span(ty);
537
538                // Make a vector of the members' type handles in advance, to
539                // avoid borrowing `members` from `ctx` while we generate
540                // new code.
541                let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
542
543                for (component, &ty) in components.iter_mut().zip(&members) {
544                    *component =
545                        ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
546                }
547                expr = crate::Expression::Compose { ty, components };
548            }
549
550            // ERRORS
551
552            // Bad conversion (type cast)
553            (
554                Components::One {
555                    span, component, ..
556                },
557                constructor,
558            ) => {
559                let component_ty = &ctx.typifier()[component];
560                let from_type = ctx.type_resolution_to_string(component_ty);
561                return Err(Box::new(Error::BadTypeCast {
562                    span,
563                    from_type,
564                    to_type: constructor.to_error_string(ctx),
565                }));
566            }
567
568            // Too many parameters for scalar constructor
569            (
570                Components::Many { spans, .. },
571                Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
572            ) => {
573                let span = spans[1].until(spans.last().unwrap());
574                return Err(Box::new(Error::UnexpectedComponents(span)));
575            }
576
577            // Other types can't be constructed
578            _ => return Err(Box::new(Error::TypeNotConstructible(ty_span))),
579        }
580
581        let expr = ctx.append_expression(expr, span)?;
582        Ok(expr)
583    }
584
585    /// Build a [`Constructor`] for a WGSL construction expression.
586    ///
587    /// If `constructor` conveys enough information to determine which Naga [`Type`]
588    /// we're actually building (i.e., it's not a partial constructor), then
589    /// ensure the `Type` exists in [`ctx.module`], and return
590    /// [`Constructor::Type`].
591    ///
592    /// Otherwise, return the [`Constructor`] partial variant corresponding to
593    /// `constructor`.
594    ///
595    /// [`Type`]: crate::Type
596    /// [`ctx.module`]: ExpressionContext::module
597    fn constructor<'out>(
598        &mut self,
599        constructor: &ast::ConstructorType<'source>,
600        ctx: &mut ExpressionContext<'source, '_, 'out>,
601    ) -> Result<'source, Constructor<Handle<crate::Type>>> {
602        let handle = match *constructor {
603            ast::ConstructorType::Scalar(scalar) => {
604                let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
605                Constructor::Type(ty)
606            }
607            ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
608            ast::ConstructorType::Vector { size, ty, ty_span } => {
609                let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
610                let scalar = match ctx.module.types[ty].inner {
611                    crate::TypeInner::Scalar(sc) => sc,
612                    _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
613                };
614                let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, scalar });
615                Constructor::Type(ty)
616            }
617            ast::ConstructorType::PartialMatrix { columns, rows } => {
618                Constructor::PartialMatrix { columns, rows }
619            }
620            ast::ConstructorType::Matrix {
621                rows,
622                columns,
623                ty,
624                ty_span,
625            } => {
626                let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
627                let scalar = match ctx.module.types[ty].inner {
628                    crate::TypeInner::Scalar(sc) => sc,
629                    _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
630                };
631                let ty = match scalar.kind {
632                    crate::ScalarKind::Float => ctx.ensure_type_exists(crate::TypeInner::Matrix {
633                        columns,
634                        rows,
635                        scalar,
636                    }),
637                    _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))),
638                };
639                Constructor::Type(ty)
640            }
641            ast::ConstructorType::PartialArray => Constructor::PartialArray,
642            ast::ConstructorType::Array { base, size } => {
643                let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
644                let size = self.array_size(size, &mut ctx.as_const())?;
645
646                ctx.layouter.update(ctx.module.to_ctx()).unwrap();
647                let stride = ctx.layouter[base].to_stride();
648
649                let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
650                Constructor::Type(ty)
651            }
652            ast::ConstructorType::Type(ty) => Constructor::Type(ty),
653        };
654
655        Ok(handle)
656    }
657}