1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4 arena::Handle,
5 proc::OverloadSet as _,
6 proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13 NotInScope,
14 #[error("Base type {0:?} is not compatible with this expression")]
15 InvalidBaseType(Handle<crate::Expression>),
16 #[error("Accessing with index {0:?} can't be done")]
17 InvalidIndexType(Handle<crate::Expression>),
18 #[error("Accessing {0:?} via a negative index is invalid")]
19 NegativeIndex(Handle<crate::Expression>),
20 #[error("Accessing index {1} is out of {0:?} bounds")]
21 IndexOutOfBounds(Handle<crate::Expression>, u32),
22 #[error("Function argument {0:?} doesn't exist")]
23 FunctionArgumentDoesntExist(u32),
24 #[error("Loading of {0:?} can't be done")]
25 InvalidPointerType(Handle<crate::Expression>),
26 #[error("Array length of {0:?} can't be done")]
27 InvalidArrayType(Handle<crate::Expression>),
28 #[error("Get intersection of {0:?} can't be done")]
29 InvalidRayQueryType(Handle<crate::Expression>),
30 #[error("Splatting {0:?} can't be done")]
31 InvalidSplatType(Handle<crate::Expression>),
32 #[error("Swizzling {0:?} can't be done")]
33 InvalidVectorType(Handle<crate::Expression>),
34 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36 #[error(transparent)]
37 Compose(#[from] super::ComposeError),
38 #[error("Cannot construct zero value of {0:?} because it is not a constructible type")]
39 InvalidZeroValue(Handle<crate::Type>),
40 #[error(transparent)]
41 IndexableLength(#[from] IndexableLengthError),
42 #[error("Operation {0:?} can't work with {1:?}")]
43 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
44 #[error(
45 "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
46 op,
47 lhs_expr,
48 lhs_type,
49 rhs_expr,
50 rhs_type
51 )]
52 InvalidBinaryOperandTypes {
53 op: crate::BinaryOperator,
54 lhs_expr: Handle<crate::Expression>,
55 lhs_type: crate::TypeInner,
56 rhs_expr: Handle<crate::Expression>,
57 rhs_type: crate::TypeInner,
58 },
59 #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
60 SelectValuesTypeMismatch {
61 accept: crate::TypeInner,
62 reject: crate::TypeInner,
63 },
64 #[error("Expected selection condition to be a boolean value, got {actual:?}")]
65 SelectConditionNotABool { actual: crate::TypeInner },
66 #[error("Relational argument {0:?} is not a boolean vector")]
67 InvalidBooleanVector(Handle<crate::Expression>),
68 #[error("Relational argument {0:?} is not a float")]
69 InvalidFloatArgument(Handle<crate::Expression>),
70 #[error("Type resolution failed")]
71 Type(#[from] ResolveError),
72 #[error("Not a global variable")]
73 ExpectedGlobalVariable,
74 #[error("Not a global variable or a function argument")]
75 ExpectedGlobalOrArgument,
76 #[error("Needs to be an binding array instead of {0:?}")]
77 ExpectedBindingArrayType(Handle<crate::Type>),
78 #[error("Needs to be an image instead of {0:?}")]
79 ExpectedImageType(Handle<crate::Type>),
80 #[error("Needs to be an image instead of {0:?}")]
81 ExpectedSamplerType(Handle<crate::Type>),
82 #[error("Unable to operate on image class {0:?}")]
83 InvalidImageClass(crate::ImageClass),
84 #[error("Image atomics are not supported for storage format {0:?}")]
85 InvalidImageFormat(crate::StorageFormat),
86 #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
87 InvalidImageStorageAccess(crate::StorageAccess),
88 #[error("Derivatives can only be taken from scalar and vector floats")]
89 InvalidDerivative,
90 #[error("Image array index parameter is misplaced")]
91 InvalidImageArrayIndex,
92 #[error("Cannot textureLoad from a specific multisample sample on a non-multisampled image.")]
93 InvalidImageSampleSelector,
94 #[error("Cannot textureLoad from a multisampled image without specifying a sample.")]
95 MissingImageSampleSelector,
96 #[error("Cannot textureLoad with a specific mip level on a non-mipmapped image.")]
97 InvalidImageLevelSelector,
98 #[error("Cannot textureLoad from a mipmapped image without specifying a level.")]
99 MissingImageLevelSelector,
100 #[error("Image array index type of {0:?} is not an integer scalar")]
101 InvalidImageArrayIndexType(Handle<crate::Expression>),
102 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
103 InvalidImageOtherIndexType(Handle<crate::Expression>),
104 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
105 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
106 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
107 ComparisonSamplingMismatch {
108 image: crate::ImageClass,
109 sampler: bool,
110 has_ref: bool,
111 },
112 #[error("Sample offset must be a const-expression")]
113 InvalidSampleOffsetExprType,
114 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
115 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
116 #[error("Depth reference {0:?} is not a scalar float")]
117 InvalidDepthReference(Handle<crate::Expression>),
118 #[error("Depth sample level can only be Auto or Zero")]
119 InvalidDepthSampleLevel,
120 #[error("Gather level can only be Zero")]
121 InvalidGatherLevel,
122 #[error("Gather component {0:?} doesn't exist in the image")]
123 InvalidGatherComponent(crate::SwizzleComponent),
124 #[error("Gather can't be done for image dimension {0:?}")]
125 InvalidGatherDimension(crate::ImageDimension),
126 #[error("Sample level (exact) type {0:?} has an invalid type")]
127 InvalidSampleLevelExactType(Handle<crate::Expression>),
128 #[error("Sample level (bias) type {0:?} is not a scalar float")]
129 InvalidSampleLevelBiasType(Handle<crate::Expression>),
130 #[error("Bias can't be done for image dimension {0:?}")]
131 InvalidSampleLevelBiasDimension(crate::ImageDimension),
132 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
133 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
134 #[error("Clamping sample coordinate to edge is not supported with {0}")]
135 InvalidSampleClampCoordinateToEdge(alloc::string::String),
136 #[error("Unable to cast")]
137 InvalidCastArgument,
138 #[error("Invalid argument count for {0:?}")]
139 WrongArgumentCount(crate::MathFunction),
140 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
141 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
142 #[error(
143 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
144 )]
145 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
146 #[error("Shader requires capability {0:?}")]
147 MissingCapabilities(super::Capabilities),
148 #[error(transparent)]
149 Literal(#[from] LiteralError),
150 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
151 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
152 #[error("Invalid operand for cooperative op")]
153 InvalidCooperativeOperand(Handle<crate::Expression>),
154 #[error("Shift amount exceeds the bit width of {lhs_type:?}")]
155 ShiftAmountTooLarge {
156 lhs_type: crate::TypeInner,
157 rhs_expr: Handle<crate::Expression>,
158 },
159 #[error("Division by zero")]
160 DivideByZero,
161}
162
163#[derive(Clone, Debug, thiserror::Error)]
164#[cfg_attr(test, derive(PartialEq))]
165pub enum ConstExpressionError {
166 #[error("The expression is not a constant or override expression")]
167 NonConstOrOverride,
168 #[error("The expression is not a fully evaluated constant expression")]
169 NonFullyEvaluatedConst,
170 #[error(transparent)]
171 Compose(#[from] super::ComposeError),
172 #[error("Splatting {0:?} can't be done")]
173 InvalidSplatType(Handle<crate::Expression>),
174 #[error("Type resolution failed")]
175 Type(#[from] ResolveError),
176 #[error(transparent)]
177 Literal(#[from] LiteralError),
178 #[error(transparent)]
179 Width(#[from] super::r#type::WidthError),
180}
181
182#[derive(Clone, Debug, thiserror::Error)]
183#[cfg_attr(test, derive(PartialEq))]
184pub enum LiteralError {
185 #[error("Float literal is NaN")]
186 NaN,
187 #[error("Float literal is infinite")]
188 Infinity,
189 #[error(transparent)]
190 Width(#[from] super::r#type::WidthError),
191}
192
193struct ExpressionTypeResolver<'a> {
194 root: Handle<crate::Expression>,
195 types: &'a UniqueArena<crate::Type>,
196 info: &'a FunctionInfo,
197}
198
199impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
200 type Output = crate::TypeInner;
201
202 #[allow(clippy::panic)]
203 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
204 if handle < self.root {
205 self.info[handle].ty.inner_with(self.types)
206 } else {
207 panic!(
209 "Depends on {:?}, which has not been processed yet",
210 self.root
211 )
212 }
213 }
214}
215
216impl super::Validator {
217 pub(super) fn validate_const_expression(
218 &self,
219 handle: Handle<crate::Expression>,
220 gctx: crate::proc::GlobalCtx,
221 mod_info: &ModuleInfo,
222 global_expr_kind: &crate::proc::ExpressionKindTracker,
223 ) -> Result<(), ConstExpressionError> {
224 use crate::Expression as E;
225
226 if !global_expr_kind.is_const_or_override(handle) {
227 return Err(ConstExpressionError::NonConstOrOverride);
228 }
229
230 match gctx.global_expressions[handle] {
231 E::Literal(literal) => {
232 self.validate_literal(literal)?;
233 }
234 E::Constant(_) | E::ZeroValue(_) => {}
235 E::Compose { ref components, ty } => {
236 validate_compose(
237 ty,
238 gctx,
239 components.iter().map(|&handle| mod_info[handle].clone()),
240 )?;
241 }
242 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
243 crate::TypeInner::Scalar { .. } => {}
244 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
245 },
246 _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
247 return Err(ConstExpressionError::NonFullyEvaluatedConst)
248 }
249 _ => {}
251 }
252
253 Ok(())
254 }
255
256 fn validate_constant_shift_amounts(
265 left_ty: &crate::TypeInner,
266 right: Handle<crate::Expression>,
267 module: &crate::Module,
268 function: &crate::Function,
269 ) -> Result<(), ExpressionError> {
270 fn is_overflowing_shift(
271 left_ty: &crate::TypeInner,
272 right: Handle<crate::Expression>,
273 module: &crate::Module,
274 function: &crate::Function,
275 ) -> bool {
276 let Some((vec_size, scalar)) = left_ty.vector_size_and_scalar() else {
277 return false;
278 };
279 if !matches!(
280 scalar.kind,
281 crate::ScalarKind::Sint | crate::ScalarKind::Uint
282 ) {
283 return false;
284 }
285 let lhs_bits = u32::from(8 * scalar.width);
286 if vec_size.is_none() {
287 let shift_amount = module
288 .to_ctx()
289 .get_const_val_from::<u32, _>(right, &function.expressions);
290 shift_amount.ok().is_some_and(|s| s >= lhs_bits)
291 } else {
292 match function.expressions[right] {
293 crate::Expression::ZeroValue(_) => false, crate::Expression::Splat { value, .. } => module
295 .to_ctx()
296 .get_const_val_from::<u32, _>(value, &function.expressions)
297 .ok()
298 .is_some_and(|s| s >= lhs_bits),
299 crate::Expression::Compose {
300 ty: _,
301 ref components,
302 } => components.iter().any(|comp| {
303 module
304 .to_ctx()
305 .get_const_val_from::<u32, _>(*comp, &function.expressions)
306 .ok()
307 .is_some_and(|s| s >= lhs_bits)
308 }),
309 _ => false,
310 }
311 }
312 }
313
314 if is_overflowing_shift(left_ty, right, module, function) {
315 Err(ExpressionError::ShiftAmountTooLarge {
316 lhs_type: left_ty.clone(),
317 rhs_expr: right,
318 })
319 } else {
320 Ok(())
321 }
322 }
323
324 fn validate_constant_divisor(
334 left_ty: &crate::TypeInner,
335 right: Handle<crate::Expression>,
336 module: &crate::Module,
337 function: &crate::Function,
338 ) -> Result<(), ExpressionError> {
339 fn contains_zero(
340 handle: Handle<crate::Expression>,
341 expressions: &crate::Arena<crate::Expression>,
342 module: &crate::Module,
343 ) -> bool {
344 match expressions[handle] {
345 crate::Expression::Literal(_) | crate::Expression::ZeroValue(_) => module
346 .to_ctx()
347 .get_const_val_from::<u32, _>(handle, expressions)
348 .ok()
349 .is_some_and(|v| v == 0),
350 crate::Expression::Splat { value, .. } => contains_zero(value, expressions, module),
351 crate::Expression::Compose { ref components, .. } => components
352 .iter()
353 .any(|&comp| contains_zero(comp, expressions, module)),
354 crate::Expression::Constant(c) => {
355 contains_zero(module.constants[c].init, &module.global_expressions, module)
356 }
357 _ => false,
358 }
359 }
360
361 let Some((_, scalar)) = left_ty.vector_size_and_scalar() else {
362 return Ok(());
363 };
364 if !matches!(
365 scalar.kind,
366 crate::ScalarKind::Sint | crate::ScalarKind::Uint
367 ) {
368 return Ok(());
369 }
370
371 if contains_zero(right, &function.expressions, module) {
372 Err(ExpressionError::DivideByZero)
373 } else {
374 Ok(())
375 }
376 }
377
378 #[allow(clippy::too_many_arguments)]
379 pub(super) fn validate_expression(
380 &self,
381 root: Handle<crate::Expression>,
382 expression: &crate::Expression,
383 function: &crate::Function,
384 module: &crate::Module,
385 info: &FunctionInfo,
386 mod_info: &ModuleInfo,
387 expr_kind: &crate::proc::ExpressionKindTracker,
388 ) -> Result<ShaderStages, ExpressionError> {
389 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
390
391 let resolver = ExpressionTypeResolver {
392 root,
393 types: &module.types,
394 info,
395 };
396
397 let stages = match *expression {
398 E::Access { base, index } => {
399 let base_type = &resolver[base];
400 match *base_type {
401 Ti::Matrix { .. }
402 | Ti::Vector { .. }
403 | Ti::Array { .. }
404 | Ti::Pointer { .. }
405 | Ti::ValuePointer { size: Some(_), .. }
406 | Ti::BindingArray { .. } => {}
407 ref other => {
408 log::debug!("Indexing of {other:?}");
409 return Err(ExpressionError::InvalidBaseType(base));
410 }
411 };
412 match resolver[index] {
413 Ti::Scalar(Sc {
415 kind: Sk::Sint | Sk::Uint,
416 ..
417 }) => {}
418 ref other => {
419 log::debug!("Indexing by {other:?}");
420 return Err(ExpressionError::InvalidIndexType(index));
421 }
422 }
423
424 match module
426 .to_ctx()
427 .get_const_val_from(index, &function.expressions)
428 {
429 Ok(value) => {
430 let length = if self.overrides_resolved {
431 base_type.indexable_length_resolved(module)
432 } else {
433 base_type.indexable_length_pending(module)
434 }?;
435 if let crate::proc::IndexableLength::Known(known_length) = length {
438 if value >= known_length {
439 return Err(ExpressionError::IndexOutOfBounds(base, value));
440 }
441 }
442 }
443 Err(crate::proc::ConstValueError::Negative) => {
444 return Err(ExpressionError::NegativeIndex(base))
445 }
446 Err(crate::proc::ConstValueError::NonConst) => {}
447 Err(crate::proc::ConstValueError::InvalidType) => {
448 return Err(ExpressionError::InvalidIndexType(index))
449 }
450 }
451
452 ShaderStages::all()
453 }
454 E::AccessIndex { base, index } => {
455 fn resolve_index_limit(
456 module: &crate::Module,
457 top: Handle<crate::Expression>,
458 ty: &crate::TypeInner,
459 top_level: bool,
460 ) -> Result<u32, ExpressionError> {
461 let limit = match *ty {
462 Ti::Vector { size, .. }
463 | Ti::ValuePointer {
464 size: Some(size), ..
465 } => size as u32,
466 Ti::Matrix { columns, .. } => columns as u32,
467 Ti::Array {
468 size: crate::ArraySize::Constant(len),
469 ..
470 } => len.get(),
471 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
473 resolve_index_limit(module, top, &module.types[base].inner, false)?
474 }
475 Ti::Struct { ref members, .. } => members.len() as u32,
476 ref other => {
477 log::debug!("Indexing of {other:?}");
478 return Err(ExpressionError::InvalidBaseType(top));
479 }
480 };
481 Ok(limit)
482 }
483
484 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
485 if index >= limit {
486 return Err(ExpressionError::IndexOutOfBounds(base, limit));
487 }
488 ShaderStages::all()
489 }
490 E::Splat { size: _, value } => match resolver[value] {
491 Ti::Scalar { .. } => ShaderStages::all(),
492 ref other => {
493 log::debug!("Splat scalar type {other:?}");
494 return Err(ExpressionError::InvalidSplatType(value));
495 }
496 },
497 E::Swizzle {
498 size,
499 vector,
500 pattern,
501 } => {
502 let vec_size = match resolver[vector] {
503 Ti::Vector { size: vec_size, .. } => vec_size,
504 ref other => {
505 log::debug!("Swizzle vector type {other:?}");
506 return Err(ExpressionError::InvalidVectorType(vector));
507 }
508 };
509 for &sc in pattern[..size as usize].iter() {
510 if sc as u8 >= vec_size as u8 {
511 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
512 }
513 }
514 ShaderStages::all()
515 }
516 E::Literal(literal) => {
517 self.validate_literal(literal)?;
518 ShaderStages::all()
519 }
520 E::Constant(_) | E::Override(_) => ShaderStages::all(),
521 E::ZeroValue(ty) => {
522 if !mod_info[ty].contains(TypeFlags::CONSTRUCTIBLE) {
523 return Err(ExpressionError::InvalidZeroValue(ty));
524 }
525 ShaderStages::all()
526 }
527 E::Compose { ref components, ty } => {
528 validate_compose(
529 ty,
530 module.to_ctx(),
531 components.iter().map(|&handle| info[handle].ty.clone()),
532 )?;
533 ShaderStages::all()
534 }
535 E::FunctionArgument(index) => {
536 if index >= function.arguments.len() as u32 {
537 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
538 }
539 ShaderStages::all()
540 }
541 E::GlobalVariable(_handle) => ShaderStages::all(),
542 E::LocalVariable(_handle) => ShaderStages::all(),
543 E::Load { pointer } => {
544 match resolver[pointer] {
545 Ti::Pointer { base, .. }
546 if self.types[base.index()]
547 .flags
548 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
549 Ti::ValuePointer { .. } => {}
550 ref other => {
551 log::debug!("Loading {other:?}");
552 return Err(ExpressionError::InvalidPointerType(pointer));
553 }
554 }
555 ShaderStages::all()
556 }
557 E::ImageSample {
558 image,
559 sampler,
560 gather,
561 coordinate,
562 array_index,
563 offset,
564 level,
565 depth_ref,
566 clamp_to_edge,
567 } => {
568 let image_ty = Self::global_var_ty(module, function, image)?;
570 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
571
572 let comparison = match module.types[sampler_ty].inner {
573 Ti::Sampler { comparison } => comparison,
574 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
575 };
576
577 let (class, dim) = match module.types[image_ty].inner {
578 Ti::Image {
579 class,
580 arrayed,
581 dim,
582 } => {
583 if arrayed != array_index.is_some() {
585 return Err(ExpressionError::InvalidImageArrayIndex);
586 }
587 if let Some(expr) = array_index {
588 match resolver[expr] {
589 Ti::Scalar(Sc {
590 kind: Sk::Sint | Sk::Uint,
591 ..
592 }) => {}
593 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
594 }
595 }
596 (class, dim)
597 }
598 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
599 };
600
601 let image_depth = match class {
603 crate::ImageClass::Sampled {
604 kind: crate::ScalarKind::Float,
605 multi: false,
606 } => false,
607 crate::ImageClass::Sampled {
608 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
609 multi: false,
610 } if gather.is_some() => false,
611 crate::ImageClass::External => false,
612 crate::ImageClass::Depth { multi: false } => true,
613 _ => return Err(ExpressionError::InvalidImageClass(class)),
614 };
615 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
616 return Err(ExpressionError::ComparisonSamplingMismatch {
617 image: class,
618 sampler: comparison,
619 has_ref: depth_ref.is_some(),
620 });
621 }
622
623 let num_components = match dim {
625 crate::ImageDimension::D1 => 1,
626 crate::ImageDimension::D2 => 2,
627 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
628 };
629 match resolver[coordinate] {
630 Ti::Scalar(Sc {
631 kind: Sk::Float, ..
632 }) if num_components == 1 => {}
633 Ti::Vector {
634 size,
635 scalar:
636 Sc {
637 kind: Sk::Float, ..
638 },
639 } if size as u32 == num_components => {}
640 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
641 }
642
643 if let Some(const_expr) = offset {
645 if !expr_kind.is_const(const_expr) {
646 return Err(ExpressionError::InvalidSampleOffsetExprType);
647 }
648
649 match resolver[const_expr] {
650 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
651 Ti::Vector {
652 size,
653 scalar: Sc { kind: Sk::Sint, .. },
654 } if size as u32 == num_components => {}
655 _ => {
656 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
657 }
658 }
659 }
660
661 if let Some(expr) = depth_ref {
663 match resolver[expr] {
664 Ti::Scalar(Sc {
665 kind: Sk::Float, ..
666 }) => {}
667 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
668 }
669 match level {
670 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
671 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
672 }
673 }
674
675 if let Some(component) = gather {
676 match dim {
677 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
678 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
679 return Err(ExpressionError::InvalidGatherDimension(dim))
680 }
681 };
682 let max_component = match class {
683 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
684 _ => crate::SwizzleComponent::W,
685 };
686 if component > max_component {
687 return Err(ExpressionError::InvalidGatherComponent(component));
688 }
689 match level {
690 crate::SampleLevel::Zero => {}
691 _ => return Err(ExpressionError::InvalidGatherLevel),
692 }
693 }
694
695 if clamp_to_edge {
698 if !matches!(
699 class,
700 crate::ImageClass::Sampled {
701 kind: crate::ScalarKind::Float,
702 multi: false
703 } | crate::ImageClass::External
704 ) {
705 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
706 alloc::format!("image class `{class:?}`"),
707 ));
708 }
709 if dim != crate::ImageDimension::D2 {
710 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
711 alloc::format!("image dimension `{dim:?}`"),
712 ));
713 }
714 if gather.is_some() {
715 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
716 "gather".into(),
717 ));
718 }
719 if array_index.is_some() {
720 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
721 "array index".into(),
722 ));
723 }
724 if offset.is_some() {
725 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
726 "offset".into(),
727 ));
728 }
729 if level != crate::SampleLevel::Zero {
730 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
731 "non-zero level".into(),
732 ));
733 }
734 if depth_ref.is_some() {
735 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
736 "depth comparison".into(),
737 ));
738 }
739 }
740
741 if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
743 return Err(ExpressionError::InvalidImageClass(class));
744 }
745
746 match level {
748 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
749 crate::SampleLevel::Zero => ShaderStages::all(),
750 crate::SampleLevel::Exact(expr) => {
751 match class {
752 crate::ImageClass::Depth { .. } => match resolver[expr] {
753 Ti::Scalar(Sc {
754 kind: Sk::Sint | Sk::Uint,
755 ..
756 }) => {}
757 _ => {
758 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
759 }
760 },
761 _ => match resolver[expr] {
762 Ti::Scalar(Sc {
763 kind: Sk::Float, ..
764 }) => {}
765 _ => {
766 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
767 }
768 },
769 }
770 ShaderStages::all()
771 }
772 crate::SampleLevel::Bias(expr) => {
773 match resolver[expr] {
774 Ti::Scalar(Sc {
775 kind: Sk::Float, ..
776 }) => {}
777 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
778 }
779 match class {
780 crate::ImageClass::Sampled {
781 kind: Sk::Float,
782 multi: false,
783 } => {
784 if dim == crate::ImageDimension::D1 {
785 return Err(ExpressionError::InvalidSampleLevelBiasDimension(
786 dim,
787 ));
788 }
789 }
790 _ => return Err(ExpressionError::InvalidImageClass(class)),
791 }
792 ShaderStages::FRAGMENT
793 }
794 crate::SampleLevel::Gradient { x, y } => {
795 match resolver[x] {
796 Ti::Scalar(Sc {
797 kind: Sk::Float, ..
798 }) if num_components == 1 => {}
799 Ti::Vector {
800 size,
801 scalar:
802 Sc {
803 kind: Sk::Float, ..
804 },
805 } if size as u32 == num_components => {}
806 _ => {
807 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
808 }
809 }
810 match resolver[y] {
811 Ti::Scalar(Sc {
812 kind: Sk::Float, ..
813 }) if num_components == 1 => {}
814 Ti::Vector {
815 size,
816 scalar:
817 Sc {
818 kind: Sk::Float, ..
819 },
820 } if size as u32 == num_components => {}
821 _ => {
822 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
823 }
824 }
825 ShaderStages::all()
826 }
827 }
828 }
829 E::ImageLoad {
830 image,
831 coordinate,
832 array_index,
833 sample,
834 level,
835 } => {
836 let ty = Self::global_var_ty(module, function, image)?;
837 let Ti::Image {
838 class,
839 arrayed,
840 dim,
841 } = module.types[ty].inner
842 else {
843 return Err(ExpressionError::ExpectedImageType(ty));
844 };
845
846 match resolver[coordinate].image_storage_coordinates() {
847 Some(coord_dim) if coord_dim == dim => {}
848 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
849 };
850 if arrayed != array_index.is_some() {
851 return Err(ExpressionError::InvalidImageArrayIndex);
852 }
853 if let Some(expr) = array_index {
854 if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
855 return Err(ExpressionError::InvalidImageArrayIndexType(expr));
856 }
857 }
858
859 match (sample, class.is_multisampled()) {
860 (None, false) => {}
861 (Some(sample), true) => {
862 if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
863 return Err(ExpressionError::InvalidImageOtherIndexType(sample));
864 }
865 }
866 (Some(_), false) => {
867 return Err(ExpressionError::InvalidImageSampleSelector);
868 }
869 (None, true) => {
870 return Err(ExpressionError::MissingImageSampleSelector);
871 }
872 }
873
874 match (level, class.is_mipmapped()) {
875 (None, false) => {}
876 (Some(level), true) => match resolver[level] {
877 Ti::Scalar(Sc {
878 kind: Sk::Sint | Sk::Uint,
879 width: _,
880 }) => {}
881 _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
882 },
883 (Some(_), false) => {
884 return Err(ExpressionError::InvalidImageLevelSelector);
885 }
886 (None, true) => {
887 return Err(ExpressionError::MissingImageLevelSelector);
888 }
889 }
890 ShaderStages::all()
891 }
892 E::ImageQuery { image, query } => {
893 let ty = Self::global_var_ty(module, function, image)?;
894 match module.types[ty].inner {
895 Ti::Image { class, arrayed, .. } => {
896 let good = match query {
897 crate::ImageQuery::NumLayers => arrayed,
898 crate::ImageQuery::Size { level: None } => true,
899 crate::ImageQuery::Size { level: Some(level) } => {
900 match resolver[level] {
901 Ti::Scalar(Sc::I32 | Sc::U32) => {}
902 _ => {
903 return Err(ExpressionError::InvalidImageOtherIndexType(
904 level,
905 ))
906 }
907 }
908 class.is_mipmapped()
909 }
910 crate::ImageQuery::NumLevels => class.is_mipmapped(),
911 crate::ImageQuery::NumSamples => class.is_multisampled(),
912 };
913 if !good {
914 return Err(ExpressionError::InvalidImageClass(class));
915 }
916 }
917 _ => return Err(ExpressionError::ExpectedImageType(ty)),
918 }
919 ShaderStages::all()
920 }
921 E::Unary { op, expr } => {
922 use crate::UnaryOperator as Uo;
923 let Some((_, scalar)) = resolver[expr].vector_size_and_scalar() else {
924 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
925 };
926 match (op, scalar.kind) {
927 (Uo::Negate, Sk::Float | Sk::Sint) => {}
928 (Uo::LogicalNot, Sk::Bool) => {}
929 (Uo::BitwiseNot, Sk::Sint | Sk::Uint) => {}
930 _ => return Err(ExpressionError::InvalidUnaryOperandType(op, expr)),
931 }
932 ShaderStages::all()
933 }
934 E::Binary { op, left, right } => {
935 use crate::BinaryOperator as Bo;
936 let left_inner = &resolver[left];
937 let right_inner = &resolver[right];
938 let good = match op {
939 Bo::Add | Bo::Subtract => match *left_inner {
940 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
941 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
942 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
943 },
944 Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => {
945 left_inner == right_inner
946 }
947 _ => false,
948 },
949 Bo::Divide | Bo::Modulo => match *left_inner {
950 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
951 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
952 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
953 },
954 _ => false,
955 },
956 Bo::Multiply => {
957 let kind_allowed = match left_inner.scalar_kind() {
958 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
959 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
960 };
961 let types_match = match (left_inner, right_inner) {
962 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
964 | (
965 &Ti::Vector {
966 scalar: scalar1, ..
967 },
968 &Ti::Scalar(scalar2),
969 )
970 | (
971 &Ti::Scalar(scalar1),
972 &Ti::Vector {
973 scalar: scalar2, ..
974 },
975 ) => scalar1 == scalar2,
976 (
978 &Ti::Scalar(Sc {
979 kind: Sk::Float, ..
980 }),
981 &Ti::Matrix { .. },
982 )
983 | (
984 &Ti::Matrix { .. },
985 &Ti::Scalar(Sc {
986 kind: Sk::Float, ..
987 }),
988 ) => true,
989 (
991 &Ti::Vector {
992 size: size1,
993 scalar: scalar1,
994 },
995 &Ti::Vector {
996 size: size2,
997 scalar: scalar2,
998 },
999 ) => scalar1 == scalar2 && size1 == size2,
1000 (
1002 &Ti::Matrix { columns, .. },
1003 &Ti::Vector {
1004 size,
1005 scalar:
1006 Sc {
1007 kind: Sk::Float, ..
1008 },
1009 },
1010 ) => columns == size,
1011 (
1013 &Ti::Vector {
1014 size,
1015 scalar:
1016 Sc {
1017 kind: Sk::Float, ..
1018 },
1019 },
1020 &Ti::Matrix { rows, .. },
1021 ) => size == rows,
1022 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
1024 columns == rows
1025 }
1026 (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. })
1028 | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => {
1029 s1 == s2
1030 }
1031 _ => false,
1032 };
1033 let left_width = left_inner.scalar_width().unwrap_or(0);
1034 let right_width = right_inner.scalar_width().unwrap_or(0);
1035 kind_allowed && types_match && left_width == right_width
1036 }
1037 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
1038 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
1039 match *left_inner {
1040 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1041 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
1042 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
1043 },
1044 ref other => {
1045 log::debug!("Op {op:?} left type {other:?}");
1046 false
1047 }
1048 }
1049 }
1050 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
1051 Ti::Scalar(Sc { kind: Sk::Bool, .. })
1052 | Ti::Vector {
1053 scalar: Sc { kind: Sk::Bool, .. },
1054 ..
1055 } => left_inner == right_inner,
1056 ref other => {
1057 log::debug!("Op {op:?} left type {other:?}");
1058 false
1059 }
1060 },
1061 Bo::And | Bo::InclusiveOr => match *left_inner {
1062 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1063 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
1064 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1065 },
1066 ref other => {
1067 log::debug!("Op {op:?} left type {other:?}");
1068 false
1069 }
1070 },
1071 Bo::ExclusiveOr => match *left_inner {
1072 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1073 Sk::Sint | Sk::Uint => left_inner == right_inner,
1074 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
1075 },
1076 ref other => {
1077 log::debug!("Op {op:?} left type {other:?}");
1078 false
1079 }
1080 },
1081 Bo::ShiftLeft | Bo::ShiftRight => {
1082 let (base_size, base_scalar) = match *left_inner {
1083 Ti::Scalar(scalar) => (Ok(None), scalar),
1084 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
1085 ref other => {
1086 log::debug!("Op {op:?} base type {other:?}");
1087 (Err(()), Sc::BOOL)
1088 }
1089 };
1090 let shift_size = match *right_inner {
1091 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
1092 Ti::Vector {
1093 size,
1094 scalar: Sc { kind: Sk::Uint, .. },
1095 } => Ok(Some(size)),
1096 ref other => {
1097 log::debug!("Op {op:?} shift type {other:?}");
1098 Err(())
1099 }
1100 };
1101 match base_scalar.kind {
1102 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
1103 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
1104 }
1105 }
1106 };
1107 if !good {
1108 log::debug!(
1109 "Left: {:?} of type {:?}",
1110 function.expressions[left],
1111 left_inner
1112 );
1113 log::debug!(
1114 "Right: {:?} of type {:?}",
1115 function.expressions[right],
1116 right_inner
1117 );
1118 return Err(ExpressionError::InvalidBinaryOperandTypes {
1119 op,
1120 lhs_expr: left,
1121 lhs_type: left_inner.clone(),
1122 rhs_expr: right,
1123 rhs_type: right_inner.clone(),
1124 });
1125 }
1126 if matches!(op, Bo::ShiftLeft | Bo::ShiftRight) {
1128 Self::validate_constant_shift_amounts(left_inner, right, module, function)?;
1129 }
1130 if matches!(op, Bo::Divide | Bo::Modulo) {
1132 Self::validate_constant_divisor(left_inner, right, module, function)?;
1133 }
1134 ShaderStages::all()
1135 }
1136 E::Select {
1137 condition,
1138 accept,
1139 reject,
1140 } => {
1141 let accept_inner = &resolver[accept];
1142 let reject_inner = &resolver[reject];
1143 let condition_ty = &resolver[condition];
1144 let condition_good = match *condition_ty {
1145 Ti::Scalar(Sc {
1146 kind: Sk::Bool,
1147 width: _,
1148 }) => {
1149 match *accept_inner {
1152 Ti::Scalar { .. } | Ti::Vector { .. } => true,
1153 _ => false,
1154 }
1155 }
1156 Ti::Vector {
1157 size,
1158 scalar:
1159 Sc {
1160 kind: Sk::Bool,
1161 width: _,
1162 },
1163 } => match *accept_inner {
1164 Ti::Vector {
1165 size: other_size, ..
1166 } => size == other_size,
1167 _ => false,
1168 },
1169 _ => false,
1170 };
1171 if accept_inner != reject_inner {
1172 return Err(ExpressionError::SelectValuesTypeMismatch {
1173 accept: accept_inner.clone(),
1174 reject: reject_inner.clone(),
1175 });
1176 }
1177 if !condition_good {
1178 return Err(ExpressionError::SelectConditionNotABool {
1179 actual: condition_ty.clone(),
1180 });
1181 }
1182 ShaderStages::all()
1183 }
1184 E::Derivative { expr, .. } => {
1185 let Some((_, scalar)) = resolver[expr].vector_size_and_scalar() else {
1186 return Err(ExpressionError::InvalidDerivative);
1187 };
1188 if scalar.kind != Sk::Float || scalar.width < 4 {
1189 return Err(ExpressionError::InvalidDerivative);
1192 }
1193 ShaderStages::FRAGMENT
1194 }
1195 E::Relational { fun, argument } => {
1196 use crate::RelationalFunction as Rf;
1197 let argument_inner = &resolver[argument];
1198 match fun {
1199 Rf::All | Rf::Any => match *argument_inner {
1200 Ti::Vector {
1201 scalar: Sc { kind: Sk::Bool, .. },
1202 ..
1203 } => {}
1204 ref other => {
1205 log::debug!("All/Any of type {other:?}");
1206 return Err(ExpressionError::InvalidBooleanVector(argument));
1207 }
1208 },
1209 Rf::IsNan | Rf::IsInf => match *argument_inner {
1210 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1211 if scalar.kind == Sk::Float => {}
1212 ref other => {
1213 log::debug!("Float test of type {other:?}");
1214 return Err(ExpressionError::InvalidFloatArgument(argument));
1215 }
1216 },
1217 }
1218 ShaderStages::all()
1219 }
1220 E::Math {
1221 fun,
1222 arg,
1223 arg1,
1224 arg2,
1225 arg3,
1226 } => {
1227 if matches!(
1228 fun,
1229 crate::MathFunction::QuantizeToF16
1230 | crate::MathFunction::Pack2x16float
1231 | crate::MathFunction::Unpack2x16float
1232 ) && !self
1233 .capabilities
1234 .contains(crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32)
1235 {
1236 return Err(ExpressionError::MissingCapabilities(
1237 crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32,
1238 ));
1239 }
1240
1241 let actuals: &[_] = match (arg1, arg2, arg3) {
1242 (None, None, None) => &[arg],
1243 (Some(arg1), None, None) => &[arg, arg1],
1244 (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1245 (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1246 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1247 };
1248
1249 let resolve = |arg| &resolver[arg];
1250 let actual_types: &[_] = match *actuals {
1251 [arg0] => &[resolve(arg0)],
1252 [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1253 [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1254 [arg0, arg1, arg2, arg3] => {
1255 &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1256 }
1257 _ => unreachable!(),
1258 };
1259
1260 let mut overloads = fun.overloads();
1262 log::debug!(
1263 "initial overloads for {:?}: {:#?}",
1264 fun,
1265 overloads.for_debug(&module.types)
1266 );
1267
1268 for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1276 overloads = overloads.arg(i, ty, &module.types);
1279 log::debug!(
1280 "overloads after arg {i}: {:#?}",
1281 overloads.for_debug(&module.types)
1282 );
1283
1284 if overloads.is_empty() {
1285 log::debug!("all overloads eliminated");
1286 return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1287 }
1288 }
1289
1290 if actuals.len() < overloads.min_arguments() {
1291 return Err(ExpressionError::WrongArgumentCount(fun));
1292 }
1293
1294 ShaderStages::all()
1295 }
1296 E::As {
1297 expr,
1298 kind,
1299 convert,
1300 } => {
1301 let mut base_scalar = match resolver[expr] {
1302 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1303 scalar
1304 }
1305 crate::TypeInner::Matrix { scalar, .. } => scalar,
1306 _ => return Err(ExpressionError::InvalidCastArgument),
1307 };
1308 base_scalar.kind = kind;
1309 if let Some(width) = convert {
1310 base_scalar.width = width;
1311 }
1312 if self.check_width(base_scalar).is_err() {
1313 return Err(ExpressionError::InvalidCastArgument);
1314 }
1315 ShaderStages::all()
1316 }
1317 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1318 E::AtomicResult { .. } => {
1319 ShaderStages::all()
1324 }
1325 E::WorkGroupUniformLoadResult { ty } => {
1326 if self.types[ty.index()]
1327 .flags
1328 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1331 {
1332 ShaderStages::COMPUTE_LIKE
1333 } else {
1334 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1335 }
1336 }
1337 E::ArrayLength(expr) => match resolver[expr] {
1338 Ti::Pointer { base, .. } => {
1339 let base_ty = &resolver.types[base];
1340 if let Ti::Array {
1341 size: crate::ArraySize::Dynamic,
1342 ..
1343 } = base_ty.inner
1344 {
1345 ShaderStages::all()
1346 } else {
1347 return Err(ExpressionError::InvalidArrayType(expr));
1348 }
1349 }
1350 ref other => {
1351 log::debug!("Array length of {other:?}");
1352 return Err(ExpressionError::InvalidArrayType(expr));
1353 }
1354 },
1355 E::RayQueryProceedResult => ShaderStages::all(),
1356 E::RayQueryGetIntersection {
1357 query,
1358 committed: _,
1359 } => match resolver[query] {
1360 Ti::Pointer {
1361 base,
1362 space: crate::AddressSpace::Function,
1363 } => match resolver.types[base].inner {
1364 Ti::RayQuery { .. } => ShaderStages::all(),
1365 ref other => {
1366 log::debug!("Intersection result of a pointer to {other:?}");
1367 return Err(ExpressionError::InvalidRayQueryType(query));
1368 }
1369 },
1370 ref other => {
1371 log::debug!("Intersection result of {other:?}");
1372 return Err(ExpressionError::InvalidRayQueryType(query));
1373 }
1374 },
1375 E::RayQueryVertexPositions {
1376 query,
1377 committed: _,
1378 } => match resolver[query] {
1379 Ti::Pointer {
1380 base,
1381 space: crate::AddressSpace::Function,
1382 } => match resolver.types[base].inner {
1383 Ti::RayQuery {
1384 vertex_return: true,
1385 } => ShaderStages::all(),
1386 ref other => {
1387 log::debug!("Intersection result of a pointer to {other:?}");
1388 return Err(ExpressionError::InvalidRayQueryType(query));
1389 }
1390 },
1391 ref other => {
1392 log::debug!("Intersection result of {other:?}");
1393 return Err(ExpressionError::InvalidRayQueryType(query));
1394 }
1395 },
1396 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1397 E::CooperativeLoad { ref data, .. } => {
1398 if resolver[data.pointer]
1399 .pointer_base_type()
1400 .and_then(|tr| tr.inner_with(&module.types).scalar())
1401 .is_none()
1402 {
1403 return Err(ExpressionError::InvalidPointerType(data.pointer));
1404 }
1405 ShaderStages::COMPUTE
1406 }
1407 E::CooperativeMultiplyAdd { a, b, c } => {
1408 let roles = [
1409 crate::CooperativeRole::A,
1410 crate::CooperativeRole::B,
1411 crate::CooperativeRole::C,
1412 ];
1413 for (operand, expected_role) in [a, b, c].into_iter().zip(roles) {
1414 match resolver[operand] {
1415 Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
1416 ref other => {
1417 log::debug!("{expected_role:?} operand type: {other:?}");
1418 return Err(ExpressionError::InvalidCooperativeOperand(a));
1419 }
1420 }
1421 }
1422 ShaderStages::COMPUTE
1423 }
1424 };
1425 Ok(stages)
1426 }
1427
1428 fn global_var_ty(
1429 module: &crate::Module,
1430 function: &crate::Function,
1431 expr: Handle<crate::Expression>,
1432 ) -> Result<Handle<crate::Type>, ExpressionError> {
1433 use crate::Expression as Ex;
1434
1435 match function.expressions[expr] {
1436 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1437 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1438 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1439 match function.expressions[base] {
1440 Ex::GlobalVariable(var_handle) => {
1441 let array_ty = module.global_variables[var_handle].ty;
1442
1443 match module.types[array_ty].inner {
1444 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1445 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1446 }
1447 }
1448 _ => Err(ExpressionError::ExpectedGlobalVariable),
1449 }
1450 }
1451 _ => Err(ExpressionError::ExpectedGlobalVariable),
1452 }
1453 }
1454
1455 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1456 let _ = self.check_width(literal.scalar())?;
1457 check_literal_value(literal)?;
1458
1459 Ok(())
1460 }
1461}
1462
1463pub const fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1464 let is_nan = match literal {
1465 crate::Literal::F64(v) => v.is_nan(),
1466 crate::Literal::F32(v) => v.is_nan(),
1467 _ => false,
1468 };
1469 if is_nan {
1470 return Err(LiteralError::NaN);
1471 }
1472
1473 let is_infinite = match literal {
1474 crate::Literal::F64(v) => v.is_infinite(),
1475 crate::Literal::F32(v) => v.is_infinite(),
1476 _ => false,
1477 };
1478 if is_infinite {
1479 return Err(LiteralError::Infinity);
1480 }
1481
1482 Ok(())
1483}
1484
1485#[cfg(test)]
1486fn validate_with_expression(
1488 expr: crate::Expression,
1489 caps: super::Capabilities,
1490) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1491 use crate::span::Span;
1492
1493 let mut function = crate::Function::default();
1494 function.expressions.append(expr, Span::default());
1495 function.body.push(
1496 crate::Statement::Emit(function.expressions.range_from(0)),
1497 Span::default(),
1498 );
1499
1500 let mut module = crate::Module::default();
1501 module.functions.append(function, Span::default());
1502
1503 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1504
1505 validator.validate(&module)
1506}
1507
1508#[cfg(test)]
1509fn validate_with_const_expression(
1511 expr: crate::Expression,
1512 caps: super::Capabilities,
1513) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1514 use crate::span::Span;
1515
1516 let mut module = crate::Module::default();
1517 module.global_expressions.append(expr, Span::default());
1518
1519 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1520
1521 validator.validate(&module)
1522}
1523
1524#[test]
1526fn f64_runtime_literals() {
1527 let result = validate_with_expression(
1528 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1529 super::Capabilities::default(),
1530 );
1531 let error = result.unwrap_err().into_inner();
1532 assert!(matches!(
1533 error,
1534 crate::valid::ValidationError::Function {
1535 source: super::FunctionError::Expression {
1536 source: ExpressionError::Literal(LiteralError::Width(
1537 super::r#type::WidthError::MissingCapability {
1538 name: "f64",
1539 flag: "FLOAT64",
1540 }
1541 ),),
1542 ..
1543 },
1544 ..
1545 }
1546 ));
1547
1548 let result = validate_with_expression(
1549 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1550 super::Capabilities::default() | super::Capabilities::FLOAT64,
1551 );
1552 assert!(result.is_ok());
1553}
1554
1555#[test]
1557fn f64_const_literals() {
1558 let result = validate_with_const_expression(
1559 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1560 super::Capabilities::default(),
1561 );
1562 let error = result.unwrap_err().into_inner();
1563 assert!(matches!(
1564 error,
1565 crate::valid::ValidationError::ConstExpression {
1566 source: ConstExpressionError::Literal(LiteralError::Width(
1567 super::r#type::WidthError::MissingCapability {
1568 name: "f64",
1569 flag: "FLOAT64",
1570 }
1571 )),
1572 ..
1573 }
1574 ));
1575
1576 let result = validate_with_const_expression(
1577 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1578 super::Capabilities::default() | super::Capabilities::FLOAT64,
1579 );
1580 assert!(result.is_ok());
1581}