1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38bitflags::bitflags! {
39 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
53 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
54 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
55 pub struct ValidationFlags: u8 {
56 const EXPRESSIONS = 0x1;
58 const BLOCKS = 0x2;
60 const CONTROL_FLOW_UNIFORMITY = 0x4;
62 const STRUCT_LAYOUTS = 0x8;
64 const CONSTANTS = 0x10;
66 const BINDINGS = 0x20;
68 }
69}
70
71impl Default for ValidationFlags {
72 fn default() -> Self {
73 Self::all()
74 }
75}
76
77bitflags::bitflags! {
78 #[must_use]
80 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
81 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
82 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
83 pub struct Capabilities: u32 {
84 const PUSH_CONSTANT = 1 << 0;
88 const FLOAT64 = 1 << 1;
90 const PRIMITIVE_INDEX = 1 << 2;
94 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
96 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
98 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
100 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
102 const CLIP_DISTANCE = 1 << 7;
106 const CULL_DISTANCE = 1 << 8;
110 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
112 const MULTIVIEW = 1 << 10;
116 const EARLY_DEPTH_TEST = 1 << 11;
118 const MULTISAMPLED_SHADING = 1 << 12;
123 const RAY_QUERY = 1 << 13;
125 const DUAL_SOURCE_BLENDING = 1 << 14;
127 const CUBE_ARRAY_TEXTURES = 1 << 15;
129 const SHADER_INT64 = 1 << 16;
131 const SUBGROUP = 1 << 17;
135 const SUBGROUP_BARRIER = 1 << 18;
137 const SUBGROUP_VERTEX_STAGE = 1 << 19;
139 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
149 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
151 const SHADER_FLOAT32_ATOMIC = 1 << 22;
160 const TEXTURE_ATOMIC = 1 << 23;
162 const TEXTURE_INT64_ATOMIC = 1 << 24;
164 const RAY_HIT_VERTEX_POSITION = 1 << 25;
166 const SHADER_FLOAT16 = 1 << 26;
168 }
169}
170
171impl Default for Capabilities {
172 fn default() -> Self {
173 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
174 }
175}
176
177bitflags::bitflags! {
178 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
180 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
181 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
182 pub struct SubgroupOperationSet: u8 {
183 const BASIC = 1 << 0;
185 const VOTE = 1 << 1;
187 const ARITHMETIC = 1 << 2;
189 const BALLOT = 1 << 3;
191 const SHUFFLE = 1 << 4;
193 const SHUFFLE_RELATIVE = 1 << 5;
195 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
200 }
203}
204
205impl super::SubgroupOperation {
206 const fn required_operations(&self) -> SubgroupOperationSet {
207 use SubgroupOperationSet as S;
208 match *self {
209 Self::All | Self::Any => S::VOTE,
210 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
211 S::ARITHMETIC
212 }
213 }
214 }
215}
216
217impl super::GatherMode {
218 const fn required_operations(&self) -> SubgroupOperationSet {
219 use SubgroupOperationSet as S;
220 match *self {
221 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
222 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
223 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
224 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
225 }
226 }
227}
228
229bitflags::bitflags! {
230 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
232 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
233 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
234 pub struct ShaderStages: u8 {
235 const VERTEX = 0x1;
236 const FRAGMENT = 0x2;
237 const COMPUTE = 0x4;
238 }
239}
240
241#[derive(Debug, Clone, Default)]
242#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
243#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
244pub struct ModuleInfo {
245 type_flags: Vec<TypeFlags>,
246 functions: Vec<FunctionInfo>,
247 entry_points: Vec<FunctionInfo>,
248 const_expression_types: Box<[TypeResolution]>,
249}
250
251impl ops::Index<Handle<crate::Type>> for ModuleInfo {
252 type Output = TypeFlags;
253 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
254 &self.type_flags[handle.index()]
255 }
256}
257
258impl ops::Index<Handle<crate::Function>> for ModuleInfo {
259 type Output = FunctionInfo;
260 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
261 &self.functions[handle.index()]
262 }
263}
264
265impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
266 type Output = TypeResolution;
267 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
268 &self.const_expression_types[handle.index()]
269 }
270}
271
272#[derive(Debug)]
273pub struct Validator {
274 flags: ValidationFlags,
275 capabilities: Capabilities,
276 subgroup_stages: ShaderStages,
277 subgroup_operations: SubgroupOperationSet,
278 types: Vec<r#type::TypeInfo>,
279 layouter: Layouter,
280 location_mask: BitSet,
281 blend_src_mask: BitSet,
282 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
283 #[allow(dead_code)]
284 switch_values: FastHashSet<crate::SwitchValue>,
285 valid_expression_list: Vec<Handle<crate::Expression>>,
286 valid_expression_set: HandleSet<crate::Expression>,
287 override_ids: FastHashSet<u16>,
288
289 overrides_resolved: bool,
292
293 needs_visit: HandleSet<crate::Expression>,
312}
313
314#[derive(Clone, Debug, thiserror::Error)]
315#[cfg_attr(test, derive(PartialEq))]
316pub enum ConstantError {
317 #[error("Initializer must be a const-expression")]
318 InitializerExprType,
319 #[error("The type doesn't match the constant")]
320 InvalidType,
321 #[error("The type is not constructible")]
322 NonConstructibleType,
323}
324
325#[derive(Clone, Debug, thiserror::Error)]
326#[cfg_attr(test, derive(PartialEq))]
327pub enum OverrideError {
328 #[error("Override name and ID are missing")]
329 MissingNameAndID,
330 #[error("Override ID must be unique")]
331 DuplicateID,
332 #[error("Initializer must be a const-expression or override-expression")]
333 InitializerExprType,
334 #[error("The type doesn't match the override")]
335 InvalidType,
336 #[error("The type is not constructible")]
337 NonConstructibleType,
338 #[error("The type is not a scalar")]
339 TypeNotScalar,
340 #[error("Override declarations are not allowed")]
341 NotAllowed,
342 #[error("Override is uninitialized")]
343 UninitializedOverride,
344 #[error("Constant expression {handle:?} is invalid")]
345 ConstExpression {
346 handle: Handle<crate::Expression>,
347 source: ConstExpressionError,
348 },
349}
350
351#[derive(Clone, Debug, thiserror::Error)]
352#[cfg_attr(test, derive(PartialEq))]
353pub enum ValidationError {
354 #[error(transparent)]
355 InvalidHandle(#[from] InvalidHandleError),
356 #[error(transparent)]
357 Layouter(#[from] LayoutError),
358 #[error("Type {handle:?} '{name}' is invalid")]
359 Type {
360 handle: Handle<crate::Type>,
361 name: String,
362 source: TypeError,
363 },
364 #[error("Constant expression {handle:?} is invalid")]
365 ConstExpression {
366 handle: Handle<crate::Expression>,
367 source: ConstExpressionError,
368 },
369 #[error("Array size expression {handle:?} is not strictly positive")]
370 ArraySizeError { handle: Handle<crate::Expression> },
371 #[error("Constant {handle:?} '{name}' is invalid")]
372 Constant {
373 handle: Handle<crate::Constant>,
374 name: String,
375 source: ConstantError,
376 },
377 #[error("Override {handle:?} '{name}' is invalid")]
378 Override {
379 handle: Handle<crate::Override>,
380 name: String,
381 source: OverrideError,
382 },
383 #[error("Global variable {handle:?} '{name}' is invalid")]
384 GlobalVariable {
385 handle: Handle<crate::GlobalVariable>,
386 name: String,
387 source: GlobalVariableError,
388 },
389 #[error("Function {handle:?} '{name}' is invalid")]
390 Function {
391 handle: Handle<crate::Function>,
392 name: String,
393 source: FunctionError,
394 },
395 #[error("Entry point {name} at {stage:?} is invalid")]
396 EntryPoint {
397 stage: crate::ShaderStage,
398 name: String,
399 source: EntryPointError,
400 },
401 #[error("Module is corrupted")]
402 Corrupted,
403}
404
405impl crate::TypeInner {
406 const fn is_sized(&self) -> bool {
407 match *self {
408 Self::Scalar { .. }
409 | Self::Vector { .. }
410 | Self::Matrix { .. }
411 | Self::Array {
412 size: crate::ArraySize::Constant(_),
413 ..
414 }
415 | Self::Atomic { .. }
416 | Self::Pointer { .. }
417 | Self::ValuePointer { .. }
418 | Self::Struct { .. } => true,
419 Self::Array { .. }
420 | Self::Image { .. }
421 | Self::Sampler { .. }
422 | Self::AccelerationStructure { .. }
423 | Self::RayQuery { .. }
424 | Self::BindingArray { .. } => false,
425 }
426 }
427
428 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
430 match *self {
431 Self::Scalar(crate::Scalar {
432 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
433 ..
434 }) => Some(crate::ImageDimension::D1),
435 Self::Vector {
436 size: crate::VectorSize::Bi,
437 scalar:
438 crate::Scalar {
439 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
440 ..
441 },
442 } => Some(crate::ImageDimension::D2),
443 Self::Vector {
444 size: crate::VectorSize::Tri,
445 scalar:
446 crate::Scalar {
447 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
448 ..
449 },
450 } => Some(crate::ImageDimension::D3),
451 _ => None,
452 }
453 }
454}
455
456impl Validator {
457 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
459 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
460 use SubgroupOperationSet as S;
461 S::BASIC
462 | S::VOTE
463 | S::ARITHMETIC
464 | S::BALLOT
465 | S::SHUFFLE
466 | S::SHUFFLE_RELATIVE
467 | S::QUAD_FRAGMENT_COMPUTE
468 } else {
469 SubgroupOperationSet::empty()
470 };
471 let subgroup_stages = {
472 let mut stages = ShaderStages::empty();
473 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
474 stages |= ShaderStages::VERTEX;
475 }
476 if capabilities.contains(Capabilities::SUBGROUP) {
477 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
478 }
479 stages
480 };
481
482 Validator {
483 flags,
484 capabilities,
485 subgroup_stages,
486 subgroup_operations,
487 types: Vec::new(),
488 layouter: Layouter::default(),
489 location_mask: BitSet::new(),
490 blend_src_mask: BitSet::new(),
491 ep_resource_bindings: FastHashSet::default(),
492 switch_values: FastHashSet::default(),
493 valid_expression_list: Vec::new(),
494 valid_expression_set: HandleSet::new(),
495 override_ids: FastHashSet::default(),
496 overrides_resolved: false,
497 needs_visit: HandleSet::new(),
498 }
499 }
500
501 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
502 self.subgroup_stages = stages;
503 self
504 }
505
506 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
507 self.subgroup_operations = operations;
508 self
509 }
510
511 pub fn reset(&mut self) {
513 self.types.clear();
514 self.layouter.clear();
515 self.location_mask.clear();
516 self.blend_src_mask.clear();
517 self.ep_resource_bindings.clear();
518 self.switch_values.clear();
519 self.valid_expression_list.clear();
520 self.valid_expression_set.clear();
521 self.override_ids.clear();
522 }
523
524 fn validate_constant(
525 &self,
526 handle: Handle<crate::Constant>,
527 gctx: crate::proc::GlobalCtx,
528 mod_info: &ModuleInfo,
529 global_expr_kind: &ExpressionKindTracker,
530 ) -> Result<(), ConstantError> {
531 let con = &gctx.constants[handle];
532
533 let type_info = &self.types[con.ty.index()];
534 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
535 return Err(ConstantError::NonConstructibleType);
536 }
537
538 if !global_expr_kind.is_const(con.init) {
539 return Err(ConstantError::InitializerExprType);
540 }
541
542 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
543 return Err(ConstantError::InvalidType);
544 }
545
546 Ok(())
547 }
548
549 fn validate_override(
550 &mut self,
551 handle: Handle<crate::Override>,
552 gctx: crate::proc::GlobalCtx,
553 mod_info: &ModuleInfo,
554 ) -> Result<(), OverrideError> {
555 let o = &gctx.overrides[handle];
556
557 if let Some(id) = o.id {
558 if !self.override_ids.insert(id) {
559 return Err(OverrideError::DuplicateID);
560 }
561 }
562
563 let type_info = &self.types[o.ty.index()];
564 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
565 return Err(OverrideError::NonConstructibleType);
566 }
567
568 match gctx.types[o.ty].inner {
569 crate::TypeInner::Scalar(
570 crate::Scalar::BOOL
571 | crate::Scalar::I32
572 | crate::Scalar::U32
573 | crate::Scalar::F16
574 | crate::Scalar::F32
575 | crate::Scalar::F64,
576 ) => {}
577 _ => return Err(OverrideError::TypeNotScalar),
578 }
579
580 if let Some(init) = o.init {
581 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
582 return Err(OverrideError::InvalidType);
583 }
584 } else if self.overrides_resolved {
585 return Err(OverrideError::UninitializedOverride);
586 }
587
588 Ok(())
589 }
590
591 pub fn validate(
593 &mut self,
594 module: &crate::Module,
595 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
596 self.overrides_resolved = false;
597 self.validate_impl(module)
598 }
599
600 pub fn validate_resolved_overrides(
608 &mut self,
609 module: &crate::Module,
610 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
611 self.overrides_resolved = true;
612 self.validate_impl(module)
613 }
614
615 fn validate_impl(
616 &mut self,
617 module: &crate::Module,
618 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
619 self.reset();
620 self.reset_types(module.types.len());
621
622 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
623
624 self.layouter.update(module.to_ctx()).map_err(|e| {
625 let handle = e.ty;
626 ValidationError::from(e).with_span_handle(handle, &module.types)
627 })?;
628
629 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
631 kind: crate::ScalarKind::Bool,
632 width: 0,
633 }));
634
635 let mut mod_info = ModuleInfo {
636 type_flags: Vec::with_capacity(module.types.len()),
637 functions: Vec::with_capacity(module.functions.len()),
638 entry_points: Vec::with_capacity(module.entry_points.len()),
639 const_expression_types: vec![placeholder; module.global_expressions.len()]
640 .into_boxed_slice(),
641 };
642
643 for (handle, ty) in module.types.iter() {
644 let ty_info = self
645 .validate_type(handle, module.to_ctx())
646 .map_err(|source| {
647 ValidationError::Type {
648 handle,
649 name: ty.name.clone().unwrap_or_default(),
650 source,
651 }
652 .with_span_handle(handle, &module.types)
653 })?;
654 mod_info.type_flags.push(ty_info.flags);
655 self.types[handle.index()] = ty_info;
656 }
657
658 {
659 let t = crate::Arena::new();
660 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
661 for (handle, _) in module.global_expressions.iter() {
662 mod_info
663 .process_const_expression(handle, &resolve_context, module.to_ctx())
664 .map_err(|source| {
665 ValidationError::ConstExpression { handle, source }
666 .with_span_handle(handle, &module.global_expressions)
667 })?
668 }
669 }
670
671 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
672
673 if self.flags.contains(ValidationFlags::CONSTANTS) {
674 for (handle, _) in module.global_expressions.iter() {
675 self.validate_const_expression(
676 handle,
677 module.to_ctx(),
678 &mod_info,
679 &global_expr_kind,
680 )
681 .map_err(|source| {
682 ValidationError::ConstExpression { handle, source }
683 .with_span_handle(handle, &module.global_expressions)
684 })?
685 }
686
687 for (handle, constant) in module.constants.iter() {
688 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
689 .map_err(|source| {
690 ValidationError::Constant {
691 handle,
692 name: constant.name.clone().unwrap_or_default(),
693 source,
694 }
695 .with_span_handle(handle, &module.constants)
696 })?
697 }
698
699 for (handle, r#override) in module.overrides.iter() {
700 self.validate_override(handle, module.to_ctx(), &mod_info)
701 .map_err(|source| {
702 ValidationError::Override {
703 handle,
704 name: r#override.name.clone().unwrap_or_default(),
705 source,
706 }
707 .with_span_handle(handle, &module.overrides)
708 })?;
709 }
710 }
711
712 for (var_handle, var) in module.global_variables.iter() {
713 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
714 .map_err(|source| {
715 ValidationError::GlobalVariable {
716 handle: var_handle,
717 name: var.name.clone().unwrap_or_default(),
718 source,
719 }
720 .with_span_handle(var_handle, &module.global_variables)
721 })?;
722 }
723
724 for (handle, fun) in module.functions.iter() {
725 match self.validate_function(fun, module, &mod_info, false) {
726 Ok(info) => mod_info.functions.push(info),
727 Err(error) => {
728 return Err(error.and_then(|source| {
729 ValidationError::Function {
730 handle,
731 name: fun.name.clone().unwrap_or_default(),
732 source,
733 }
734 .with_span_handle(handle, &module.functions)
735 }))
736 }
737 }
738 }
739
740 let mut ep_map = FastHashSet::default();
741 for ep in module.entry_points.iter() {
742 if !ep_map.insert((ep.stage, &ep.name)) {
743 return Err(ValidationError::EntryPoint {
744 stage: ep.stage,
745 name: ep.name.clone(),
746 source: EntryPointError::Conflict,
747 }
748 .with_span()); }
750
751 match self.validate_entry_point(ep, module, &mod_info) {
752 Ok(info) => mod_info.entry_points.push(info),
753 Err(error) => {
754 return Err(error.and_then(|source| {
755 ValidationError::EntryPoint {
756 stage: ep.stage,
757 name: ep.name.clone(),
758 source,
759 }
760 .with_span()
761 }));
762 }
763 }
764 }
765
766 Ok(mod_info)
767 }
768}
769
770fn validate_atomic_compare_exchange_struct(
771 types: &crate::UniqueArena<crate::Type>,
772 members: &[crate::StructMember],
773 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
774) -> bool {
775 members.len() == 2
776 && members[0].name.as_deref() == Some("old_value")
777 && scalar_predicate(&types[members[0].ty].inner)
778 && members[1].name.as_deref() == Some("exchanged")
779 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
780}