1use alloc::{format, string::String};
2
3use thiserror::Error;
4
5use crate::{
6 arena::{Arena, Handle, UniqueArena},
7 common::ForDebugWithTypes,
8 ir,
9};
10
11#[derive(Debug, PartialEq)]
91#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
92#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
93pub enum TypeResolution {
94 Handle(Handle<crate::Type>),
96
97 Value(crate::TypeInner),
110}
111
112impl TypeResolution {
113 pub const fn handle(&self) -> Option<Handle<crate::Type>> {
114 match *self {
115 Self::Handle(handle) => Some(handle),
116 Self::Value(_) => None,
117 }
118 }
119
120 pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
121 match *self {
122 Self::Handle(handle) => &arena[handle].inner,
123 Self::Value(ref inner) => inner,
124 }
125 }
126}
127
128impl Clone for TypeResolution {
130 fn clone(&self) -> Self {
131 use crate::TypeInner as Ti;
132 match *self {
133 Self::Handle(handle) => Self::Handle(handle),
134 Self::Value(ref v) => Self::Value(match *v {
135 Ti::Scalar(scalar) => Ti::Scalar(scalar),
136 Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
137 Ti::Matrix {
138 rows,
139 columns,
140 scalar,
141 } => Ti::Matrix {
142 rows,
143 columns,
144 scalar,
145 },
146 Ti::Pointer { base, space } => Ti::Pointer { base, space },
147 Ti::ValuePointer {
148 size,
149 scalar,
150 space,
151 } => Ti::ValuePointer {
152 size,
153 scalar,
154 space,
155 },
156 Ti::Array { base, size, stride } => Ti::Array { base, size, stride },
157 _ => unreachable!("Unexpected clone type: {:?}", v),
158 }),
159 }
160 }
161}
162
163#[derive(Clone, Debug, Error, PartialEq)]
164pub enum ResolveError {
165 #[error("Index {index} is out of bounds for expression {expr:?}")]
166 OutOfBoundsIndex {
167 expr: Handle<crate::Expression>,
168 index: u32,
169 },
170 #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
171 InvalidAccess {
172 expr: Handle<crate::Expression>,
173 indexed: bool,
174 },
175 #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
176 InvalidSubAccess {
177 ty: Handle<crate::Type>,
178 indexed: bool,
179 },
180 #[error("Invalid scalar {0:?}")]
181 InvalidScalar(Handle<crate::Expression>),
182 #[error("Invalid vector {0:?}")]
183 InvalidVector(Handle<crate::Expression>),
184 #[error("Invalid pointer {0:?}")]
185 InvalidPointer(Handle<crate::Expression>),
186 #[error("Invalid image {0:?}")]
187 InvalidImage(Handle<crate::Expression>),
188 #[error("Function {name} not defined")]
189 FunctionNotDefined { name: String },
190 #[error("Function without return type")]
191 FunctionReturnsVoid,
192 #[error("Incompatible operands: {0}")]
193 IncompatibleOperands(String),
194 #[error("Function argument {0} doesn't exist")]
195 FunctionArgumentNotFound(u32),
196 #[error("Special type is not registered within the module")]
197 MissingSpecialType,
198 #[error("Call to builtin {0} has incorrect or ambiguous arguments")]
199 BuiltinArgumentsInvalid(String),
200}
201
202impl From<crate::proc::MissingSpecialType> for ResolveError {
203 fn from(_unit_struct: crate::proc::MissingSpecialType) -> Self {
204 ResolveError::MissingSpecialType
205 }
206}
207
208pub struct ResolveContext<'a> {
209 pub constants: &'a Arena<crate::Constant>,
210 pub overrides: &'a Arena<crate::Override>,
211 pub types: &'a UniqueArena<crate::Type>,
212 pub special_types: &'a crate::SpecialTypes,
213 pub global_vars: &'a Arena<crate::GlobalVariable>,
214 pub local_vars: &'a Arena<crate::LocalVariable>,
215 pub functions: &'a Arena<crate::Function>,
216 pub arguments: &'a [crate::FunctionArgument],
217}
218
219impl<'a> ResolveContext<'a> {
220 pub const fn with_locals(
222 module: &'a crate::Module,
223 local_vars: &'a Arena<crate::LocalVariable>,
224 arguments: &'a [crate::FunctionArgument],
225 ) -> Self {
226 Self {
227 constants: &module.constants,
228 overrides: &module.overrides,
229 types: &module.types,
230 special_types: &module.special_types,
231 global_vars: &module.global_variables,
232 local_vars,
233 functions: &module.functions,
234 arguments,
235 }
236 }
237
238 pub fn resolve(
254 &self,
255 expr: &crate::Expression,
256 past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
257 ) -> Result<TypeResolution, ResolveError> {
258 use crate::TypeInner as Ti;
259 let types = self.types;
260 Ok(match *expr {
261 crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
262 Ti::Array { base, .. } => TypeResolution::Handle(base),
266 Ti::Matrix { rows, scalar, .. } => {
267 TypeResolution::Value(Ti::Vector { size: rows, scalar })
268 }
269 Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
270 Ti::ValuePointer {
271 size: Some(_),
272 scalar,
273 space,
274 } => TypeResolution::Value(Ti::ValuePointer {
275 size: None,
276 scalar,
277 space,
278 }),
279 Ti::Pointer { base, space } => {
280 TypeResolution::Value(match types[base].inner {
281 Ti::Array { base, .. } => Ti::Pointer { base, space },
282 Ti::Vector { size: _, scalar } => Ti::ValuePointer {
283 size: None,
284 scalar,
285 space,
286 },
287 Ti::Matrix {
289 columns: _,
290 rows,
291 scalar,
292 } => Ti::ValuePointer {
293 size: Some(rows),
294 scalar,
295 space,
296 },
297 Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
298 ref other => {
299 log::error!("Access sub-type {other:?}");
300 return Err(ResolveError::InvalidSubAccess {
301 ty: base,
302 indexed: false,
303 });
304 }
305 })
306 }
307 Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
308 ref other => {
309 log::error!("Access type {other:?}");
310 return Err(ResolveError::InvalidAccess {
311 expr: base,
312 indexed: false,
313 });
314 }
315 },
316 crate::Expression::AccessIndex { base, index } => {
317 match *past(base)?.inner_with(types) {
318 Ti::Vector { size, scalar } => {
319 if index >= size as u32 {
320 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
321 }
322 TypeResolution::Value(Ti::Scalar(scalar))
323 }
324 Ti::Matrix {
325 columns,
326 rows,
327 scalar,
328 } => {
329 if index >= columns as u32 {
330 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
331 }
332 TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
333 }
334 Ti::Array { base, .. } => TypeResolution::Handle(base),
335 Ti::Struct { ref members, .. } => {
336 let member = members
337 .get(index as usize)
338 .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
339 TypeResolution::Handle(member.ty)
340 }
341 Ti::ValuePointer {
342 size: Some(size),
343 scalar,
344 space,
345 } => {
346 if index >= size as u32 {
347 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
348 }
349 TypeResolution::Value(Ti::ValuePointer {
350 size: None,
351 scalar,
352 space,
353 })
354 }
355 Ti::Pointer {
356 base: ty_base,
357 space,
358 } => TypeResolution::Value(match types[ty_base].inner {
359 Ti::Array { base, .. } => Ti::Pointer { base, space },
360 Ti::Vector { size, scalar } => {
361 if index >= size as u32 {
362 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
363 }
364 Ti::ValuePointer {
365 size: None,
366 scalar,
367 space,
368 }
369 }
370 Ti::Matrix {
371 rows,
372 columns,
373 scalar,
374 } => {
375 if index >= columns as u32 {
376 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
377 }
378 Ti::ValuePointer {
379 size: Some(rows),
380 scalar,
381 space,
382 }
383 }
384 Ti::Struct { ref members, .. } => {
385 let member = members
386 .get(index as usize)
387 .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
388 Ti::Pointer {
389 base: member.ty,
390 space,
391 }
392 }
393 Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
394 ref other => {
395 log::error!("Access index sub-type {other:?}");
396 return Err(ResolveError::InvalidSubAccess {
397 ty: ty_base,
398 indexed: true,
399 });
400 }
401 }),
402 Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
403 ref other => {
404 log::error!("Access index type {other:?}");
405 return Err(ResolveError::InvalidAccess {
406 expr: base,
407 indexed: true,
408 });
409 }
410 }
411 }
412 crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
413 Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
414 ref other => {
415 log::error!("Scalar type {other:?}");
416 return Err(ResolveError::InvalidScalar(value));
417 }
418 },
419 crate::Expression::Swizzle {
420 size,
421 vector,
422 pattern: _,
423 } => match *past(vector)?.inner_with(types) {
424 Ti::Vector { size: _, scalar } => {
425 TypeResolution::Value(Ti::Vector { size, scalar })
426 }
427 ref other => {
428 log::error!("Vector type {other:?}");
429 return Err(ResolveError::InvalidVector(vector));
430 }
431 },
432 crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
433 crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
434 crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
435 crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
436 crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
437 crate::Expression::FunctionArgument(index) => {
438 let arg = self
439 .arguments
440 .get(index as usize)
441 .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
442 TypeResolution::Handle(arg.ty)
443 }
444 crate::Expression::GlobalVariable(h) => {
445 let var = &self.global_vars[h];
446 if var.space == crate::AddressSpace::Handle {
447 TypeResolution::Handle(var.ty)
448 } else {
449 TypeResolution::Value(Ti::Pointer {
450 base: var.ty,
451 space: var.space,
452 })
453 }
454 }
455 crate::Expression::LocalVariable(h) => {
456 let var = &self.local_vars[h];
457 TypeResolution::Value(Ti::Pointer {
458 base: var.ty,
459 space: crate::AddressSpace::Function,
460 })
461 }
462 crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
463 Ti::Pointer { base, space: _ } => {
464 if let Ti::Atomic(scalar) = types[base].inner {
465 TypeResolution::Value(Ti::Scalar(scalar))
466 } else {
467 TypeResolution::Handle(base)
468 }
469 }
470 Ti::ValuePointer {
471 size,
472 scalar,
473 space: _,
474 } => TypeResolution::Value(match size {
475 Some(size) => Ti::Vector { size, scalar },
476 None => Ti::Scalar(scalar),
477 }),
478 ref other => {
479 log::error!("Pointer type {other:?}");
480 return Err(ResolveError::InvalidPointer(pointer));
481 }
482 },
483 crate::Expression::ImageSample {
484 image,
485 gather: Some(_),
486 ..
487 } => match *past(image)?.inner_with(types) {
488 Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
489 scalar: crate::Scalar {
490 kind: match class {
491 crate::ImageClass::Sampled { kind, multi: _ } => kind,
492 _ => crate::ScalarKind::Float,
493 },
494 width: 4,
495 },
496 size: crate::VectorSize::Quad,
497 }),
498 ref other => {
499 log::error!("Image type {other:?}");
500 return Err(ResolveError::InvalidImage(image));
501 }
502 },
503 crate::Expression::ImageSample { image, .. }
504 | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
505 Ti::Image { class, .. } => TypeResolution::Value(match class {
506 crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
507 crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
508 scalar: crate::Scalar { kind, width: 4 },
509 size: crate::VectorSize::Quad,
510 },
511 crate::ImageClass::Storage { format, .. } => Ti::Vector {
512 scalar: format.into(),
513 size: crate::VectorSize::Quad,
514 },
515 }),
516 ref other => {
517 log::error!("Image type {other:?}");
518 return Err(ResolveError::InvalidImage(image));
519 }
520 },
521 crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
522 crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
523 Ti::Image { dim, .. } => match dim {
524 crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
525 crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
526 size: crate::VectorSize::Bi,
527 scalar: crate::Scalar::U32,
528 },
529 crate::ImageDimension::D3 => Ti::Vector {
530 size: crate::VectorSize::Tri,
531 scalar: crate::Scalar::U32,
532 },
533 },
534 ref other => {
535 log::error!("Image type {other:?}");
536 return Err(ResolveError::InvalidImage(image));
537 }
538 },
539 crate::ImageQuery::NumLevels
540 | crate::ImageQuery::NumLayers
541 | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
542 }),
543 crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
544 crate::Expression::Binary { op, left, right } => match op {
545 crate::BinaryOperator::Add
546 | crate::BinaryOperator::Subtract
547 | crate::BinaryOperator::Divide
548 | crate::BinaryOperator::Modulo => past(left)?.clone(),
549 crate::BinaryOperator::Multiply => {
550 let (res_left, res_right) = (past(left)?, past(right)?);
551 match (res_left.inner_with(types), res_right.inner_with(types)) {
552 (
553 &Ti::Matrix {
554 columns: _,
555 rows,
556 scalar,
557 },
558 &Ti::Matrix { columns, .. },
559 ) => TypeResolution::Value(Ti::Matrix {
560 columns,
561 rows,
562 scalar,
563 }),
564 (
565 &Ti::Matrix {
566 columns: _,
567 rows,
568 scalar,
569 },
570 &Ti::Vector { .. },
571 ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
572 (
573 &Ti::Vector { .. },
574 &Ti::Matrix {
575 columns,
576 rows: _,
577 scalar,
578 },
579 ) => TypeResolution::Value(Ti::Vector {
580 size: columns,
581 scalar,
582 }),
583 (&Ti::Scalar { .. }, _) => res_right.clone(),
584 (_, &Ti::Scalar { .. }) => res_left.clone(),
585 (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
586 (tl, tr) => {
587 return Err(ResolveError::IncompatibleOperands(format!(
588 "{tl:?} * {tr:?}"
589 )))
590 }
591 }
592 }
593 crate::BinaryOperator::Equal
594 | crate::BinaryOperator::NotEqual
595 | crate::BinaryOperator::Less
596 | crate::BinaryOperator::LessEqual
597 | crate::BinaryOperator::Greater
598 | crate::BinaryOperator::GreaterEqual => {
599 let scalar = crate::Scalar::BOOL;
601 let inner = match *past(left)?.inner_with(types) {
602 Ti::Scalar { .. } => Ti::Scalar(scalar),
603 Ti::Vector { size, .. } => Ti::Vector { size, scalar },
604 ref other => {
605 return Err(ResolveError::IncompatibleOperands(format!(
606 "{op:?}({other:?}, _)"
607 )))
608 }
609 };
610 TypeResolution::Value(inner)
611 }
612 crate::BinaryOperator::LogicalAnd | crate::BinaryOperator::LogicalOr => {
613 let bool = Ti::Scalar(crate::Scalar::BOOL);
615 let ty = past(left)?.inner_with(types);
616 if *ty == bool {
617 TypeResolution::Value(bool)
618 } else {
619 return Err(ResolveError::IncompatibleOperands(format!(
620 "{op:?}({:?}, _)",
621 ty.for_debug(types),
622 )));
623 }
624 }
625 crate::BinaryOperator::And
626 | crate::BinaryOperator::ExclusiveOr
627 | crate::BinaryOperator::InclusiveOr
628 | crate::BinaryOperator::ShiftLeft
629 | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
630 },
631 crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
632 crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
633 crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
634 crate::Expression::Select { accept, .. } => past(accept)?.clone(),
635 crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
636 crate::Expression::Relational { fun, argument } => match fun {
637 crate::RelationalFunction::All | crate::RelationalFunction::Any => {
638 TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
639 }
640 crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
641 match *past(argument)?.inner_with(types) {
642 Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
643 Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
644 scalar: crate::Scalar::BOOL,
645 size,
646 }),
647 ref other => {
648 return Err(ResolveError::IncompatibleOperands(format!(
649 "{fun:?}({other:?})"
650 )))
651 }
652 }
653 }
654 },
655 crate::Expression::Math {
656 fun,
657 arg,
658 arg1,
659 arg2: _,
660 arg3: _,
661 } => {
662 use crate::proc::OverloadSet as _;
663
664 let mut overloads = fun.overloads();
665 log::debug!(
666 "initial overloads for {fun:?}, {:#?}",
667 overloads.for_debug(types)
668 );
669
670 let res_arg = past(arg)?;
678 overloads = overloads.arg(0, res_arg.inner_with(types), types);
679 log::debug!(
680 "overloads after arg 0 of type {:?}: {:#?}",
681 res_arg.for_debug(types),
682 overloads.for_debug(types)
683 );
684
685 if let Some(arg1) = arg1 {
686 let res_arg1 = past(arg1)?;
687 overloads = overloads.arg(1, res_arg1.inner_with(types), types);
688 log::debug!(
689 "overloads after arg 1 of type {:?}: {:#?}",
690 res_arg1.for_debug(types),
691 overloads.for_debug(types)
692 );
693 }
694
695 if overloads.is_empty() {
696 return Err(ResolveError::BuiltinArgumentsInvalid(format!("{fun:?}")));
697 }
698
699 let rule = overloads.most_preferred();
700
701 rule.conclusion.into_resolution(self.special_types)?
702 }
703 crate::Expression::As {
704 expr,
705 kind,
706 convert,
707 } => match *past(expr)?.inner_with(types) {
708 Ti::Scalar(crate::Scalar { width, .. }) => {
709 TypeResolution::Value(Ti::Scalar(crate::Scalar {
710 kind,
711 width: convert.unwrap_or(width),
712 }))
713 }
714 Ti::Vector {
715 size,
716 scalar: crate::Scalar { kind: _, width },
717 } => TypeResolution::Value(Ti::Vector {
718 size,
719 scalar: crate::Scalar {
720 kind,
721 width: convert.unwrap_or(width),
722 },
723 }),
724 Ti::Matrix {
725 columns,
726 rows,
727 mut scalar,
728 } => {
729 if let Some(width) = convert {
730 scalar.width = width;
731 }
732 TypeResolution::Value(Ti::Matrix {
733 columns,
734 rows,
735 scalar,
736 })
737 }
738 ref other => {
739 return Err(ResolveError::IncompatibleOperands(format!(
740 "{other:?} as {kind:?}"
741 )))
742 }
743 },
744 crate::Expression::CallResult(function) => {
745 let result = self.functions[function]
746 .result
747 .as_ref()
748 .ok_or(ResolveError::FunctionReturnsVoid)?;
749 TypeResolution::Handle(result.ty)
750 }
751 crate::Expression::ArrayLength(_) => {
752 TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
753 }
754 crate::Expression::RayQueryProceedResult => {
755 TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
756 }
757 crate::Expression::RayQueryGetIntersection { .. } => {
758 let result = self
759 .special_types
760 .ray_intersection
761 .ok_or(ResolveError::MissingSpecialType)?;
762 TypeResolution::Handle(result)
763 }
764 crate::Expression::RayQueryVertexPositions { .. } => {
765 let result = self
766 .special_types
767 .ray_vertex_return
768 .ok_or(ResolveError::MissingSpecialType)?;
769 TypeResolution::Handle(result)
770 }
771 crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
772 scalar: crate::Scalar::U32,
773 size: crate::VectorSize::Quad,
774 }),
775 })
776 }
777}
778
779pub fn compare_types(
793 lhs: &TypeResolution,
794 rhs: &TypeResolution,
795 types: &UniqueArena<crate::Type>,
796) -> bool {
797 match lhs {
798 &TypeResolution::Handle(lhs_handle)
799 if matches!(
800 types[lhs_handle],
801 ir::Type {
802 inner: ir::TypeInner::Struct { .. },
803 ..
804 }
805 ) =>
806 {
807 rhs.handle()
809 .is_some_and(|rhs_handle| lhs_handle == rhs_handle)
810 }
811 _ => lhs
812 .inner_with(types)
813 .non_struct_equivalent(rhs.inner_with(types), types),
814 }
815}
816
817#[test]
818fn test_error_size() {
819 assert_eq!(size_of::<ResolveError>(), 32);
820}