1use alloc::{
2 boxed::Box,
3 format,
4 string::{String, ToString},
5 vec,
6 vec::Vec,
7};
8use core::num::NonZeroU32;
9
10use crate::common::wgsl::TypeContext;
11use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
12use crate::front::wgsl::parse::ast;
13use crate::front::wgsl::{Error, Result};
14use crate::{Handle, Span};
15
16enum Constructor<T> {
19 PartialVector { size: crate::VectorSize },
22
23 PartialMatrix {
26 columns: crate::VectorSize,
27 rows: crate::VectorSize,
28 },
29
30 PartialArray,
33
34 Type(T),
42}
43
44impl Constructor<Handle<crate::Type>> {
45 fn borrow_inner(
51 self,
52 module: &crate::Module,
53 ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
54 match self {
55 Constructor::PartialVector { size } => Constructor::PartialVector { size },
56 Constructor::PartialMatrix { columns, rows } => {
57 Constructor::PartialMatrix { columns, rows }
58 }
59 Constructor::PartialArray => Constructor::PartialArray,
60 Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
61 }
62 }
63}
64
65impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
66 fn to_error_string(&self, ctx: &ExpressionContext) -> String {
67 match *self {
68 Self::PartialVector { size } => {
69 format!("vec{}<?>", size as u32,)
70 }
71 Self::PartialMatrix { columns, rows } => {
72 format!("mat{}x{}<?>", columns as u32, rows as u32,)
73 }
74 Self::PartialArray => "array<?, ?>".to_string(),
75 Self::Type((handle, _inner)) => ctx.type_to_string(handle),
76 }
77 }
78}
79
80enum Components<'a> {
81 None,
82 One {
83 component: Handle<crate::Expression>,
84 span: Span,
85 ty_inner: &'a crate::TypeInner,
86 },
87 Many {
88 components: Vec<Handle<crate::Expression>>,
89 spans: Vec<Span>,
90 },
91}
92
93impl Components<'_> {
94 fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
95 match self {
96 Self::None => vec![],
97 Self::One { component, .. } => vec![component],
98 Self::Many { components, .. } => components,
99 }
100 }
101}
102
103impl<'source> Lowerer<'source, '_> {
104 pub fn construct(
118 &mut self,
119 span: Span,
120 constructor: &ast::ConstructorType<'source>,
121 ty_span: Span,
122 components: &[Handle<ast::Expression<'source>>],
123 ctx: &mut ExpressionContext<'source, '_, '_>,
124 ) -> Result<'source, Handle<crate::Expression>> {
125 use crate::proc::TypeResolution as Tr;
126
127 let constructor_h = self.constructor(constructor, ctx)?;
128
129 let components = match *components {
130 [] => Components::None,
131 [component] => {
132 let span = ctx.ast_expressions.get_span(component);
133 let component = self.expression_for_abstract(component, ctx)?;
134 let ty_inner = super::resolve_inner!(ctx, component);
135
136 Components::One {
137 component,
138 span,
139 ty_inner,
140 }
141 }
142 ref ast_components @ [_, _, ..] => {
143 let components = ast_components
144 .iter()
145 .map(|&expr| self.expression_for_abstract(expr, ctx))
146 .collect::<Result<_>>()?;
147 let spans = ast_components
148 .iter()
149 .map(|&expr| ctx.ast_expressions.get_span(expr))
150 .collect();
151
152 for &component in &components {
153 ctx.grow_types(component)?;
154 }
155
156 Components::Many { components, spans }
157 }
158 };
159
160 let constructor = constructor_h.borrow_inner(ctx.module);
164
165 let expr;
166 match (components, constructor) {
167 (Components::None, dst_ty) => match dst_ty {
169 Constructor::Type((result_ty, _)) => {
170 expr = crate::Expression::ZeroValue(result_ty);
171 }
172 Constructor::PartialVector { size } => {
173 let result_ty = ctx.module.types.insert(
177 crate::Type {
178 name: None,
179 inner: crate::TypeInner::Vector {
180 size,
181 scalar: crate::Scalar::ABSTRACT_INT,
182 },
183 },
184 span,
185 );
186 expr = crate::Expression::ZeroValue(result_ty);
187 }
188 Constructor::PartialMatrix { .. } | Constructor::PartialArray => {
189 return Err(Box::new(Error::TypeNotInferable(ty_span)));
192 }
193 },
194
195 (
197 Components::One {
198 component,
199 ty_inner: &crate::TypeInner::Scalar { .. },
200 ..
201 },
202 Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
203 ) => {
204 expr = crate::Expression::As {
205 expr: component,
206 kind: scalar.kind,
207 convert: Some(scalar.width),
208 };
209 }
210
211 (
213 Components::One {
214 component,
215 ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
216 ..
217 },
218 Constructor::Type((
219 _,
220 &crate::TypeInner::Vector {
221 size: dst_size,
222 scalar: dst_scalar,
223 },
224 )),
225 ) if dst_size == src_size => {
226 expr = crate::Expression::As {
227 expr: component,
228 kind: dst_scalar.kind,
229 convert: Some(dst_scalar.width),
230 };
231 }
232
233 (
235 Components::One {
236 component,
237 ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
238 ..
239 },
240 Constructor::PartialVector { size: dst_size },
241 ) if dst_size == src_size => {
242 return Ok(component);
246 }
247
248 (
250 Components::One {
251 component,
252 ty_inner:
253 &crate::TypeInner::Matrix {
254 columns: src_columns,
255 rows: src_rows,
256 ..
257 },
258 ..
259 },
260 Constructor::Type((
261 _,
262 &crate::TypeInner::Matrix {
263 columns: dst_columns,
264 rows: dst_rows,
265 scalar: dst_scalar,
266 },
267 )),
268 ) if dst_columns == src_columns && dst_rows == src_rows => {
269 expr = crate::Expression::As {
270 expr: component,
271 kind: dst_scalar.kind,
272 convert: Some(dst_scalar.width),
273 };
274 }
275
276 (
278 Components::One {
279 component,
280 ty_inner:
281 &crate::TypeInner::Matrix {
282 columns: src_columns,
283 rows: src_rows,
284 ..
285 },
286 ..
287 },
288 Constructor::PartialMatrix {
289 columns: dst_columns,
290 rows: dst_rows,
291 },
292 ) if dst_columns == src_columns && dst_rows == src_rows => {
293 return Ok(component);
297 }
298
299 (
301 Components::One {
302 component,
303 ty_inner: &crate::TypeInner::Scalar { .. },
304 ..
305 },
306 Constructor::PartialVector { size },
307 ) => {
308 expr = crate::Expression::Splat {
309 size,
310 value: component,
311 };
312 }
313
314 (
316 Components::One {
317 mut component,
318 ty_inner: &crate::TypeInner::Scalar(_),
319 ..
320 },
321 Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })),
322 ) => {
323 ctx.convert_slice_to_common_leaf_scalar(
324 core::slice::from_mut(&mut component),
325 scalar,
326 )?;
327 expr = crate::Expression::Splat {
328 size,
329 value: component,
330 };
331 }
332
333 (
335 Components::Many {
336 mut components,
337 spans,
338 },
339 Constructor::PartialVector { size },
340 ) => {
341 let consensus_scalar =
342 ctx.automatic_conversion_consensus(&components)
343 .map_err(|index| {
344 Error::InvalidConstructorComponentType(spans[index], index as i32)
345 })?;
346 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
347 let inner = consensus_scalar.to_inner_vector(size);
348 let ty = ctx.ensure_type_exists(inner);
349 expr = crate::Expression::Compose { ty, components };
350 }
351
352 (
354 Components::Many { mut components, .. },
355 Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
356 ) => {
357 ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
358 expr = crate::Expression::Compose { ty, components };
359 }
360
361 (
363 Components::Many {
364 mut components,
365 spans,
366 },
367 Constructor::PartialMatrix { columns, rows },
368 ) if components.len() == columns as usize * rows as usize => {
369 let consensus_scalar =
370 ctx.automatic_conversion_consensus(&components)
371 .map_err(|index| {
372 Error::InvalidConstructorComponentType(spans[index], index as i32)
373 })?;
374 let consensus_scalar = consensus_scalar
376 .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
377 .unwrap_or(consensus_scalar);
378 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
379 let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
380
381 let components = components
382 .chunks(rows as usize)
383 .map(|vec_components| {
384 ctx.append_expression(
385 crate::Expression::Compose {
386 ty: vec_ty,
387 components: Vec::from(vec_components),
388 },
389 Default::default(),
390 )
391 })
392 .collect::<Result<Vec<_>>>()?;
393
394 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
395 columns,
396 rows,
397 scalar: consensus_scalar,
398 });
399 expr = crate::Expression::Compose { ty, components };
400 }
401
402 (
404 Components::Many { mut components, .. },
405 Constructor::Type((
406 _,
407 &crate::TypeInner::Matrix {
408 columns,
409 rows,
410 scalar,
411 },
412 )),
413 ) if components.len() == columns as usize * rows as usize => {
414 let element = Tr::Value(crate::TypeInner::Scalar(scalar));
415 ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
416 let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
417
418 let components = components
419 .chunks(rows as usize)
420 .map(|vec_components| {
421 ctx.append_expression(
422 crate::Expression::Compose {
423 ty: vec_ty,
424 components: Vec::from(vec_components),
425 },
426 Default::default(),
427 )
428 })
429 .collect::<Result<Vec<_>>>()?;
430
431 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
432 columns,
433 rows,
434 scalar,
435 });
436 expr = crate::Expression::Compose { ty, components };
437 }
438
439 (
441 Components::Many {
442 mut components,
443 spans,
444 },
445 Constructor::PartialMatrix { columns, rows },
446 ) => {
447 let consensus_scalar =
448 ctx.automatic_conversion_consensus(&components)
449 .map_err(|index| {
450 Error::InvalidConstructorComponentType(spans[index], index as i32)
451 })?;
452 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
453 let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
454 columns,
455 rows,
456 scalar: consensus_scalar,
457 });
458 expr = crate::Expression::Compose { ty, components };
459 }
460
461 (
463 Components::Many { mut components, .. },
464 Constructor::Type((
465 ty,
466 &crate::TypeInner::Matrix {
467 columns: _,
468 rows,
469 scalar,
470 },
471 )),
472 ) => {
473 let component_ty = crate::TypeInner::Vector { size: rows, scalar };
474 ctx.try_automatic_conversions_slice(
475 &mut components,
476 &Tr::Value(component_ty),
477 ty_span,
478 )?;
479 expr = crate::Expression::Compose { ty, components };
480 }
481
482 (components, Constructor::PartialArray) => {
484 let mut components = components.into_components_vec();
485 if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
486 ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
501 } else {
502 }
505
506 let base = ctx.register_type(components[0])?;
507
508 let inner = crate::TypeInner::Array {
509 base,
510 size: crate::ArraySize::Constant(
511 NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
512 ),
513 stride: {
514 ctx.layouter.update(ctx.module.to_ctx()).unwrap();
515 ctx.layouter[base].to_stride()
516 },
517 };
518 let ty = ctx.ensure_type_exists(inner);
519
520 expr = crate::Expression::Compose { ty, components };
521 }
522
523 (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
525 let mut components = components.into_components_vec();
526 ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
527 expr = crate::Expression::Compose { ty, components };
528 }
529
530 (
532 components,
533 Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
534 ) => {
535 let mut components = components.into_components_vec();
536 let struct_ty_span = ctx.module.types.get_span(ty);
537
538 let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
542
543 for (component, &ty) in components.iter_mut().zip(&members) {
544 *component =
545 ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
546 }
547 expr = crate::Expression::Compose { ty, components };
548 }
549
550 (
554 Components::One {
555 span, component, ..
556 },
557 constructor,
558 ) => {
559 let component_ty = &ctx.typifier()[component];
560 let from_type = ctx.type_resolution_to_string(component_ty);
561 return Err(Box::new(Error::BadTypeCast {
562 span,
563 from_type,
564 to_type: constructor.to_error_string(ctx),
565 }));
566 }
567
568 (
570 Components::Many { spans, .. },
571 Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
572 ) => {
573 let span = spans[1].until(spans.last().unwrap());
574 return Err(Box::new(Error::UnexpectedComponents(span)));
575 }
576
577 _ => return Err(Box::new(Error::TypeNotConstructible(ty_span))),
579 }
580
581 let expr = ctx.append_expression(expr, span)?;
582 Ok(expr)
583 }
584
585 fn constructor<'out>(
598 &mut self,
599 constructor: &ast::ConstructorType<'source>,
600 ctx: &mut ExpressionContext<'source, '_, 'out>,
601 ) -> Result<'source, Constructor<Handle<crate::Type>>> {
602 let handle = match *constructor {
603 ast::ConstructorType::Scalar(scalar) => {
604 let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
605 Constructor::Type(ty)
606 }
607 ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
608 ast::ConstructorType::Vector { size, ty, ty_span } => {
609 let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
610 let scalar = match ctx.module.types[ty].inner {
611 crate::TypeInner::Scalar(sc) => sc,
612 _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
613 };
614 let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, scalar });
615 Constructor::Type(ty)
616 }
617 ast::ConstructorType::PartialMatrix { columns, rows } => {
618 Constructor::PartialMatrix { columns, rows }
619 }
620 ast::ConstructorType::Matrix {
621 rows,
622 columns,
623 ty,
624 ty_span,
625 } => {
626 let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
627 let scalar = match ctx.module.types[ty].inner {
628 crate::TypeInner::Scalar(sc) => sc,
629 _ => return Err(Box::new(Error::UnknownScalarType(ty_span))),
630 };
631 let ty = match scalar.kind {
632 crate::ScalarKind::Float => ctx.ensure_type_exists(crate::TypeInner::Matrix {
633 columns,
634 rows,
635 scalar,
636 }),
637 _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))),
638 };
639 Constructor::Type(ty)
640 }
641 ast::ConstructorType::PartialArray => Constructor::PartialArray,
642 ast::ConstructorType::Array { base, size } => {
643 let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
644 let size = self.array_size(size, &mut ctx.as_const())?;
645
646 ctx.layouter.update(ctx.module.to_ctx()).unwrap();
647 let stride = ctx.layouter[base].to_stride();
648
649 let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
650 Constructor::Type(ty)
651 }
652 ast::ConstructorType::Type(ty) => Constructor::Type(ty),
653 };
654
655 Ok(handle)
656 }
657}