1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; bitflags::bitflags! {
42 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58 pub struct ValidationFlags: u8 {
59 const EXPRESSIONS = 0x1;
61 const BLOCKS = 0x2;
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
65 const STRUCT_LAYOUTS = 0x8;
67 const CONSTANTS = 0x10;
69 const BINDINGS = 0x20;
71 }
72}
73
74impl Default for ValidationFlags {
75 fn default() -> Self {
76 Self::all()
77 }
78}
79
80bitflags::bitflags! {
81 #[must_use]
83 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86 pub struct Capabilities: u32 {
87 const PUSH_CONSTANT = 1 << 0;
91 const FLOAT64 = 1 << 1;
93 const PRIMITIVE_INDEX = 1 << 2;
97 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105 const CLIP_DISTANCE = 1 << 7;
109 const CULL_DISTANCE = 1 << 8;
113 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115 const MULTIVIEW = 1 << 10;
119 const EARLY_DEPTH_TEST = 1 << 11;
121 const MULTISAMPLED_SHADING = 1 << 12;
126 const RAY_QUERY = 1 << 13;
128 const DUAL_SOURCE_BLENDING = 1 << 14;
130 const CUBE_ARRAY_TEXTURES = 1 << 15;
132 const SHADER_INT64 = 1 << 16;
134 const SUBGROUP = 1 << 17;
145 const SUBGROUP_BARRIER = 1 << 18;
149 const SUBGROUP_VERTEX_STAGE = 1 << 19;
155 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
165 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
167 const SHADER_FLOAT32_ATOMIC = 1 << 22;
176 const TEXTURE_ATOMIC = 1 << 23;
178 const TEXTURE_INT64_ATOMIC = 1 << 24;
180 const RAY_HIT_VERTEX_POSITION = 1 << 25;
182 const SHADER_FLOAT16 = 1 << 26;
184 const TEXTURE_EXTERNAL = 1 << 27;
186 const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189 }
190}
191
192impl Capabilities {
193 #[cfg(feature = "wgsl-in")]
197 #[doc(hidden)]
198 pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
199 use crate::front::wgsl::ImplementedEnableExtension as Ext;
200 match *self {
201 Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
202 Self::SHADER_FLOAT16 => Some(Ext::F16),
204 Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
205 _ => None,
206 }
207 }
208}
209
210impl Default for Capabilities {
211 fn default() -> Self {
212 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
213 }
214}
215
216bitflags::bitflags! {
217 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
219 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
220 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
221 pub struct SubgroupOperationSet: u8 {
222 const BASIC = 1 << 0;
228 const VOTE = 1 << 1;
230 const ARITHMETIC = 1 << 2;
232 const BALLOT = 1 << 3;
234 const SHUFFLE = 1 << 4;
236 const SHUFFLE_RELATIVE = 1 << 5;
238 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
243 }
246}
247
248impl super::SubgroupOperation {
249 const fn required_operations(&self) -> SubgroupOperationSet {
250 use SubgroupOperationSet as S;
251 match *self {
252 Self::All | Self::Any => S::VOTE,
253 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
254 S::ARITHMETIC
255 }
256 }
257 }
258}
259
260impl super::GatherMode {
261 const fn required_operations(&self) -> SubgroupOperationSet {
262 use SubgroupOperationSet as S;
263 match *self {
264 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
265 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
266 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
267 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
268 }
269 }
270}
271
272bitflags::bitflags! {
273 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
275 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
276 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
277 pub struct ShaderStages: u8 {
278 const VERTEX = 0x1;
279 const FRAGMENT = 0x2;
280 const COMPUTE = 0x4;
281 }
282}
283
284#[derive(Debug, Clone, Default)]
285#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
286#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
287pub struct ModuleInfo {
288 type_flags: Vec<TypeFlags>,
289 functions: Vec<FunctionInfo>,
290 entry_points: Vec<FunctionInfo>,
291 const_expression_types: Box<[TypeResolution]>,
292}
293
294impl ops::Index<Handle<crate::Type>> for ModuleInfo {
295 type Output = TypeFlags;
296 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
297 &self.type_flags[handle.index()]
298 }
299}
300
301impl ops::Index<Handle<crate::Function>> for ModuleInfo {
302 type Output = FunctionInfo;
303 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
304 &self.functions[handle.index()]
305 }
306}
307
308impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
309 type Output = TypeResolution;
310 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
311 &self.const_expression_types[handle.index()]
312 }
313}
314
315#[derive(Debug)]
316pub struct Validator {
317 flags: ValidationFlags,
318 capabilities: Capabilities,
319 subgroup_stages: ShaderStages,
320 subgroup_operations: SubgroupOperationSet,
321 types: Vec<r#type::TypeInfo>,
322 layouter: Layouter,
323 location_mask: BitSet,
324 blend_src_mask: BitSet,
325 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
326 #[allow(dead_code)]
327 switch_values: FastHashSet<crate::SwitchValue>,
328 valid_expression_list: Vec<Handle<crate::Expression>>,
329 valid_expression_set: HandleSet<crate::Expression>,
330 override_ids: FastHashSet<u16>,
331
332 overrides_resolved: bool,
335
336 needs_visit: HandleSet<crate::Expression>,
355}
356
357#[derive(Clone, Debug, thiserror::Error)]
358#[cfg_attr(test, derive(PartialEq))]
359pub enum ConstantError {
360 #[error("Initializer must be a const-expression")]
361 InitializerExprType,
362 #[error("The type doesn't match the constant")]
363 InvalidType,
364 #[error("The type is not constructible")]
365 NonConstructibleType,
366}
367
368#[derive(Clone, Debug, thiserror::Error)]
369#[cfg_attr(test, derive(PartialEq))]
370pub enum OverrideError {
371 #[error("Override name and ID are missing")]
372 MissingNameAndID,
373 #[error("Override ID must be unique")]
374 DuplicateID,
375 #[error("Initializer must be a const-expression or override-expression")]
376 InitializerExprType,
377 #[error("The type doesn't match the override")]
378 InvalidType,
379 #[error("The type is not constructible")]
380 NonConstructibleType,
381 #[error("The type is not a scalar")]
382 TypeNotScalar,
383 #[error("Override declarations are not allowed")]
384 NotAllowed,
385 #[error("Override is uninitialized")]
386 UninitializedOverride,
387 #[error("Constant expression {handle:?} is invalid")]
388 ConstExpression {
389 handle: Handle<crate::Expression>,
390 source: ConstExpressionError,
391 },
392}
393
394#[derive(Clone, Debug, thiserror::Error)]
395#[cfg_attr(test, derive(PartialEq))]
396pub enum ValidationError {
397 #[error(transparent)]
398 InvalidHandle(#[from] InvalidHandleError),
399 #[error(transparent)]
400 Layouter(#[from] LayoutError),
401 #[error("Type {handle:?} '{name}' is invalid")]
402 Type {
403 handle: Handle<crate::Type>,
404 name: String,
405 source: TypeError,
406 },
407 #[error("Constant expression {handle:?} is invalid")]
408 ConstExpression {
409 handle: Handle<crate::Expression>,
410 source: ConstExpressionError,
411 },
412 #[error("Array size expression {handle:?} is not strictly positive")]
413 ArraySizeError { handle: Handle<crate::Expression> },
414 #[error("Constant {handle:?} '{name}' is invalid")]
415 Constant {
416 handle: Handle<crate::Constant>,
417 name: String,
418 source: ConstantError,
419 },
420 #[error("Override {handle:?} '{name}' is invalid")]
421 Override {
422 handle: Handle<crate::Override>,
423 name: String,
424 source: OverrideError,
425 },
426 #[error("Global variable {handle:?} '{name}' is invalid")]
427 GlobalVariable {
428 handle: Handle<crate::GlobalVariable>,
429 name: String,
430 source: GlobalVariableError,
431 },
432 #[error("Function {handle:?} '{name}' is invalid")]
433 Function {
434 handle: Handle<crate::Function>,
435 name: String,
436 source: FunctionError,
437 },
438 #[error("Entry point {name} at {stage:?} is invalid")]
439 EntryPoint {
440 stage: crate::ShaderStage,
441 name: String,
442 source: EntryPointError,
443 },
444 #[error("Module is corrupted")]
445 Corrupted,
446}
447
448impl crate::TypeInner {
449 const fn is_sized(&self) -> bool {
450 match *self {
451 Self::Scalar { .. }
452 | Self::Vector { .. }
453 | Self::Matrix { .. }
454 | Self::Array {
455 size: crate::ArraySize::Constant(_),
456 ..
457 }
458 | Self::Atomic { .. }
459 | Self::Pointer { .. }
460 | Self::ValuePointer { .. }
461 | Self::Struct { .. } => true,
462 Self::Array { .. }
463 | Self::Image { .. }
464 | Self::Sampler { .. }
465 | Self::AccelerationStructure { .. }
466 | Self::RayQuery { .. }
467 | Self::BindingArray { .. } => false,
468 }
469 }
470
471 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
473 match *self {
474 Self::Scalar(crate::Scalar {
475 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
476 ..
477 }) => Some(crate::ImageDimension::D1),
478 Self::Vector {
479 size: crate::VectorSize::Bi,
480 scalar:
481 crate::Scalar {
482 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
483 ..
484 },
485 } => Some(crate::ImageDimension::D2),
486 Self::Vector {
487 size: crate::VectorSize::Tri,
488 scalar:
489 crate::Scalar {
490 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
491 ..
492 },
493 } => Some(crate::ImageDimension::D3),
494 _ => None,
495 }
496 }
497}
498
499impl Validator {
500 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
514 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
515 use SubgroupOperationSet as S;
516 S::BASIC
517 | S::VOTE
518 | S::ARITHMETIC
519 | S::BALLOT
520 | S::SHUFFLE
521 | S::SHUFFLE_RELATIVE
522 | S::QUAD_FRAGMENT_COMPUTE
523 } else {
524 SubgroupOperationSet::empty()
525 };
526 let subgroup_stages = {
527 let mut stages = ShaderStages::empty();
528 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
529 stages |= ShaderStages::VERTEX;
530 }
531 if capabilities.contains(Capabilities::SUBGROUP) {
532 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
533 }
534 stages
535 };
536
537 Validator {
538 flags,
539 capabilities,
540 subgroup_stages,
541 subgroup_operations,
542 types: Vec::new(),
543 layouter: Layouter::default(),
544 location_mask: BitSet::new(),
545 blend_src_mask: BitSet::new(),
546 ep_resource_bindings: FastHashSet::default(),
547 switch_values: FastHashSet::default(),
548 valid_expression_list: Vec::new(),
549 valid_expression_set: HandleSet::new(),
550 override_ids: FastHashSet::default(),
551 overrides_resolved: false,
552 needs_visit: HandleSet::new(),
553 }
554 }
555
556 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
557 self.subgroup_stages = stages;
558 self
559 }
560
561 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
562 self.subgroup_operations = operations;
563 self
564 }
565
566 pub fn reset(&mut self) {
568 self.types.clear();
569 self.layouter.clear();
570 self.location_mask.clear();
571 self.blend_src_mask.clear();
572 self.ep_resource_bindings.clear();
573 self.switch_values.clear();
574 self.valid_expression_list.clear();
575 self.valid_expression_set.clear();
576 self.override_ids.clear();
577 }
578
579 fn validate_constant(
580 &self,
581 handle: Handle<crate::Constant>,
582 gctx: crate::proc::GlobalCtx,
583 mod_info: &ModuleInfo,
584 global_expr_kind: &ExpressionKindTracker,
585 ) -> Result<(), ConstantError> {
586 let con = &gctx.constants[handle];
587
588 let type_info = &self.types[con.ty.index()];
589 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
590 return Err(ConstantError::NonConstructibleType);
591 }
592
593 if !global_expr_kind.is_const(con.init) {
594 return Err(ConstantError::InitializerExprType);
595 }
596
597 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
598 return Err(ConstantError::InvalidType);
599 }
600
601 Ok(())
602 }
603
604 fn validate_override(
605 &mut self,
606 handle: Handle<crate::Override>,
607 gctx: crate::proc::GlobalCtx,
608 mod_info: &ModuleInfo,
609 ) -> Result<(), OverrideError> {
610 let o = &gctx.overrides[handle];
611
612 if let Some(id) = o.id {
613 if !self.override_ids.insert(id) {
614 return Err(OverrideError::DuplicateID);
615 }
616 }
617
618 let type_info = &self.types[o.ty.index()];
619 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
620 return Err(OverrideError::NonConstructibleType);
621 }
622
623 match gctx.types[o.ty].inner {
624 crate::TypeInner::Scalar(
625 crate::Scalar::BOOL
626 | crate::Scalar::I32
627 | crate::Scalar::U32
628 | crate::Scalar::F16
629 | crate::Scalar::F32
630 | crate::Scalar::F64,
631 ) => {}
632 _ => return Err(OverrideError::TypeNotScalar),
633 }
634
635 if let Some(init) = o.init {
636 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
637 return Err(OverrideError::InvalidType);
638 }
639 } else if self.overrides_resolved {
640 return Err(OverrideError::UninitializedOverride);
641 }
642
643 Ok(())
644 }
645
646 pub fn validate(
648 &mut self,
649 module: &crate::Module,
650 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
651 self.overrides_resolved = false;
652 self.validate_impl(module)
653 }
654
655 pub fn validate_resolved_overrides(
663 &mut self,
664 module: &crate::Module,
665 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
666 self.overrides_resolved = true;
667 self.validate_impl(module)
668 }
669
670 fn validate_impl(
671 &mut self,
672 module: &crate::Module,
673 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
674 self.reset();
675 self.reset_types(module.types.len());
676
677 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
678
679 self.layouter.update(module.to_ctx()).map_err(|e| {
680 let handle = e.ty;
681 ValidationError::from(e).with_span_handle(handle, &module.types)
682 })?;
683
684 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
686 kind: crate::ScalarKind::Bool,
687 width: 0,
688 }));
689
690 let mut mod_info = ModuleInfo {
691 type_flags: Vec::with_capacity(module.types.len()),
692 functions: Vec::with_capacity(module.functions.len()),
693 entry_points: Vec::with_capacity(module.entry_points.len()),
694 const_expression_types: vec![placeholder; module.global_expressions.len()]
695 .into_boxed_slice(),
696 };
697
698 for (handle, ty) in module.types.iter() {
699 let ty_info = self
700 .validate_type(handle, module.to_ctx())
701 .map_err(|source| {
702 ValidationError::Type {
703 handle,
704 name: ty.name.clone().unwrap_or_default(),
705 source,
706 }
707 .with_span_handle(handle, &module.types)
708 })?;
709 mod_info.type_flags.push(ty_info.flags);
710 self.types[handle.index()] = ty_info;
711 }
712
713 {
714 let t = crate::Arena::new();
715 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
716 for (handle, _) in module.global_expressions.iter() {
717 mod_info
718 .process_const_expression(handle, &resolve_context, module.to_ctx())
719 .map_err(|source| {
720 ValidationError::ConstExpression { handle, source }
721 .with_span_handle(handle, &module.global_expressions)
722 })?
723 }
724 }
725
726 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
727
728 if self.flags.contains(ValidationFlags::CONSTANTS) {
729 for (handle, _) in module.global_expressions.iter() {
730 self.validate_const_expression(
731 handle,
732 module.to_ctx(),
733 &mod_info,
734 &global_expr_kind,
735 )
736 .map_err(|source| {
737 ValidationError::ConstExpression { handle, source }
738 .with_span_handle(handle, &module.global_expressions)
739 })?
740 }
741
742 for (handle, constant) in module.constants.iter() {
743 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
744 .map_err(|source| {
745 ValidationError::Constant {
746 handle,
747 name: constant.name.clone().unwrap_or_default(),
748 source,
749 }
750 .with_span_handle(handle, &module.constants)
751 })?
752 }
753
754 for (handle, r#override) in module.overrides.iter() {
755 self.validate_override(handle, module.to_ctx(), &mod_info)
756 .map_err(|source| {
757 ValidationError::Override {
758 handle,
759 name: r#override.name.clone().unwrap_or_default(),
760 source,
761 }
762 .with_span_handle(handle, &module.overrides)
763 })?;
764 }
765 }
766
767 for (var_handle, var) in module.global_variables.iter() {
768 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
769 .map_err(|source| {
770 ValidationError::GlobalVariable {
771 handle: var_handle,
772 name: var.name.clone().unwrap_or_default(),
773 source,
774 }
775 .with_span_handle(var_handle, &module.global_variables)
776 })?;
777 }
778
779 for (handle, fun) in module.functions.iter() {
780 match self.validate_function(fun, module, &mod_info, false) {
781 Ok(info) => mod_info.functions.push(info),
782 Err(error) => {
783 return Err(error.and_then(|source| {
784 ValidationError::Function {
785 handle,
786 name: fun.name.clone().unwrap_or_default(),
787 source,
788 }
789 .with_span_handle(handle, &module.functions)
790 }))
791 }
792 }
793 }
794
795 let mut ep_map = FastHashSet::default();
796 for ep in module.entry_points.iter() {
797 if !ep_map.insert((ep.stage, &ep.name)) {
798 return Err(ValidationError::EntryPoint {
799 stage: ep.stage,
800 name: ep.name.clone(),
801 source: EntryPointError::Conflict,
802 }
803 .with_span()); }
805
806 match self.validate_entry_point(ep, module, &mod_info) {
807 Ok(info) => mod_info.entry_points.push(info),
808 Err(error) => {
809 return Err(error.and_then(|source| {
810 ValidationError::EntryPoint {
811 stage: ep.stage,
812 name: ep.name.clone(),
813 source,
814 }
815 .with_span()
816 }));
817 }
818 }
819 }
820
821 Ok(mod_info)
822 }
823}
824
825fn validate_atomic_compare_exchange_struct(
826 types: &crate::UniqueArena<crate::Type>,
827 members: &[crate::StructMember],
828 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
829) -> bool {
830 members.len() == 2
831 && members[0].name.as_deref() == Some("old_value")
832 && scalar_predicate(&types[members[0].ty].inner)
833 && members[1].name.as_deref() == Some("exchanged")
834 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
835}