naga/front/wgsl/lower/
construction.rs

1use alloc::{
2    boxed::Box,
3    format,
4    string::{String, ToString},
5    vec,
6    vec::Vec,
7};
8use core::num::NonZeroU32;
9
10use crate::common::wgsl::{TryToWgsl, TypeContext};
11use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
12use crate::front::wgsl::parse::ast;
13use crate::front::wgsl::{Error, Result};
14use crate::{Handle, Span};
15
16/// A cooked form of `ast::ConstructorType` that uses Naga types whenever
17/// possible.
18enum Constructor<T> {
19    /// A vector construction whose component type is inferred from the
20    /// argument: `vec3(1.0)`.
21    PartialVector { size: crate::VectorSize },
22
23    /// A matrix construction whose component type is inferred from the
24    /// argument: `mat2x2(1,2,3,4)`.
25    PartialMatrix {
26        columns: crate::VectorSize,
27        rows: crate::VectorSize,
28    },
29
30    /// An array whose component type and size are inferred from the arguments:
31    /// `array(3,4,5)`.
32    PartialArray,
33
34    /// A known Naga type.
35    ///
36    /// When we match on this type, we need to see the `TypeInner` here, but at
37    /// the point that we build this value we'll still need mutable access to
38    /// the module later. To avoid borrowing from the module, the type parameter
39    /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
40    /// version holding a tuple `(Handle<Type>, &TypeInner)`.
41    Type(T),
42}
43
44impl Constructor<Handle<crate::Type>> {
45    /// Return an equivalent `Constructor` value that includes borrowed
46    /// `TypeInner` values alongside any type handles.
47    ///
48    /// The returned form is more convenient to match on, since the patterns
49    /// can actually see what the handle refers to.
50    fn borrow_inner(
51        self,
52        module: &crate::Module,
53    ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
54        match self {
55            Constructor::PartialVector { size } => Constructor::PartialVector { size },
56            Constructor::PartialMatrix { columns, rows } => {
57                Constructor::PartialMatrix { columns, rows }
58            }
59            Constructor::PartialArray => Constructor::PartialArray,
60            Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
61        }
62    }
63}
64
65impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
66    fn to_error_string(&self, ctx: &ExpressionContext) -> String {
67        match *self {
68            Self::PartialVector { size } => {
69                format!("vec{}<?>", size as u32,)
70            }
71            Self::PartialMatrix { columns, rows } => {
72                format!("mat{}x{}<?>", columns as u32, rows as u32,)
73            }
74            Self::PartialArray => "array<?, ?>".to_string(),
75            Self::Type((handle, _inner)) => ctx.type_to_string(handle),
76        }
77    }
78}
79
80enum Components<'a> {
81    None,
82    One {
83        component: Handle<crate::Expression>,
84        span: Span,
85        ty_inner: &'a crate::TypeInner,
86    },
87    Many {
88        components: Vec<Handle<crate::Expression>>,
89        spans: Vec<Span>,
90    },
91}
92
93impl Components<'_> {
94    fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
95        match self {
96            Self::None => vec![],
97            Self::One { component, .. } => vec![component],
98            Self::Many { components, .. } => components,
99        }
100    }
101}
102
103impl<'source> Lowerer<'source, '_> {
104    /// Generate Naga IR for a type constructor expression.
105    ///
106    /// The `constructor` value represents the head of the constructor
107    /// expression, which is at least a hint of which type is being built; if
108    /// it's one of the `Partial` variants, we need to consider the argument
109    /// types as well.
110    ///
111    /// This is used for [`Construct`] expressions, but also for [`Call`]
112    /// expressions, once we've determined that the "callable" (in WGSL spec
113    /// terms) is actually a type.
114    ///
115    /// [`Construct`]: ast::Expression::Construct
116    /// [`Call`]: ast::Expression::Call
117    pub fn construct(
118        &mut self,
119        span: Span,
120        constructor: &ast::ConstructorType<'source>,
121        ty_span: Span,
122        components: &[Handle<ast::Expression<'source>>],
123        ctx: &mut ExpressionContext<'source, '_, '_>,
124    ) -> Result<'source, Handle<crate::Expression>> {
125        use crate::proc::TypeResolution as Tr;
126
127        let constructor_h = self.constructor(constructor, ctx)?;
128
129        let components = match *components {
130            [] => Components::None,
131            [component] => {
132                let span = ctx.ast_expressions.get_span(component);
133                let component = self.expression_for_abstract(component, ctx)?;
134                let ty_inner = super::resolve_inner!(ctx, component);
135
136                Components::One {
137                    component,
138                    span,
139                    ty_inner,
140                }
141            }
142            ref ast_components @ [_, _, ..] => {
143                let components = ast_components
144                    .iter()
145                    .map(|&expr| self.expression_for_abstract(expr, ctx))
146                    .collect::<Result<_>>()?;
147                let spans = ast_components
148                    .iter()
149                    .map(|&expr| ctx.ast_expressions.get_span(expr))
150                    .collect();
151
152                for &component in &components {
153                    ctx.grow_types(component)?;
154                }
155
156                Components::Many { components, spans }
157            }
158        };
159
160        // Even though we computed `constructor` above, wait until now to borrow
161        // a reference to the `TypeInner`, so that the component-handling code
162        // above can have mutable access to the type arena.
163        let constructor = constructor_h.borrow_inner(ctx.module);
164
165        let expr;
166        match (components, constructor) {
167            // Empty constructor
168            (Components::None, dst_ty) => match dst_ty {
169                Constructor::Type((result_ty, _)) => {
170                    expr = crate::Expression::ZeroValue(result_ty);
171                }
172                Constructor::PartialVector { size } => {
173                    // vec2(), vec3(), vec4() return vectors of abstractInts; the same
174                    // is not true of the similar constructors for matrices or arrays.
175                    // See https://www.w3.org/TR/WGSL/#vec2-builtin et seq.
176                    let result_ty = ctx.module.types.insert(
177                        crate::Type {
178                            name: None,
179                            inner: crate::TypeInner::Vector {
180                                size,
181                                scalar: crate::Scalar::ABSTRACT_INT,
182                            },
183                        },
184                        span,
185                    );
186                    expr = crate::Expression::ZeroValue(result_ty);
187                }
188                Constructor::PartialMatrix { .. } | Constructor::PartialArray => {
189                    // We have no arguments from which to infer the result type, so
190                    // partial constructors aren't acceptable here.
191                    return Err(Box::new(Error::TypeNotInferable(ty_span)));
192                }
193            },
194
195            // Scalar constructor & conversion (scalar -> scalar)
196            (
197                Components::One {
198                    component,
199                    ty_inner: &crate::TypeInner::Scalar { .. },
200                    ..
201                },
202                Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
203            ) => {
204                expr = crate::Expression::As {
205                    expr: component,
206                    kind: scalar.kind,
207                    convert: Some(scalar.width),
208                };
209            }
210
211            // Vector conversion (vector -> vector)
212            (
213                Components::One {
214                    component,
215                    ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
216                    ..
217                },
218                Constructor::Type((
219                    _,
220                    &crate::TypeInner::Vector {
221                        size: dst_size,
222                        scalar: dst_scalar,
223                    },
224                )),
225            ) if dst_size == src_size => {
226                expr = crate::Expression::As {
227                    expr: component,
228                    kind: dst_scalar.kind,
229                    convert: Some(dst_scalar.width),
230                };
231            }
232
233            // Vector conversion (vector -> vector) - partial
234            (
235                Components::One {
236                    component,
237                    ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
238                    ..
239                },
240                Constructor::PartialVector { size: dst_size },
241            ) if dst_size == src_size => {
242                // This is a trivial conversion: the sizes match, and a Partial
243                // constructor doesn't specify a scalar type, so nothing can
244                // possibly happen.
245                return Ok(component);
246            }
247
248            // Matrix conversion (matrix -> matrix)
249            (
250                Components::One {
251                    component,
252                    ty_inner:
253                        &crate::TypeInner::Matrix {
254                            columns: src_columns,
255                            rows: src_rows,
256                            ..
257                        },
258                    ..
259                },
260                Constructor::Type((
261                    _,
262                    &crate::TypeInner::Matrix {
263                        columns: dst_columns,
264                        rows: dst_rows,
265                        scalar: dst_scalar,
266                    },
267                )),
268            ) if dst_columns == src_columns && dst_rows == src_rows => {
269                expr = crate::Expression::As {
270                    expr: component,
271                    kind: dst_scalar.kind,
272                    convert: Some(dst_scalar.width),
273                };
274            }
275
276            // Matrix conversion (matrix -> matrix) - partial
277            (
278                Components::One {
279                    component,
280                    ty_inner:
281                        &crate::TypeInner::Matrix {
282                            columns: src_columns,
283                            rows: src_rows,
284                            ..
285                        },
286                    ..
287                },
288                Constructor::PartialMatrix {
289                    columns: dst_columns,
290                    rows: dst_rows,
291                },
292            ) if dst_columns == src_columns && dst_rows == src_rows => {
293                // This is a trivial conversion: the sizes match, and a Partial
294                // constructor doesn't specify a scalar type, so nothing can
295                // possibly happen.
296                return Ok(component);
297            }
298
299            // Vector constructor (splat) - infer type
300            (
301                Components::One {
302                    component,
303                    ty_inner: &crate::TypeInner::Scalar { .. },
304                    ..
305                },
306                Constructor::PartialVector { size },
307            ) => {
308                expr = crate::Expression::Splat {
309                    size,
310                    value: component,
311                };
312            }
313
314            // Vector constructor (splat)
315            (
316                Components::One {
317                    mut component,
318                    ty_inner: &crate::TypeInner::Scalar(component_scalar),
319                    span,
320                },
321                Constructor::Type((
322                    type_handle,
323                    &crate::TypeInner::Vector {
324                        size,
325                        scalar: vec_scalar,
326                    },
327                )),
328            ) => {
329                // Splat only allows automatic conversions of the component's scalar.
330                if !component_scalar.automatically_converts_to(vec_scalar) {
331                    let component_ty = &ctx.typifier()[component];
332                    let arg_ty = ctx.type_resolution_to_string(component_ty);
333                    return Err(Box::new(Error::WrongArgumentType {
334                        function: ctx.type_to_string(type_handle),
335                        call_span: ty_span,
336                        arg_span: span,
337                        arg_index: 0,
338                        arg_ty,
339                        allowed: vec![vec_scalar.to_wgsl_for_diagnostics()],
340                    }));
341                }
342                ctx.convert_slice_to_common_leaf_scalar(
343                    core::slice::from_mut(&mut component),
344                    vec_scalar,
345                )?;
346                expr = crate::Expression::Splat {
347                    size,
348                    value: component,
349                };
350            }
351
352            // Vector constructor (by elements), partial
353            (
354                Components::Many {
355                    mut components,
356                    spans,
357                },
358                Constructor::PartialVector { size },
359            ) => {
360                let consensus_scalar =
361                    ctx.automatic_conversion_consensus(&components)
362                        .map_err(|index| {
363                            Error::InvalidConstructorComponentType(spans[index], index as i32)
364                        })?;
365                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
366                let inner = consensus_scalar.to_inner_vector(size);
367                let ty = ctx.ensure_type_exists(inner);
368                expr = crate::Expression::Compose { ty, components };
369            }
370
371            // Vector constructor (by elements), full type given
372            (
373                Components::Many { mut components, .. },
374                Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
375            ) => {
376                ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
377                expr = crate::Expression::Compose { ty, components };
378            }
379
380            // Matrix constructor (by elements), partial
381            (
382                Components::Many {
383                    mut components,
384                    spans,
385                },
386                Constructor::PartialMatrix { columns, rows },
387            ) if components.len() == columns as usize * rows as usize => {
388                let consensus_scalar =
389                    ctx.automatic_conversion_consensus(&components)
390                        .map_err(|index| {
391                            Error::InvalidConstructorComponentType(spans[index], index as i32)
392                        })?;
393                // We actually only accept floating-point elements.
394                let consensus_scalar = consensus_scalar
395                    .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
396                    .unwrap_or(consensus_scalar);
397                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
398                let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
399
400                let components = components
401                    .chunks(rows as usize)
402                    .map(|vec_components| {
403                        ctx.append_expression(
404                            crate::Expression::Compose {
405                                ty: vec_ty,
406                                components: Vec::from(vec_components),
407                            },
408                            Default::default(),
409                        )
410                    })
411                    .collect::<Result<Vec<_>>>()?;
412
413                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
414                    columns,
415                    rows,
416                    scalar: consensus_scalar,
417                });
418                expr = crate::Expression::Compose { ty, components };
419            }
420
421            // Matrix constructor (by elements), type given
422            (
423                Components::Many { mut components, .. },
424                Constructor::Type((
425                    _,
426                    &crate::TypeInner::Matrix {
427                        columns,
428                        rows,
429                        scalar,
430                    },
431                )),
432            ) if components.len() == columns as usize * rows as usize => {
433                let element = Tr::Value(crate::TypeInner::Scalar(scalar));
434                ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
435                let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
436
437                let components = components
438                    .chunks(rows as usize)
439                    .map(|vec_components| {
440                        ctx.append_expression(
441                            crate::Expression::Compose {
442                                ty: vec_ty,
443                                components: Vec::from(vec_components),
444                            },
445                            Default::default(),
446                        )
447                    })
448                    .collect::<Result<Vec<_>>>()?;
449
450                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
451                    columns,
452                    rows,
453                    scalar,
454                });
455                expr = crate::Expression::Compose { ty, components };
456            }
457
458            // Matrix constructor (by columns), partial
459            (
460                Components::Many {
461                    mut components,
462                    spans,
463                },
464                Constructor::PartialMatrix { columns, rows },
465            ) => {
466                let consensus_scalar =
467                    ctx.automatic_conversion_consensus(&components)
468                        .map_err(|index| {
469                            Error::InvalidConstructorComponentType(spans[index], index as i32)
470                        })?;
471                ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
472                let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
473                    columns,
474                    rows,
475                    scalar: consensus_scalar,
476                });
477                expr = crate::Expression::Compose { ty, components };
478            }
479
480            // Matrix constructor (by columns), type given
481            (
482                Components::Many { mut components, .. },
483                Constructor::Type((
484                    ty,
485                    &crate::TypeInner::Matrix {
486                        columns: _,
487                        rows,
488                        scalar,
489                    },
490                )),
491            ) => {
492                let component_ty = crate::TypeInner::Vector { size: rows, scalar };
493                ctx.try_automatic_conversions_slice(
494                    &mut components,
495                    &Tr::Value(component_ty),
496                    ty_span,
497                )?;
498                expr = crate::Expression::Compose { ty, components };
499            }
500
501            // Array constructor - infer type
502            (components, Constructor::PartialArray) => {
503                let mut components = components.into_components_vec();
504                if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
505                    // Note that this will *not* necessarily convert all the
506                    // components to the same type! The `automatic_conversion_consensus`
507                    // method only considers the parameters' leaf scalar
508                    // types; the parameters themselves could be any mix of
509                    // vectors, matrices, and scalars.
510                    //
511                    // But *if* it is possible for this array construction
512                    // expression to be well-typed at all, then all the
513                    // parameters must have the same type constructors (vec,
514                    // matrix, scalar) applied to their leaf scalars, so
515                    // reconciling their scalars is always the right thing to
516                    // do. And if this array construction is not well-typed,
517                    // these conversions will not make it so, and we can let
518                    // validation catch the error.
519                    ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
520                } else {
521                    // There's no consensus scalar. Emit the `Compose`
522                    // expression anyway, and let validation catch the problem.
523                }
524
525                let base = ctx.register_type(components[0])?;
526
527                let inner = crate::TypeInner::Array {
528                    base,
529                    size: crate::ArraySize::Constant(
530                        NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
531                    ),
532                    stride: {
533                        ctx.layouter.update(ctx.module.to_ctx()).unwrap();
534                        ctx.layouter[base].to_stride()
535                    },
536                };
537                let ty = ctx.ensure_type_exists(inner);
538
539                expr = crate::Expression::Compose { ty, components };
540            }
541
542            // Array constructor, explicit type
543            (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
544                let mut components = components.into_components_vec();
545                ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
546                expr = crate::Expression::Compose { ty, components };
547            }
548
549            // Struct constructor
550            (
551                components,
552                Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
553            ) => {
554                let mut components = components.into_components_vec();
555                let struct_ty_span = ctx.module.types.get_span(ty);
556
557                // Make a vector of the members' type handles in advance, to
558                // avoid borrowing `members` from `ctx` while we generate
559                // new code.
560                let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
561
562                for (component, &ty) in components.iter_mut().zip(&members) {
563                    *component =
564                        ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
565                }
566                expr = crate::Expression::Compose { ty, components };
567            }
568
569            // ERRORS
570
571            // Bad conversion (type cast)
572            (
573                Components::One {
574                    span, component, ..
575                },
576                constructor,
577            ) => {
578                let component_ty = &ctx.typifier()[component];
579                let from_type = ctx.type_resolution_to_string(component_ty);
580                return Err(Box::new(Error::BadTypeCast {
581                    span,
582                    from_type,
583                    to_type: constructor.to_error_string(ctx),
584                }));
585            }
586
587            // Too many parameters for scalar constructor
588            (
589                Components::Many { spans, .. },
590                Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
591            ) => {
592                let span = spans[1].until(spans.last().unwrap());
593                return Err(Box::new(Error::UnexpectedComponents(span)));
594            }
595
596            // Other types can't be constructed
597            _ => return Err(Box::new(Error::TypeNotConstructible(ty_span))),
598        }
599
600        let expr = ctx.append_expression(expr, span)?;
601        Ok(expr)
602    }
603
604    /// Build a [`Constructor`] for a WGSL construction expression.
605    ///
606    /// If `constructor` conveys enough information to determine which Naga [`Type`]
607    /// we're actually building (i.e., it's not a partial constructor), then
608    /// ensure the `Type` exists in [`ctx.module`], and return
609    /// [`Constructor::Type`].
610    ///
611    /// Otherwise, return the [`Constructor`] partial variant corresponding to
612    /// `constructor`.
613    ///
614    /// [`Type`]: crate::Type
615    /// [`ctx.module`]: ExpressionContext::module
616    fn constructor<'out>(
617        &mut self,
618        constructor: &ast::ConstructorType<'source>,
619        ctx: &mut ExpressionContext<'source, '_, 'out>,
620    ) -> Result<'source, Constructor<Handle<crate::Type>>> {
621        let handle = match *constructor {
622            ast::ConstructorType::Scalar(scalar) => {
623                let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
624                Constructor::Type(ty)
625            }
626            ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
627            ast::ConstructorType::Vector { size, ty, ty_span } => {
628                let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
629                let scalar = match ctx.module.types[ty].inner {
630                    crate::TypeInner::Scalar(sc) => sc,
631                    _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
632                };
633                let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, scalar });
634                Constructor::Type(ty)
635            }
636            ast::ConstructorType::PartialMatrix { columns, rows } => {
637                Constructor::PartialMatrix { columns, rows }
638            }
639            ast::ConstructorType::Matrix {
640                rows,
641                columns,
642                ty,
643                ty_span,
644            } => {
645                let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
646                let scalar = match ctx.module.types[ty].inner {
647                    crate::TypeInner::Scalar(sc) => sc,
648                    _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
649                };
650                let ty = match scalar.kind {
651                    crate::ScalarKind::Float => ctx.ensure_type_exists(crate::TypeInner::Matrix {
652                        columns,
653                        rows,
654                        scalar,
655                    }),
656                    _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))),
657                };
658                Constructor::Type(ty)
659            }
660            ast::ConstructorType::PartialCooperativeMatrix { .. } => {
661                return Err(Box::new(Error::UnderspecifiedCooperativeMatrix));
662            }
663            ast::ConstructorType::CooperativeMatrix {
664                rows,
665                columns,
666                ty,
667                ty_span,
668                role,
669            } => {
670                let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
671                let scalar = match ctx.module.types[ty].inner {
672                    crate::TypeInner::Scalar(s) => s,
673                    _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))),
674                };
675                let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix {
676                    columns,
677                    rows,
678                    scalar,
679                    role,
680                });
681                Constructor::Type(ty)
682            }
683            ast::ConstructorType::PartialArray => Constructor::PartialArray,
684            ast::ConstructorType::Array { base, size } => {
685                let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
686                let size = self.array_size(size, &mut ctx.as_const())?;
687
688                ctx.layouter.update(ctx.module.to_ctx()).unwrap();
689                let stride = ctx.layouter[base].to_stride();
690
691                let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
692                Constructor::Type(ty)
693            }
694            ast::ConstructorType::Type(ty) => Constructor::Type(ty),
695        };
696
697        Ok(handle)
698    }
699}