1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4 arena::Handle,
5 proc::OverloadSet as _,
6 proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13 NotInScope,
14 #[error("Base type {0:?} is not compatible with this expression")]
15 InvalidBaseType(Handle<crate::Expression>),
16 #[error("Accessing with index {0:?} can't be done")]
17 InvalidIndexType(Handle<crate::Expression>),
18 #[error("Accessing {0:?} via a negative index is invalid")]
19 NegativeIndex(Handle<crate::Expression>),
20 #[error("Accessing index {1} is out of {0:?} bounds")]
21 IndexOutOfBounds(Handle<crate::Expression>, u32),
22 #[error("Function argument {0:?} doesn't exist")]
23 FunctionArgumentDoesntExist(u32),
24 #[error("Loading of {0:?} can't be done")]
25 InvalidPointerType(Handle<crate::Expression>),
26 #[error("Array length of {0:?} can't be done")]
27 InvalidArrayType(Handle<crate::Expression>),
28 #[error("Get intersection of {0:?} can't be done")]
29 InvalidRayQueryType(Handle<crate::Expression>),
30 #[error("Splatting {0:?} can't be done")]
31 InvalidSplatType(Handle<crate::Expression>),
32 #[error("Swizzling {0:?} can't be done")]
33 InvalidVectorType(Handle<crate::Expression>),
34 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36 #[error(transparent)]
37 Compose(#[from] super::ComposeError),
38 #[error(transparent)]
39 IndexableLength(#[from] IndexableLengthError),
40 #[error("Operation {0:?} can't work with {1:?}")]
41 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42 #[error(
43 "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
44 op,
45 lhs_expr,
46 lhs_type,
47 rhs_expr,
48 rhs_type
49 )]
50 InvalidBinaryOperandTypes {
51 op: crate::BinaryOperator,
52 lhs_expr: Handle<crate::Expression>,
53 lhs_type: crate::TypeInner,
54 rhs_expr: Handle<crate::Expression>,
55 rhs_type: crate::TypeInner,
56 },
57 #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
58 SelectValuesTypeMismatch {
59 accept: crate::TypeInner,
60 reject: crate::TypeInner,
61 },
62 #[error("Expected selection condition to be a boolean value, got {actual:?}")]
63 SelectConditionNotABool { actual: crate::TypeInner },
64 #[error("Relational argument {0:?} is not a boolean vector")]
65 InvalidBooleanVector(Handle<crate::Expression>),
66 #[error("Relational argument {0:?} is not a float")]
67 InvalidFloatArgument(Handle<crate::Expression>),
68 #[error("Type resolution failed")]
69 Type(#[from] ResolveError),
70 #[error("Not a global variable")]
71 ExpectedGlobalVariable,
72 #[error("Not a global variable or a function argument")]
73 ExpectedGlobalOrArgument,
74 #[error("Needs to be an binding array instead of {0:?}")]
75 ExpectedBindingArrayType(Handle<crate::Type>),
76 #[error("Needs to be an image instead of {0:?}")]
77 ExpectedImageType(Handle<crate::Type>),
78 #[error("Needs to be an image instead of {0:?}")]
79 ExpectedSamplerType(Handle<crate::Type>),
80 #[error("Unable to operate on image class {0:?}")]
81 InvalidImageClass(crate::ImageClass),
82 #[error("Image atomics are not supported for storage format {0:?}")]
83 InvalidImageFormat(crate::StorageFormat),
84 #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
85 InvalidImageStorageAccess(crate::StorageAccess),
86 #[error("Derivatives can only be taken from scalar and vector floats")]
87 InvalidDerivative,
88 #[error("Image array index parameter is misplaced")]
89 InvalidImageArrayIndex,
90 #[error("Inappropriate sample or level-of-detail index for texel access")]
91 InvalidImageOtherIndex,
92 #[error("Image array index type of {0:?} is not an integer scalar")]
93 InvalidImageArrayIndexType(Handle<crate::Expression>),
94 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
95 InvalidImageOtherIndexType(Handle<crate::Expression>),
96 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
97 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
98 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
99 ComparisonSamplingMismatch {
100 image: crate::ImageClass,
101 sampler: bool,
102 has_ref: bool,
103 },
104 #[error("Sample offset must be a const-expression")]
105 InvalidSampleOffsetExprType,
106 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
107 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
108 #[error("Depth reference {0:?} is not a scalar float")]
109 InvalidDepthReference(Handle<crate::Expression>),
110 #[error("Depth sample level can only be Auto or Zero")]
111 InvalidDepthSampleLevel,
112 #[error("Gather level can only be Zero")]
113 InvalidGatherLevel,
114 #[error("Gather component {0:?} doesn't exist in the image")]
115 InvalidGatherComponent(crate::SwizzleComponent),
116 #[error("Gather can't be done for image dimension {0:?}")]
117 InvalidGatherDimension(crate::ImageDimension),
118 #[error("Sample level (exact) type {0:?} has an invalid type")]
119 InvalidSampleLevelExactType(Handle<crate::Expression>),
120 #[error("Sample level (bias) type {0:?} is not a scalar float")]
121 InvalidSampleLevelBiasType(Handle<crate::Expression>),
122 #[error("Bias can't be done for image dimension {0:?}")]
123 InvalidSampleLevelBiasDimension(crate::ImageDimension),
124 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
125 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
126 #[error("Clamping sample coordinate to edge is not supported with {0}")]
127 InvalidSampleClampCoordinateToEdge(alloc::string::String),
128 #[error("Unable to cast")]
129 InvalidCastArgument,
130 #[error("Invalid argument count for {0:?}")]
131 WrongArgumentCount(crate::MathFunction),
132 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
133 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
134 #[error(
135 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
136 )]
137 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
138 #[error("Shader requires capability {0:?}")]
139 MissingCapabilities(super::Capabilities),
140 #[error(transparent)]
141 Literal(#[from] LiteralError),
142 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
143 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
144}
145
146#[derive(Clone, Debug, thiserror::Error)]
147#[cfg_attr(test, derive(PartialEq))]
148pub enum ConstExpressionError {
149 #[error("The expression is not a constant or override expression")]
150 NonConstOrOverride,
151 #[error("The expression is not a fully evaluated constant expression")]
152 NonFullyEvaluatedConst,
153 #[error(transparent)]
154 Compose(#[from] super::ComposeError),
155 #[error("Splatting {0:?} can't be done")]
156 InvalidSplatType(Handle<crate::Expression>),
157 #[error("Type resolution failed")]
158 Type(#[from] ResolveError),
159 #[error(transparent)]
160 Literal(#[from] LiteralError),
161 #[error(transparent)]
162 Width(#[from] super::r#type::WidthError),
163}
164
165#[derive(Clone, Debug, thiserror::Error)]
166#[cfg_attr(test, derive(PartialEq))]
167pub enum LiteralError {
168 #[error("Float literal is NaN")]
169 NaN,
170 #[error("Float literal is infinite")]
171 Infinity,
172 #[error(transparent)]
173 Width(#[from] super::r#type::WidthError),
174}
175
176struct ExpressionTypeResolver<'a> {
177 root: Handle<crate::Expression>,
178 types: &'a UniqueArena<crate::Type>,
179 info: &'a FunctionInfo,
180}
181
182impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
183 type Output = crate::TypeInner;
184
185 #[allow(clippy::panic)]
186 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
187 if handle < self.root {
188 self.info[handle].ty.inner_with(self.types)
189 } else {
190 panic!(
192 "Depends on {:?}, which has not been processed yet",
193 self.root
194 )
195 }
196 }
197}
198
199impl super::Validator {
200 pub(super) fn validate_const_expression(
201 &self,
202 handle: Handle<crate::Expression>,
203 gctx: crate::proc::GlobalCtx,
204 mod_info: &ModuleInfo,
205 global_expr_kind: &crate::proc::ExpressionKindTracker,
206 ) -> Result<(), ConstExpressionError> {
207 use crate::Expression as E;
208
209 if !global_expr_kind.is_const_or_override(handle) {
210 return Err(ConstExpressionError::NonConstOrOverride);
211 }
212
213 match gctx.global_expressions[handle] {
214 E::Literal(literal) => {
215 self.validate_literal(literal)?;
216 }
217 E::Constant(_) | E::ZeroValue(_) => {}
218 E::Compose { ref components, ty } => {
219 validate_compose(
220 ty,
221 gctx,
222 components.iter().map(|&handle| mod_info[handle].clone()),
223 )?;
224 }
225 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
226 crate::TypeInner::Scalar { .. } => {}
227 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
228 },
229 _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
230 return Err(ConstExpressionError::NonFullyEvaluatedConst)
231 }
232 _ => {}
234 }
235
236 Ok(())
237 }
238
239 #[allow(clippy::too_many_arguments)]
240 pub(super) fn validate_expression(
241 &self,
242 root: Handle<crate::Expression>,
243 expression: &crate::Expression,
244 function: &crate::Function,
245 module: &crate::Module,
246 info: &FunctionInfo,
247 mod_info: &ModuleInfo,
248 expr_kind: &crate::proc::ExpressionKindTracker,
249 ) -> Result<ShaderStages, ExpressionError> {
250 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
251
252 let resolver = ExpressionTypeResolver {
253 root,
254 types: &module.types,
255 info,
256 };
257
258 let stages = match *expression {
259 E::Access { base, index } => {
260 let base_type = &resolver[base];
261 match *base_type {
262 Ti::Matrix { .. }
263 | Ti::Vector { .. }
264 | Ti::Array { .. }
265 | Ti::Pointer { .. }
266 | Ti::ValuePointer { size: Some(_), .. }
267 | Ti::BindingArray { .. } => {}
268 ref other => {
269 log::error!("Indexing of {other:?}");
270 return Err(ExpressionError::InvalidBaseType(base));
271 }
272 };
273 match resolver[index] {
274 Ti::Scalar(Sc {
276 kind: Sk::Sint | Sk::Uint,
277 ..
278 }) => {}
279 ref other => {
280 log::error!("Indexing by {other:?}");
281 return Err(ExpressionError::InvalidIndexType(index));
282 }
283 }
284
285 match module
287 .to_ctx()
288 .eval_expr_to_u32_from(index, &function.expressions)
289 {
290 Ok(value) => {
291 let length = if self.overrides_resolved {
292 base_type.indexable_length_resolved(module)
293 } else {
294 base_type.indexable_length_pending(module)
295 }?;
296 if let crate::proc::IndexableLength::Known(known_length) = length {
299 if value >= known_length {
300 return Err(ExpressionError::IndexOutOfBounds(base, value));
301 }
302 }
303 }
304 Err(crate::proc::U32EvalError::Negative) => {
305 return Err(ExpressionError::NegativeIndex(base))
306 }
307 Err(crate::proc::U32EvalError::NonConst) => {}
308 }
309
310 ShaderStages::all()
311 }
312 E::AccessIndex { base, index } => {
313 fn resolve_index_limit(
314 module: &crate::Module,
315 top: Handle<crate::Expression>,
316 ty: &crate::TypeInner,
317 top_level: bool,
318 ) -> Result<u32, ExpressionError> {
319 let limit = match *ty {
320 Ti::Vector { size, .. }
321 | Ti::ValuePointer {
322 size: Some(size), ..
323 } => size as u32,
324 Ti::Matrix { columns, .. } => columns as u32,
325 Ti::Array {
326 size: crate::ArraySize::Constant(len),
327 ..
328 } => len.get(),
329 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
331 resolve_index_limit(module, top, &module.types[base].inner, false)?
332 }
333 Ti::Struct { ref members, .. } => members.len() as u32,
334 ref other => {
335 log::error!("Indexing of {other:?}");
336 return Err(ExpressionError::InvalidBaseType(top));
337 }
338 };
339 Ok(limit)
340 }
341
342 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
343 if index >= limit {
344 return Err(ExpressionError::IndexOutOfBounds(base, limit));
345 }
346 ShaderStages::all()
347 }
348 E::Splat { size: _, value } => match resolver[value] {
349 Ti::Scalar { .. } => ShaderStages::all(),
350 ref other => {
351 log::error!("Splat scalar type {other:?}");
352 return Err(ExpressionError::InvalidSplatType(value));
353 }
354 },
355 E::Swizzle {
356 size,
357 vector,
358 pattern,
359 } => {
360 let vec_size = match resolver[vector] {
361 Ti::Vector { size: vec_size, .. } => vec_size,
362 ref other => {
363 log::error!("Swizzle vector type {other:?}");
364 return Err(ExpressionError::InvalidVectorType(vector));
365 }
366 };
367 for &sc in pattern[..size as usize].iter() {
368 if sc as u8 >= vec_size as u8 {
369 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
370 }
371 }
372 ShaderStages::all()
373 }
374 E::Literal(literal) => {
375 self.validate_literal(literal)?;
376 ShaderStages::all()
377 }
378 E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
379 E::Compose { ref components, ty } => {
380 validate_compose(
381 ty,
382 module.to_ctx(),
383 components.iter().map(|&handle| info[handle].ty.clone()),
384 )?;
385 ShaderStages::all()
386 }
387 E::FunctionArgument(index) => {
388 if index >= function.arguments.len() as u32 {
389 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
390 }
391 ShaderStages::all()
392 }
393 E::GlobalVariable(_handle) => ShaderStages::all(),
394 E::LocalVariable(_handle) => ShaderStages::all(),
395 E::Load { pointer } => {
396 match resolver[pointer] {
397 Ti::Pointer { base, .. }
398 if self.types[base.index()]
399 .flags
400 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
401 Ti::ValuePointer { .. } => {}
402 ref other => {
403 log::error!("Loading {other:?}");
404 return Err(ExpressionError::InvalidPointerType(pointer));
405 }
406 }
407 ShaderStages::all()
408 }
409 E::ImageSample {
410 image,
411 sampler,
412 gather,
413 coordinate,
414 array_index,
415 offset,
416 level,
417 depth_ref,
418 clamp_to_edge,
419 } => {
420 let image_ty = Self::global_var_ty(module, function, image)?;
422 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
423
424 let comparison = match module.types[sampler_ty].inner {
425 Ti::Sampler { comparison } => comparison,
426 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
427 };
428
429 let (class, dim) = match module.types[image_ty].inner {
430 Ti::Image {
431 class,
432 arrayed,
433 dim,
434 } => {
435 if arrayed != array_index.is_some() {
437 return Err(ExpressionError::InvalidImageArrayIndex);
438 }
439 if let Some(expr) = array_index {
440 match resolver[expr] {
441 Ti::Scalar(Sc {
442 kind: Sk::Sint | Sk::Uint,
443 ..
444 }) => {}
445 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
446 }
447 }
448 (class, dim)
449 }
450 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
451 };
452
453 let image_depth = match class {
455 crate::ImageClass::Sampled {
456 kind: crate::ScalarKind::Float,
457 multi: false,
458 } => false,
459 crate::ImageClass::Sampled {
460 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
461 multi: false,
462 } if gather.is_some() => false,
463 crate::ImageClass::External => false,
464 crate::ImageClass::Depth { multi: false } => true,
465 _ => return Err(ExpressionError::InvalidImageClass(class)),
466 };
467 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
468 return Err(ExpressionError::ComparisonSamplingMismatch {
469 image: class,
470 sampler: comparison,
471 has_ref: depth_ref.is_some(),
472 });
473 }
474
475 let num_components = match dim {
477 crate::ImageDimension::D1 => 1,
478 crate::ImageDimension::D2 => 2,
479 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
480 };
481 match resolver[coordinate] {
482 Ti::Scalar(Sc {
483 kind: Sk::Float, ..
484 }) if num_components == 1 => {}
485 Ti::Vector {
486 size,
487 scalar:
488 Sc {
489 kind: Sk::Float, ..
490 },
491 } if size as u32 == num_components => {}
492 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
493 }
494
495 if let Some(const_expr) = offset {
497 if !expr_kind.is_const(const_expr) {
498 return Err(ExpressionError::InvalidSampleOffsetExprType);
499 }
500
501 match resolver[const_expr] {
502 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
503 Ti::Vector {
504 size,
505 scalar: Sc { kind: Sk::Sint, .. },
506 } if size as u32 == num_components => {}
507 _ => {
508 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
509 }
510 }
511 }
512
513 if let Some(expr) = depth_ref {
515 match resolver[expr] {
516 Ti::Scalar(Sc {
517 kind: Sk::Float, ..
518 }) => {}
519 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
520 }
521 match level {
522 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
523 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
524 }
525 }
526
527 if let Some(component) = gather {
528 match dim {
529 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
530 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
531 return Err(ExpressionError::InvalidGatherDimension(dim))
532 }
533 };
534 let max_component = match class {
535 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
536 _ => crate::SwizzleComponent::W,
537 };
538 if component > max_component {
539 return Err(ExpressionError::InvalidGatherComponent(component));
540 }
541 match level {
542 crate::SampleLevel::Zero => {}
543 _ => return Err(ExpressionError::InvalidGatherLevel),
544 }
545 }
546
547 if clamp_to_edge {
550 if !matches!(
551 class,
552 crate::ImageClass::Sampled {
553 kind: crate::ScalarKind::Float,
554 multi: false
555 } | crate::ImageClass::External
556 ) {
557 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
558 alloc::format!("image class `{class:?}`"),
559 ));
560 }
561 if dim != crate::ImageDimension::D2 {
562 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
563 alloc::format!("image dimension `{dim:?}`"),
564 ));
565 }
566 if gather.is_some() {
567 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
568 "gather".into(),
569 ));
570 }
571 if array_index.is_some() {
572 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
573 "array index".into(),
574 ));
575 }
576 if offset.is_some() {
577 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
578 "offset".into(),
579 ));
580 }
581 if level != crate::SampleLevel::Zero {
582 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
583 "non-zero level".into(),
584 ));
585 }
586 if depth_ref.is_some() {
587 return Err(ExpressionError::InvalidSampleClampCoordinateToEdge(
588 "depth comparison".into(),
589 ));
590 }
591 }
592
593 if matches!(class, crate::ImageClass::External) && !clamp_to_edge {
595 return Err(ExpressionError::InvalidImageClass(class));
596 }
597
598 match level {
600 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
601 crate::SampleLevel::Zero => ShaderStages::all(),
602 crate::SampleLevel::Exact(expr) => {
603 match class {
604 crate::ImageClass::Depth { .. } => match resolver[expr] {
605 Ti::Scalar(Sc {
606 kind: Sk::Sint | Sk::Uint,
607 ..
608 }) => {}
609 _ => {
610 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
611 }
612 },
613 _ => match resolver[expr] {
614 Ti::Scalar(Sc {
615 kind: Sk::Float, ..
616 }) => {}
617 _ => {
618 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
619 }
620 },
621 }
622 ShaderStages::all()
623 }
624 crate::SampleLevel::Bias(expr) => {
625 match resolver[expr] {
626 Ti::Scalar(Sc {
627 kind: Sk::Float, ..
628 }) => {}
629 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
630 }
631 match class {
632 crate::ImageClass::Sampled {
633 kind: Sk::Float,
634 multi: false,
635 } => {
636 if dim == crate::ImageDimension::D1 {
637 return Err(ExpressionError::InvalidSampleLevelBiasDimension(
638 dim,
639 ));
640 }
641 }
642 _ => return Err(ExpressionError::InvalidImageClass(class)),
643 }
644 ShaderStages::FRAGMENT
645 }
646 crate::SampleLevel::Gradient { x, y } => {
647 match resolver[x] {
648 Ti::Scalar(Sc {
649 kind: Sk::Float, ..
650 }) if num_components == 1 => {}
651 Ti::Vector {
652 size,
653 scalar:
654 Sc {
655 kind: Sk::Float, ..
656 },
657 } if size as u32 == num_components => {}
658 _ => {
659 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
660 }
661 }
662 match resolver[y] {
663 Ti::Scalar(Sc {
664 kind: Sk::Float, ..
665 }) if num_components == 1 => {}
666 Ti::Vector {
667 size,
668 scalar:
669 Sc {
670 kind: Sk::Float, ..
671 },
672 } if size as u32 == num_components => {}
673 _ => {
674 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
675 }
676 }
677 ShaderStages::all()
678 }
679 }
680 }
681 E::ImageLoad {
682 image,
683 coordinate,
684 array_index,
685 sample,
686 level,
687 } => {
688 let ty = Self::global_var_ty(module, function, image)?;
689 let Ti::Image {
690 class,
691 arrayed,
692 dim,
693 } = module.types[ty].inner
694 else {
695 return Err(ExpressionError::ExpectedImageType(ty));
696 };
697
698 match resolver[coordinate].image_storage_coordinates() {
699 Some(coord_dim) if coord_dim == dim => {}
700 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
701 };
702 if arrayed != array_index.is_some() {
703 return Err(ExpressionError::InvalidImageArrayIndex);
704 }
705 if let Some(expr) = array_index {
706 if !matches!(resolver[expr], Ti::Scalar(Sc::I32 | Sc::U32)) {
707 return Err(ExpressionError::InvalidImageArrayIndexType(expr));
708 }
709 }
710
711 match (sample, class.is_multisampled()) {
712 (None, false) => {}
713 (Some(sample), true) => {
714 if !matches!(resolver[sample], Ti::Scalar(Sc::I32 | Sc::U32)) {
715 return Err(ExpressionError::InvalidImageOtherIndexType(sample));
716 }
717 }
718 _ => {
719 return Err(ExpressionError::InvalidImageOtherIndex);
720 }
721 }
722
723 match (level, class.is_mipmapped()) {
724 (None, false) => {}
725 (Some(level), true) => match resolver[level] {
726 Ti::Scalar(Sc {
727 kind: Sk::Sint | Sk::Uint,
728 width: _,
729 }) => {}
730 _ => return Err(ExpressionError::InvalidImageArrayIndexType(level)),
731 },
732 _ => {
733 return Err(ExpressionError::InvalidImageOtherIndex);
734 }
735 }
736 ShaderStages::all()
737 }
738 E::ImageQuery { image, query } => {
739 let ty = Self::global_var_ty(module, function, image)?;
740 match module.types[ty].inner {
741 Ti::Image { class, arrayed, .. } => {
742 let good = match query {
743 crate::ImageQuery::NumLayers => arrayed,
744 crate::ImageQuery::Size { level: None } => true,
745 crate::ImageQuery::Size { level: Some(level) } => {
746 match resolver[level] {
747 Ti::Scalar(Sc::I32 | Sc::U32) => {}
748 _ => {
749 return Err(ExpressionError::InvalidImageOtherIndexType(
750 level,
751 ))
752 }
753 }
754 class.is_mipmapped()
755 }
756 crate::ImageQuery::NumLevels => class.is_mipmapped(),
757 crate::ImageQuery::NumSamples => class.is_multisampled(),
758 };
759 if !good {
760 return Err(ExpressionError::InvalidImageClass(class));
761 }
762 }
763 _ => return Err(ExpressionError::ExpectedImageType(ty)),
764 }
765 ShaderStages::all()
766 }
767 E::Unary { op, expr } => {
768 use crate::UnaryOperator as Uo;
769 let inner = &resolver[expr];
770 match (op, inner.scalar_kind()) {
771 (Uo::Negate, Some(Sk::Float | Sk::Sint))
772 | (Uo::LogicalNot, Some(Sk::Bool))
773 | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
774 other => {
775 log::error!("Op {op:?} kind {other:?}");
776 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
777 }
778 }
779 ShaderStages::all()
780 }
781 E::Binary { op, left, right } => {
782 use crate::BinaryOperator as Bo;
783 let left_inner = &resolver[left];
784 let right_inner = &resolver[right];
785 let good = match op {
786 Bo::Add | Bo::Subtract => match *left_inner {
787 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
788 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
789 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
790 },
791 Ti::Matrix { .. } => left_inner == right_inner,
792 _ => false,
793 },
794 Bo::Divide | Bo::Modulo => match *left_inner {
795 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
796 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
797 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
798 },
799 _ => false,
800 },
801 Bo::Multiply => {
802 let kind_allowed = match left_inner.scalar_kind() {
803 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
804 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
805 };
806 let types_match = match (left_inner, right_inner) {
807 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
809 | (
810 &Ti::Vector {
811 scalar: scalar1, ..
812 },
813 &Ti::Scalar(scalar2),
814 )
815 | (
816 &Ti::Scalar(scalar1),
817 &Ti::Vector {
818 scalar: scalar2, ..
819 },
820 ) => scalar1 == scalar2,
821 (
823 &Ti::Scalar(Sc {
824 kind: Sk::Float, ..
825 }),
826 &Ti::Matrix { .. },
827 )
828 | (
829 &Ti::Matrix { .. },
830 &Ti::Scalar(Sc {
831 kind: Sk::Float, ..
832 }),
833 ) => true,
834 (
836 &Ti::Vector {
837 size: size1,
838 scalar: scalar1,
839 },
840 &Ti::Vector {
841 size: size2,
842 scalar: scalar2,
843 },
844 ) => scalar1 == scalar2 && size1 == size2,
845 (
847 &Ti::Matrix { columns, .. },
848 &Ti::Vector {
849 size,
850 scalar:
851 Sc {
852 kind: Sk::Float, ..
853 },
854 },
855 ) => columns == size,
856 (
858 &Ti::Vector {
859 size,
860 scalar:
861 Sc {
862 kind: Sk::Float, ..
863 },
864 },
865 &Ti::Matrix { rows, .. },
866 ) => size == rows,
867 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
868 columns == rows
869 }
870 _ => false,
871 };
872 let left_width = left_inner.scalar_width().unwrap_or(0);
873 let right_width = right_inner.scalar_width().unwrap_or(0);
874 kind_allowed && types_match && left_width == right_width
875 }
876 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
877 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
878 match *left_inner {
879 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
880 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
881 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
882 },
883 ref other => {
884 log::error!("Op {op:?} left type {other:?}");
885 false
886 }
887 }
888 }
889 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
890 Ti::Scalar(Sc { kind: Sk::Bool, .. })
891 | Ti::Vector {
892 scalar: Sc { kind: Sk::Bool, .. },
893 ..
894 } => left_inner == right_inner,
895 ref other => {
896 log::error!("Op {op:?} left type {other:?}");
897 false
898 }
899 },
900 Bo::And | Bo::InclusiveOr => match *left_inner {
901 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
902 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
903 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
904 },
905 ref other => {
906 log::error!("Op {op:?} left type {other:?}");
907 false
908 }
909 },
910 Bo::ExclusiveOr => match *left_inner {
911 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
912 Sk::Sint | Sk::Uint => left_inner == right_inner,
913 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
914 },
915 ref other => {
916 log::error!("Op {op:?} left type {other:?}");
917 false
918 }
919 },
920 Bo::ShiftLeft | Bo::ShiftRight => {
921 let (base_size, base_scalar) = match *left_inner {
922 Ti::Scalar(scalar) => (Ok(None), scalar),
923 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
924 ref other => {
925 log::error!("Op {op:?} base type {other:?}");
926 (Err(()), Sc::BOOL)
927 }
928 };
929 let shift_size = match *right_inner {
930 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
931 Ti::Vector {
932 size,
933 scalar: Sc { kind: Sk::Uint, .. },
934 } => Ok(Some(size)),
935 ref other => {
936 log::error!("Op {op:?} shift type {other:?}");
937 Err(())
938 }
939 };
940 match base_scalar.kind {
941 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
942 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
943 }
944 }
945 };
946 if !good {
947 log::error!(
948 "Left: {:?} of type {:?}",
949 function.expressions[left],
950 left_inner
951 );
952 log::error!(
953 "Right: {:?} of type {:?}",
954 function.expressions[right],
955 right_inner
956 );
957 return Err(ExpressionError::InvalidBinaryOperandTypes {
958 op,
959 lhs_expr: left,
960 lhs_type: left_inner.clone(),
961 rhs_expr: right,
962 rhs_type: right_inner.clone(),
963 });
964 }
965 ShaderStages::all()
966 }
967 E::Select {
968 condition,
969 accept,
970 reject,
971 } => {
972 let accept_inner = &resolver[accept];
973 let reject_inner = &resolver[reject];
974 let condition_ty = &resolver[condition];
975 let condition_good = match *condition_ty {
976 Ti::Scalar(Sc {
977 kind: Sk::Bool,
978 width: _,
979 }) => {
980 match *accept_inner {
983 Ti::Scalar { .. } | Ti::Vector { .. } => true,
984 _ => false,
985 }
986 }
987 Ti::Vector {
988 size,
989 scalar:
990 Sc {
991 kind: Sk::Bool,
992 width: _,
993 },
994 } => match *accept_inner {
995 Ti::Vector {
996 size: other_size, ..
997 } => size == other_size,
998 _ => false,
999 },
1000 _ => false,
1001 };
1002 if accept_inner != reject_inner {
1003 return Err(ExpressionError::SelectValuesTypeMismatch {
1004 accept: accept_inner.clone(),
1005 reject: reject_inner.clone(),
1006 });
1007 }
1008 if !condition_good {
1009 return Err(ExpressionError::SelectConditionNotABool {
1010 actual: condition_ty.clone(),
1011 });
1012 }
1013 ShaderStages::all()
1014 }
1015 E::Derivative { expr, .. } => {
1016 match resolver[expr] {
1017 Ti::Scalar(Sc {
1018 kind: Sk::Float, ..
1019 })
1020 | Ti::Vector {
1021 scalar:
1022 Sc {
1023 kind: Sk::Float, ..
1024 },
1025 ..
1026 } => {}
1027 _ => return Err(ExpressionError::InvalidDerivative),
1028 }
1029 ShaderStages::FRAGMENT
1030 }
1031 E::Relational { fun, argument } => {
1032 use crate::RelationalFunction as Rf;
1033 let argument_inner = &resolver[argument];
1034 match fun {
1035 Rf::All | Rf::Any => match *argument_inner {
1036 Ti::Vector {
1037 scalar: Sc { kind: Sk::Bool, .. },
1038 ..
1039 } => {}
1040 ref other => {
1041 log::error!("All/Any of type {other:?}");
1042 return Err(ExpressionError::InvalidBooleanVector(argument));
1043 }
1044 },
1045 Rf::IsNan | Rf::IsInf => match *argument_inner {
1046 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1047 if scalar.kind == Sk::Float => {}
1048 ref other => {
1049 log::error!("Float test of type {other:?}");
1050 return Err(ExpressionError::InvalidFloatArgument(argument));
1051 }
1052 },
1053 }
1054 ShaderStages::all()
1055 }
1056 E::Math {
1057 fun,
1058 arg,
1059 arg1,
1060 arg2,
1061 arg3,
1062 } => {
1063 if matches!(
1064 fun,
1065 crate::MathFunction::QuantizeToF16
1066 | crate::MathFunction::Pack2x16float
1067 | crate::MathFunction::Unpack2x16float
1068 ) && !self
1069 .capabilities
1070 .contains(crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32)
1071 {
1072 return Err(ExpressionError::MissingCapabilities(
1073 crate::valid::Capabilities::SHADER_FLOAT16_IN_FLOAT32,
1074 ));
1075 }
1076
1077 let actuals: &[_] = match (arg1, arg2, arg3) {
1078 (None, None, None) => &[arg],
1079 (Some(arg1), None, None) => &[arg, arg1],
1080 (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1081 (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1082 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1083 };
1084
1085 let resolve = |arg| &resolver[arg];
1086 let actual_types: &[_] = match *actuals {
1087 [arg0] => &[resolve(arg0)],
1088 [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1089 [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1090 [arg0, arg1, arg2, arg3] => {
1091 &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1092 }
1093 _ => unreachable!(),
1094 };
1095
1096 let mut overloads = fun.overloads();
1098 log::debug!(
1099 "initial overloads for {:?}: {:#?}",
1100 fun,
1101 overloads.for_debug(&module.types)
1102 );
1103
1104 for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1112 overloads = overloads.arg(i, ty, &module.types);
1115 log::debug!(
1116 "overloads after arg {i}: {:#?}",
1117 overloads.for_debug(&module.types)
1118 );
1119
1120 if overloads.is_empty() {
1121 log::debug!("all overloads eliminated");
1122 return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1123 }
1124 }
1125
1126 if actuals.len() < overloads.min_arguments() {
1127 return Err(ExpressionError::WrongArgumentCount(fun));
1128 }
1129
1130 ShaderStages::all()
1131 }
1132 E::As {
1133 expr,
1134 kind,
1135 convert,
1136 } => {
1137 let mut base_scalar = match resolver[expr] {
1138 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1139 scalar
1140 }
1141 crate::TypeInner::Matrix { scalar, .. } => scalar,
1142 _ => return Err(ExpressionError::InvalidCastArgument),
1143 };
1144 base_scalar.kind = kind;
1145 if let Some(width) = convert {
1146 base_scalar.width = width;
1147 }
1148 if self.check_width(base_scalar).is_err() {
1149 return Err(ExpressionError::InvalidCastArgument);
1150 }
1151 ShaderStages::all()
1152 }
1153 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1154 E::AtomicResult { .. } => {
1155 ShaderStages::all()
1160 }
1161 E::WorkGroupUniformLoadResult { ty } => {
1162 if self.types[ty.index()]
1163 .flags
1164 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1167 {
1168 ShaderStages::COMPUTE
1169 } else {
1170 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1171 }
1172 }
1173 E::ArrayLength(expr) => match resolver[expr] {
1174 Ti::Pointer { base, .. } => {
1175 let base_ty = &resolver.types[base];
1176 if let Ti::Array {
1177 size: crate::ArraySize::Dynamic,
1178 ..
1179 } = base_ty.inner
1180 {
1181 ShaderStages::all()
1182 } else {
1183 return Err(ExpressionError::InvalidArrayType(expr));
1184 }
1185 }
1186 ref other => {
1187 log::error!("Array length of {other:?}");
1188 return Err(ExpressionError::InvalidArrayType(expr));
1189 }
1190 },
1191 E::RayQueryProceedResult => ShaderStages::all(),
1192 E::RayQueryGetIntersection {
1193 query,
1194 committed: _,
1195 } => match resolver[query] {
1196 Ti::Pointer {
1197 base,
1198 space: crate::AddressSpace::Function,
1199 } => match resolver.types[base].inner {
1200 Ti::RayQuery { .. } => ShaderStages::all(),
1201 ref other => {
1202 log::error!("Intersection result of a pointer to {other:?}");
1203 return Err(ExpressionError::InvalidRayQueryType(query));
1204 }
1205 },
1206 ref other => {
1207 log::error!("Intersection result of {other:?}");
1208 return Err(ExpressionError::InvalidRayQueryType(query));
1209 }
1210 },
1211 E::RayQueryVertexPositions {
1212 query,
1213 committed: _,
1214 } => match resolver[query] {
1215 Ti::Pointer {
1216 base,
1217 space: crate::AddressSpace::Function,
1218 } => match resolver.types[base].inner {
1219 Ti::RayQuery {
1220 vertex_return: true,
1221 } => ShaderStages::all(),
1222 ref other => {
1223 log::error!("Intersection result of a pointer to {other:?}");
1224 return Err(ExpressionError::InvalidRayQueryType(query));
1225 }
1226 },
1227 ref other => {
1228 log::error!("Intersection result of {other:?}");
1229 return Err(ExpressionError::InvalidRayQueryType(query));
1230 }
1231 },
1232 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1233 };
1234 Ok(stages)
1235 }
1236
1237 fn global_var_ty(
1238 module: &crate::Module,
1239 function: &crate::Function,
1240 expr: Handle<crate::Expression>,
1241 ) -> Result<Handle<crate::Type>, ExpressionError> {
1242 use crate::Expression as Ex;
1243
1244 match function.expressions[expr] {
1245 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1246 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1247 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1248 match function.expressions[base] {
1249 Ex::GlobalVariable(var_handle) => {
1250 let array_ty = module.global_variables[var_handle].ty;
1251
1252 match module.types[array_ty].inner {
1253 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1254 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1255 }
1256 }
1257 _ => Err(ExpressionError::ExpectedGlobalVariable),
1258 }
1259 }
1260 _ => Err(ExpressionError::ExpectedGlobalVariable),
1261 }
1262 }
1263
1264 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1265 let _ = self.check_width(literal.scalar())?;
1266 check_literal_value(literal)?;
1267
1268 Ok(())
1269 }
1270}
1271
1272pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1273 let is_nan = match literal {
1274 crate::Literal::F64(v) => v.is_nan(),
1275 crate::Literal::F32(v) => v.is_nan(),
1276 _ => false,
1277 };
1278 if is_nan {
1279 return Err(LiteralError::NaN);
1280 }
1281
1282 let is_infinite = match literal {
1283 crate::Literal::F64(v) => v.is_infinite(),
1284 crate::Literal::F32(v) => v.is_infinite(),
1285 _ => false,
1286 };
1287 if is_infinite {
1288 return Err(LiteralError::Infinity);
1289 }
1290
1291 Ok(())
1292}
1293
1294#[cfg(test)]
1295fn validate_with_expression(
1297 expr: crate::Expression,
1298 caps: super::Capabilities,
1299) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1300 use crate::span::Span;
1301
1302 let mut function = crate::Function::default();
1303 function.expressions.append(expr, Span::default());
1304 function.body.push(
1305 crate::Statement::Emit(function.expressions.range_from(0)),
1306 Span::default(),
1307 );
1308
1309 let mut module = crate::Module::default();
1310 module.functions.append(function, Span::default());
1311
1312 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1313
1314 validator.validate(&module)
1315}
1316
1317#[cfg(test)]
1318fn validate_with_const_expression(
1320 expr: crate::Expression,
1321 caps: super::Capabilities,
1322) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1323 use crate::span::Span;
1324
1325 let mut module = crate::Module::default();
1326 module.global_expressions.append(expr, Span::default());
1327
1328 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1329
1330 validator.validate(&module)
1331}
1332
1333#[test]
1335fn f64_runtime_literals() {
1336 let result = validate_with_expression(
1337 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1338 super::Capabilities::default(),
1339 );
1340 let error = result.unwrap_err().into_inner();
1341 assert!(matches!(
1342 error,
1343 crate::valid::ValidationError::Function {
1344 source: super::FunctionError::Expression {
1345 source: ExpressionError::Literal(LiteralError::Width(
1346 super::r#type::WidthError::MissingCapability {
1347 name: "f64",
1348 flag: "FLOAT64",
1349 }
1350 ),),
1351 ..
1352 },
1353 ..
1354 }
1355 ));
1356
1357 let result = validate_with_expression(
1358 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1359 super::Capabilities::default() | super::Capabilities::FLOAT64,
1360 );
1361 assert!(result.is_ok());
1362}
1363
1364#[test]
1366fn f64_const_literals() {
1367 let result = validate_with_const_expression(
1368 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1369 super::Capabilities::default(),
1370 );
1371 let error = result.unwrap_err().into_inner();
1372 assert!(matches!(
1373 error,
1374 crate::valid::ValidationError::ConstExpression {
1375 source: ConstExpressionError::Literal(LiteralError::Width(
1376 super::r#type::WidthError::MissingCapability {
1377 name: "f64",
1378 flag: "FLOAT64",
1379 }
1380 )),
1381 ..
1382 }
1383 ));
1384
1385 let result = validate_with_const_expression(
1386 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1387 super::Capabilities::default() | super::Capabilities::FLOAT64,
1388 );
1389 assert!(result.is_ok());
1390}