1use alloc::{
2 boxed::Box,
3 format,
4 string::{String, ToString},
5 vec,
6 vec::Vec,
7};
8use core::num::NonZeroU32;
9
10use crate::common::wgsl::{TryToWgsl, TypeContext};
11use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
12use crate::front::wgsl::parse::ast;
13use crate::front::wgsl::{Error, Result};
14use crate::{Handle, Span};
15
16enum Constructor<T> {
19 PartialVector { size: crate::VectorSize },
22
23 PartialMatrix {
26 columns: crate::VectorSize,
27 rows: crate::VectorSize,
28 },
29
30 PartialArray,
33
34 Type(T),
42}
43
44impl Constructor<Handle<crate::Type>> {
45 fn borrow_inner(
51 self,
52 module: &crate::Module,
53 ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
54 match self {
55 Constructor::PartialVector { size } => Constructor::PartialVector { size },
56 Constructor::PartialMatrix { columns, rows } => {
57 Constructor::PartialMatrix { columns, rows }
58 }
59 Constructor::PartialArray => Constructor::PartialArray,
60 Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
61 }
62 }
63}
64
65impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
66 fn to_error_string(&self, ctx: &ExpressionContext) -> String {
67 match *self {
68 Self::PartialVector { size } => {
69 format!("vec{}<?>", size as u32,)
70 }
71 Self::PartialMatrix { columns, rows } => {
72 format!("mat{}x{}<?>", columns as u32, rows as u32,)
73 }
74 Self::PartialArray => "array<?, ?>".to_string(),
75 Self::Type((handle, _inner)) => ctx.type_to_string(handle),
76 }
77 }
78}
79
80enum Components<'a> {
81 None,
82 One {
83 component: Handle<crate::Expression>,
84 span: Span,
85 ty_inner: &'a crate::TypeInner,
86 },
87 Many {
88 components: Vec<Handle<crate::Expression>>,
89 spans: Vec<Span>,
90 },
91}
92
93impl Components<'_> {
94 fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
95 match self {
96 Self::None => vec![],
97 Self::One { component, .. } => vec![component],
98 Self::Many { components, .. } => components,
99 }
100 }
101}
102
103impl<'source> Lowerer<'source, '_> {
104 pub fn construct(
118 &mut self,
119 span: Span,
120 constructor: &ast::ConstructorType<'source>,
121 ty_span: Span,
122 components: &[Handle<ast::Expression<'source>>],
123 ctx: &mut ExpressionContext<'source, '_, '_>,
124 ) -> Result<'source, Handle<crate::Expression>> {
125 use crate::proc::TypeResolution as Tr;
126
127 let constructor_h = self.constructor(constructor, ctx)?;
128
129 let components = match *components {
130 [] => Components::None,
131 [component] => {
132 let span = ctx.ast_expressions.get_span(component);
133 let component = self.expression_for_abstract(component, ctx)?;
134 let ty_inner = super::resolve_inner!(ctx, component);
135
136 Components::One {
137 component,
138 span,
139 ty_inner,
140 }
141 }
142 ref ast_components @ [_, _, ..] => {
143 let components = ast_components
144 .iter()
145 .map(|&expr| self.expression_for_abstract(expr, ctx))
146 .collect::<Result<_>>()?;
147 let spans = ast_components
148 .iter()
149 .map(|&expr| ctx.ast_expressions.get_span(expr))
150 .collect();
151
152 for &component in &components {
153 ctx.grow_types(component)?;
154 }
155
156 Components::Many { components, spans }
157 }
158 };
159
160 let constructor = constructor_h.borrow_inner(ctx.module);
164
165 let expr;
166 match (components, constructor) {
167 (Components::None, dst_ty) => match dst_ty {
169 Constructor::Type((result_ty, _)) => {
170 expr = crate::Expression::ZeroValue(result_ty);
171 }
172 Constructor::PartialVector { size } => {
173 let result_ty = ctx.module.types.insert(
177 crate::Type {
178 name: None,
179 inner: crate::TypeInner::Vector {
180 size,
181 scalar: crate::Scalar::ABSTRACT_INT,
182 },
183 },
184 span,
185 );
186 expr = crate::Expression::ZeroValue(result_ty);
187 }
188 Constructor::PartialMatrix { .. } | Constructor::PartialArray => {
189 return Err(Box::new(Error::TypeNotInferable(ty_span)));
192 }
193 },
194
195 (
197 Components::One {
198 component,
199 ty_inner: &crate::TypeInner::Scalar { .. },
200 ..
201 },
202 Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
203 ) => {
204 expr = crate::Expression::As {
205 expr: component,
206 kind: scalar.kind,
207 convert: Some(scalar.width),
208 };
209 }
210
211 (
213 Components::One {
214 component,
215 ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
216 ..
217 },
218 Constructor::Type((
219 _,
220 &crate::TypeInner::Vector {
221 size: dst_size,
222 scalar: dst_scalar,
223 },
224 )),
225 ) if dst_size == src_size => {
226 expr = crate::Expression::As {
227 expr: component,
228 kind: dst_scalar.kind,
229 convert: Some(dst_scalar.width),
230 };
231 }
232
233 (
235 Components::One {
236 component,
237 ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
238 ..
239 },
240 Constructor::PartialVector { size: dst_size },
241 ) if dst_size == src_size => {
242 return Ok(component);
246 }
247
248 (
250 Components::One {
251 component,
252 ty_inner:
253 &crate::TypeInner::Matrix {
254 columns: src_columns,
255 rows: src_rows,
256 ..
257 },
258 ..
259 },
260 Constructor::Type((
261 _,
262 &crate::TypeInner::Matrix {
263 columns: dst_columns,
264 rows: dst_rows,
265 scalar: dst_scalar,
266 },
267 )),
268 ) if dst_columns == src_columns && dst_rows == src_rows => {
269 expr = crate::Expression::As {
270 expr: component,
271 kind: dst_scalar.kind,
272 convert: Some(dst_scalar.width),
273 };
274 }
275
276 (
278 Components::One {
279 component,
280 ty_inner:
281 &crate::TypeInner::Matrix {
282 columns: src_columns,
283 rows: src_rows,
284 ..
285 },
286 ..
287 },
288 Constructor::PartialMatrix {
289 columns: dst_columns,
290 rows: dst_rows,
291 },
292 ) if dst_columns == src_columns && dst_rows == src_rows => {
293 return Ok(component);
297 }
298
299 (
301 Components::One {
302 component,
303 ty_inner: &crate::TypeInner::Scalar { .. },
304 ..
305 },
306 Constructor::PartialVector { size },
307 ) => {
308 expr = crate::Expression::Splat {
309 size,
310 value: component,
311 };
312 }
313
314 (
316 Components::One {
317 mut component,
318 ty_inner: &crate::TypeInner::Scalar(component_scalar),
319 span,
320 },
321 Constructor::Type((
322 type_handle,
323 &crate::TypeInner::Vector {
324 size,
325 scalar: vec_scalar,
326 },
327 )),
328 ) => {
329 if !component_scalar.automatically_converts_to(vec_scalar) {
331 let component_ty = &ctx.typifier()[component];
332 let arg_ty = ctx.type_resolution_to_string(component_ty);
333 return Err(Box::new(Error::WrongArgumentType {
334 function: ctx.type_to_string(type_handle),
335 call_span: ty_span,
336 arg_span: span,
337 arg_index: 0,
338 arg_ty,
339 allowed: vec![vec_scalar.to_wgsl_for_diagnostics()],
340 }));
341 }
342 ctx.convert_slice_to_common_leaf_scalar(
343 core::slice::from_mut(&mut component),
344 vec_scalar,
345 )?;
346 expr = crate::Expression::Splat {
347 size,
348 value: component,
349 };
350 }
351
352 (
354 Components::Many {
355 mut components,
356 spans,
357 },
358 Constructor::PartialVector { size },
359 ) => {
360 let consensus_scalar =
361 ctx.automatic_conversion_consensus(&components)
362 .map_err(|index| {
363 Error::InvalidConstructorComponentType(spans[index], index as i32)
364 })?;
365 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
366 let inner = consensus_scalar.to_inner_vector(size);
367 let ty = ctx.ensure_type_exists(inner);
368 expr = crate::Expression::Compose { ty, components };
369 }
370
371 (
373 Components::Many { mut components, .. },
374 Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
375 ) => {
376 ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
377 expr = crate::Expression::Compose { ty, components };
378 }
379
380 (
382 Components::Many {
383 mut components,
384 spans,
385 },
386 Constructor::PartialMatrix { columns, rows },
387 ) if components.len() == columns as usize * rows as usize => {
388 let consensus_scalar =
389 ctx.automatic_conversion_consensus(&components)
390 .map_err(|index| {
391 Error::InvalidConstructorComponentType(spans[index], index as i32)
392 })?;
393 let consensus_scalar = consensus_scalar
395 .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
396 .unwrap_or(consensus_scalar);
397 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
398 let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
399
400 let components = components
401 .chunks(rows as usize)
402 .map(|vec_components| {
403 ctx.append_expression(
404 crate::Expression::Compose {
405 ty: vec_ty,
406 components: Vec::from(vec_components),
407 },
408 Default::default(),
409 )
410 })
411 .collect::<Result<Vec<_>>>()?;
412
413 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
414 columns,
415 rows,
416 scalar: consensus_scalar,
417 });
418 expr = crate::Expression::Compose { ty, components };
419 }
420
421 (
423 Components::Many { mut components, .. },
424 Constructor::Type((
425 _,
426 &crate::TypeInner::Matrix {
427 columns,
428 rows,
429 scalar,
430 },
431 )),
432 ) if components.len() == columns as usize * rows as usize => {
433 let element = Tr::Value(crate::TypeInner::Scalar(scalar));
434 ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
435 let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
436
437 let components = components
438 .chunks(rows as usize)
439 .map(|vec_components| {
440 ctx.append_expression(
441 crate::Expression::Compose {
442 ty: vec_ty,
443 components: Vec::from(vec_components),
444 },
445 Default::default(),
446 )
447 })
448 .collect::<Result<Vec<_>>>()?;
449
450 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
451 columns,
452 rows,
453 scalar,
454 });
455 expr = crate::Expression::Compose { ty, components };
456 }
457
458 (
460 Components::Many {
461 mut components,
462 spans,
463 },
464 Constructor::PartialMatrix { columns, rows },
465 ) => {
466 let consensus_scalar =
467 ctx.automatic_conversion_consensus(&components)
468 .map_err(|index| {
469 Error::InvalidConstructorComponentType(spans[index], index as i32)
470 })?;
471 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
472 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
473 columns,
474 rows,
475 scalar: consensus_scalar,
476 });
477 expr = crate::Expression::Compose { ty, components };
478 }
479
480 (
482 Components::Many { mut components, .. },
483 Constructor::Type((
484 ty,
485 &crate::TypeInner::Matrix {
486 columns: _,
487 rows,
488 scalar,
489 },
490 )),
491 ) => {
492 let component_ty = crate::TypeInner::Vector { size: rows, scalar };
493 ctx.try_automatic_conversions_slice(
494 &mut components,
495 &Tr::Value(component_ty),
496 ty_span,
497 )?;
498 expr = crate::Expression::Compose { ty, components };
499 }
500
501 (components, Constructor::PartialArray) => {
503 let mut components = components.into_components_vec();
504 if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
505 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
520 } else {
521 }
524
525 let base = ctx.register_type(components[0])?;
526
527 let inner = crate::TypeInner::Array {
528 base,
529 size: crate::ArraySize::Constant(
530 NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
531 ),
532 stride: {
533 ctx.layouter.update(ctx.module.to_ctx()).unwrap();
534 ctx.layouter[base].to_stride()
535 },
536 };
537 let ty = ctx.ensure_type_exists(inner);
538
539 expr = crate::Expression::Compose { ty, components };
540 }
541
542 (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
544 let mut components = components.into_components_vec();
545 ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
546 expr = crate::Expression::Compose { ty, components };
547 }
548
549 (
551 components,
552 Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
553 ) => {
554 let mut components = components.into_components_vec();
555 let struct_ty_span = ctx.module.types.get_span(ty);
556
557 let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
561
562 for (component, &ty) in components.iter_mut().zip(&members) {
563 *component =
564 ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
565 }
566 expr = crate::Expression::Compose { ty, components };
567 }
568
569 (
573 Components::One {
574 span, component, ..
575 },
576 constructor,
577 ) => {
578 let component_ty = &ctx.typifier()[component];
579 let from_type = ctx.type_resolution_to_string(component_ty);
580 return Err(Box::new(Error::BadTypeCast {
581 span,
582 from_type,
583 to_type: constructor.to_error_string(ctx),
584 }));
585 }
586
587 (
589 Components::Many { spans, .. },
590 Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
591 ) => {
592 let span = spans[1].until(spans.last().unwrap());
593 return Err(Box::new(Error::UnexpectedComponents(span)));
594 }
595
596 _ => return Err(Box::new(Error::TypeNotConstructible(ty_span))),
598 }
599
600 let expr = ctx.append_expression(expr, span)?;
601 Ok(expr)
602 }
603
604 fn constructor<'out>(
617 &mut self,
618 constructor: &ast::ConstructorType<'source>,
619 ctx: &mut ExpressionContext<'source, '_, 'out>,
620 ) -> Result<'source, Constructor<Handle<crate::Type>>> {
621 let handle = match *constructor {
622 ast::ConstructorType::Scalar(scalar) => {
623 let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
624 Constructor::Type(ty)
625 }
626 ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
627 ast::ConstructorType::Vector { size, ty, ty_span } => {
628 let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
629 let scalar = match ctx.module.types[ty].inner {
630 crate::TypeInner::Scalar(sc) => sc,
631 _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
632 };
633 let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, scalar });
634 Constructor::Type(ty)
635 }
636 ast::ConstructorType::PartialMatrix { columns, rows } => {
637 Constructor::PartialMatrix { columns, rows }
638 }
639 ast::ConstructorType::Matrix {
640 rows,
641 columns,
642 ty,
643 ty_span,
644 } => {
645 let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
646 let scalar = match ctx.module.types[ty].inner {
647 crate::TypeInner::Scalar(sc) => sc,
648 _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
649 };
650 let ty = match scalar.kind {
651 crate::ScalarKind::Float => ctx.ensure_type_exists(crate::TypeInner::Matrix {
652 columns,
653 rows,
654 scalar,
655 }),
656 _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))),
657 };
658 Constructor::Type(ty)
659 }
660 ast::ConstructorType::PartialCooperativeMatrix { .. } => {
661 return Err(Box::new(Error::UnderspecifiedCooperativeMatrix));
662 }
663 ast::ConstructorType::CooperativeMatrix {
664 rows,
665 columns,
666 ty,
667 ty_span,
668 role,
669 } => {
670 let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
671 let scalar = match ctx.module.types[ty].inner {
672 crate::TypeInner::Scalar(s) => s,
673 _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))),
674 };
675 let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix {
676 columns,
677 rows,
678 scalar,
679 role,
680 });
681 Constructor::Type(ty)
682 }
683 ast::ConstructorType::PartialArray => Constructor::PartialArray,
684 ast::ConstructorType::Array { base, size } => {
685 let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
686 let size = self.array_size(size, &mut ctx.as_const())?;
687
688 ctx.layouter.update(ctx.module.to_ctx()).unwrap();
689 let stride = ctx.layouter[base].to_stride();
690
691 let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
692 Constructor::Type(ty)
693 }
694 ast::ConstructorType::Type(ty) => Constructor::Type(ty),
695 };
696
697 Ok(handle)
698 }
699}