1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; bitflags::bitflags! {
42 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58 pub struct ValidationFlags: u8 {
59 const EXPRESSIONS = 0x1;
61 const BLOCKS = 0x2;
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
65 const STRUCT_LAYOUTS = 0x8;
67 const CONSTANTS = 0x10;
69 const BINDINGS = 0x20;
71 }
72}
73
74impl Default for ValidationFlags {
75 fn default() -> Self {
76 Self::all()
77 }
78}
79
80bitflags::bitflags! {
81 #[must_use]
83 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86 pub struct Capabilities: u32 {
87 const PUSH_CONSTANT = 1 << 0;
91 const FLOAT64 = 1 << 1;
93 const PRIMITIVE_INDEX = 1 << 2;
97 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105 const CLIP_DISTANCE = 1 << 7;
109 const CULL_DISTANCE = 1 << 8;
113 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115 const MULTIVIEW = 1 << 10;
119 const EARLY_DEPTH_TEST = 1 << 11;
121 const MULTISAMPLED_SHADING = 1 << 12;
126 const RAY_QUERY = 1 << 13;
128 const DUAL_SOURCE_BLENDING = 1 << 14;
130 const CUBE_ARRAY_TEXTURES = 1 << 15;
132 const SHADER_INT64 = 1 << 16;
134 const SUBGROUP = 1 << 17;
145 const SUBGROUP_BARRIER = 1 << 18;
149 const SUBGROUP_VERTEX_STAGE = 1 << 19;
155 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
165 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
167 const SHADER_FLOAT32_ATOMIC = 1 << 22;
176 const TEXTURE_ATOMIC = 1 << 23;
178 const TEXTURE_INT64_ATOMIC = 1 << 24;
180 const RAY_HIT_VERTEX_POSITION = 1 << 25;
182 const SHADER_FLOAT16 = 1 << 26;
184 const TEXTURE_EXTERNAL = 1 << 27;
186 const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189 const SHADER_BARYCENTRICS = 1 << 29;
191 const MESH_SHADER = 1 << 30;
193 }
194}
195
196impl Capabilities {
197 #[cfg(feature = "wgsl-in")]
201 #[doc(hidden)]
202 pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
203 use crate::front::wgsl::ImplementedEnableExtension as Ext;
204 match *self {
205 Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
206 Self::SHADER_FLOAT16 => Some(Ext::F16),
208 Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
209 _ => None,
210 }
211 }
212}
213
214impl Default for Capabilities {
215 fn default() -> Self {
216 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
217 }
218}
219
220bitflags::bitflags! {
221 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
223 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
224 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
225 pub struct SubgroupOperationSet: u8 {
226 const BASIC = 1 << 0;
232 const VOTE = 1 << 1;
234 const ARITHMETIC = 1 << 2;
236 const BALLOT = 1 << 3;
238 const SHUFFLE = 1 << 4;
240 const SHUFFLE_RELATIVE = 1 << 5;
242 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
247 }
250}
251
252impl super::SubgroupOperation {
253 const fn required_operations(&self) -> SubgroupOperationSet {
254 use SubgroupOperationSet as S;
255 match *self {
256 Self::All | Self::Any => S::VOTE,
257 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
258 S::ARITHMETIC
259 }
260 }
261 }
262}
263
264impl super::GatherMode {
265 const fn required_operations(&self) -> SubgroupOperationSet {
266 use SubgroupOperationSet as S;
267 match *self {
268 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
269 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
270 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
271 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
272 }
273 }
274}
275
276bitflags::bitflags! {
277 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
279 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
280 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
281 pub struct ShaderStages: u8 {
282 const VERTEX = 0x1;
283 const FRAGMENT = 0x2;
284 const COMPUTE = 0x4;
285 const MESH = 0x8;
286 const TASK = 0x10;
287 }
288}
289
290#[derive(Debug, Clone, Default)]
291#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
292#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
293pub struct ModuleInfo {
294 type_flags: Vec<TypeFlags>,
295 functions: Vec<FunctionInfo>,
296 entry_points: Vec<FunctionInfo>,
297 const_expression_types: Box<[TypeResolution]>,
298}
299
300impl ops::Index<Handle<crate::Type>> for ModuleInfo {
301 type Output = TypeFlags;
302 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
303 &self.type_flags[handle.index()]
304 }
305}
306
307impl ops::Index<Handle<crate::Function>> for ModuleInfo {
308 type Output = FunctionInfo;
309 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
310 &self.functions[handle.index()]
311 }
312}
313
314impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
315 type Output = TypeResolution;
316 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
317 &self.const_expression_types[handle.index()]
318 }
319}
320
321#[derive(Debug)]
322pub struct Validator {
323 flags: ValidationFlags,
324 capabilities: Capabilities,
325 subgroup_stages: ShaderStages,
326 subgroup_operations: SubgroupOperationSet,
327 types: Vec<r#type::TypeInfo>,
328 layouter: Layouter,
329 location_mask: BitSet,
330 blend_src_mask: BitSet,
331 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
332 #[allow(dead_code)]
333 switch_values: FastHashSet<crate::SwitchValue>,
334 valid_expression_list: Vec<Handle<crate::Expression>>,
335 valid_expression_set: HandleSet<crate::Expression>,
336 override_ids: FastHashSet<u16>,
337
338 overrides_resolved: bool,
341
342 needs_visit: HandleSet<crate::Expression>,
361}
362
363#[derive(Clone, Debug, thiserror::Error)]
364#[cfg_attr(test, derive(PartialEq))]
365pub enum ConstantError {
366 #[error("Initializer must be a const-expression")]
367 InitializerExprType,
368 #[error("The type doesn't match the constant")]
369 InvalidType,
370 #[error("The type is not constructible")]
371 NonConstructibleType,
372}
373
374#[derive(Clone, Debug, thiserror::Error)]
375#[cfg_attr(test, derive(PartialEq))]
376pub enum OverrideError {
377 #[error("Override name and ID are missing")]
378 MissingNameAndID,
379 #[error("Override ID must be unique")]
380 DuplicateID,
381 #[error("Initializer must be a const-expression or override-expression")]
382 InitializerExprType,
383 #[error("The type doesn't match the override")]
384 InvalidType,
385 #[error("The type is not constructible")]
386 NonConstructibleType,
387 #[error("The type is not a scalar")]
388 TypeNotScalar,
389 #[error("Override declarations are not allowed")]
390 NotAllowed,
391 #[error("Override is uninitialized")]
392 UninitializedOverride,
393 #[error("Constant expression {handle:?} is invalid")]
394 ConstExpression {
395 handle: Handle<crate::Expression>,
396 source: ConstExpressionError,
397 },
398}
399
400#[derive(Clone, Debug, thiserror::Error)]
401#[cfg_attr(test, derive(PartialEq))]
402pub enum ValidationError {
403 #[error(transparent)]
404 InvalidHandle(#[from] InvalidHandleError),
405 #[error(transparent)]
406 Layouter(#[from] LayoutError),
407 #[error("Type {handle:?} '{name}' is invalid")]
408 Type {
409 handle: Handle<crate::Type>,
410 name: String,
411 source: TypeError,
412 },
413 #[error("Constant expression {handle:?} is invalid")]
414 ConstExpression {
415 handle: Handle<crate::Expression>,
416 source: ConstExpressionError,
417 },
418 #[error("Array size expression {handle:?} is not strictly positive")]
419 ArraySizeError { handle: Handle<crate::Expression> },
420 #[error("Constant {handle:?} '{name}' is invalid")]
421 Constant {
422 handle: Handle<crate::Constant>,
423 name: String,
424 source: ConstantError,
425 },
426 #[error("Override {handle:?} '{name}' is invalid")]
427 Override {
428 handle: Handle<crate::Override>,
429 name: String,
430 source: OverrideError,
431 },
432 #[error("Global variable {handle:?} '{name}' is invalid")]
433 GlobalVariable {
434 handle: Handle<crate::GlobalVariable>,
435 name: String,
436 source: GlobalVariableError,
437 },
438 #[error("Function {handle:?} '{name}' is invalid")]
439 Function {
440 handle: Handle<crate::Function>,
441 name: String,
442 source: FunctionError,
443 },
444 #[error("Entry point {name} at {stage:?} is invalid")]
445 EntryPoint {
446 stage: crate::ShaderStage,
447 name: String,
448 source: EntryPointError,
449 },
450 #[error("Module is corrupted")]
451 Corrupted,
452}
453
454impl crate::TypeInner {
455 const fn is_sized(&self) -> bool {
456 match *self {
457 Self::Scalar { .. }
458 | Self::Vector { .. }
459 | Self::Matrix { .. }
460 | Self::Array {
461 size: crate::ArraySize::Constant(_),
462 ..
463 }
464 | Self::Atomic { .. }
465 | Self::Pointer { .. }
466 | Self::ValuePointer { .. }
467 | Self::Struct { .. } => true,
468 Self::Array { .. }
469 | Self::Image { .. }
470 | Self::Sampler { .. }
471 | Self::AccelerationStructure { .. }
472 | Self::RayQuery { .. }
473 | Self::BindingArray { .. } => false,
474 }
475 }
476
477 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
479 match *self {
480 Self::Scalar(crate::Scalar {
481 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
482 ..
483 }) => Some(crate::ImageDimension::D1),
484 Self::Vector {
485 size: crate::VectorSize::Bi,
486 scalar:
487 crate::Scalar {
488 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
489 ..
490 },
491 } => Some(crate::ImageDimension::D2),
492 Self::Vector {
493 size: crate::VectorSize::Tri,
494 scalar:
495 crate::Scalar {
496 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
497 ..
498 },
499 } => Some(crate::ImageDimension::D3),
500 _ => None,
501 }
502 }
503}
504
505impl Validator {
506 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
520 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
521 use SubgroupOperationSet as S;
522 S::BASIC
523 | S::VOTE
524 | S::ARITHMETIC
525 | S::BALLOT
526 | S::SHUFFLE
527 | S::SHUFFLE_RELATIVE
528 | S::QUAD_FRAGMENT_COMPUTE
529 } else {
530 SubgroupOperationSet::empty()
531 };
532 let subgroup_stages = {
533 let mut stages = ShaderStages::empty();
534 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
535 stages |= ShaderStages::VERTEX;
536 }
537 if capabilities.contains(Capabilities::SUBGROUP) {
538 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
539 }
540 stages
541 };
542
543 Validator {
544 flags,
545 capabilities,
546 subgroup_stages,
547 subgroup_operations,
548 types: Vec::new(),
549 layouter: Layouter::default(),
550 location_mask: BitSet::new(),
551 blend_src_mask: BitSet::new(),
552 ep_resource_bindings: FastHashSet::default(),
553 switch_values: FastHashSet::default(),
554 valid_expression_list: Vec::new(),
555 valid_expression_set: HandleSet::new(),
556 override_ids: FastHashSet::default(),
557 overrides_resolved: false,
558 needs_visit: HandleSet::new(),
559 }
560 }
561
562 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
563 self.subgroup_stages = stages;
564 self
565 }
566
567 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
568 self.subgroup_operations = operations;
569 self
570 }
571
572 pub fn reset(&mut self) {
574 self.types.clear();
575 self.layouter.clear();
576 self.location_mask.clear();
577 self.blend_src_mask.clear();
578 self.ep_resource_bindings.clear();
579 self.switch_values.clear();
580 self.valid_expression_list.clear();
581 self.valid_expression_set.clear();
582 self.override_ids.clear();
583 }
584
585 fn validate_constant(
586 &self,
587 handle: Handle<crate::Constant>,
588 gctx: crate::proc::GlobalCtx,
589 mod_info: &ModuleInfo,
590 global_expr_kind: &ExpressionKindTracker,
591 ) -> Result<(), ConstantError> {
592 let con = &gctx.constants[handle];
593
594 let type_info = &self.types[con.ty.index()];
595 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
596 return Err(ConstantError::NonConstructibleType);
597 }
598
599 if !global_expr_kind.is_const(con.init) {
600 return Err(ConstantError::InitializerExprType);
601 }
602
603 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
604 return Err(ConstantError::InvalidType);
605 }
606
607 Ok(())
608 }
609
610 fn validate_override(
611 &mut self,
612 handle: Handle<crate::Override>,
613 gctx: crate::proc::GlobalCtx,
614 mod_info: &ModuleInfo,
615 ) -> Result<(), OverrideError> {
616 let o = &gctx.overrides[handle];
617
618 if let Some(id) = o.id {
619 if !self.override_ids.insert(id) {
620 return Err(OverrideError::DuplicateID);
621 }
622 }
623
624 let type_info = &self.types[o.ty.index()];
625 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
626 return Err(OverrideError::NonConstructibleType);
627 }
628
629 match gctx.types[o.ty].inner {
630 crate::TypeInner::Scalar(
631 crate::Scalar::BOOL
632 | crate::Scalar::I32
633 | crate::Scalar::U32
634 | crate::Scalar::F16
635 | crate::Scalar::F32
636 | crate::Scalar::F64,
637 ) => {}
638 _ => return Err(OverrideError::TypeNotScalar),
639 }
640
641 if let Some(init) = o.init {
642 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
643 return Err(OverrideError::InvalidType);
644 }
645 } else if self.overrides_resolved {
646 return Err(OverrideError::UninitializedOverride);
647 }
648
649 Ok(())
650 }
651
652 pub fn validate(
654 &mut self,
655 module: &crate::Module,
656 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
657 self.overrides_resolved = false;
658 self.validate_impl(module)
659 }
660
661 pub fn validate_resolved_overrides(
669 &mut self,
670 module: &crate::Module,
671 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
672 self.overrides_resolved = true;
673 self.validate_impl(module)
674 }
675
676 fn validate_impl(
677 &mut self,
678 module: &crate::Module,
679 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
680 self.reset();
681 self.reset_types(module.types.len());
682
683 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
684
685 self.layouter.update(module.to_ctx()).map_err(|e| {
686 let handle = e.ty;
687 ValidationError::from(e).with_span_handle(handle, &module.types)
688 })?;
689
690 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
692 kind: crate::ScalarKind::Bool,
693 width: 0,
694 }));
695
696 let mut mod_info = ModuleInfo {
697 type_flags: Vec::with_capacity(module.types.len()),
698 functions: Vec::with_capacity(module.functions.len()),
699 entry_points: Vec::with_capacity(module.entry_points.len()),
700 const_expression_types: vec![placeholder; module.global_expressions.len()]
701 .into_boxed_slice(),
702 };
703
704 for (handle, ty) in module.types.iter() {
705 let ty_info = self
706 .validate_type(handle, module.to_ctx())
707 .map_err(|source| {
708 ValidationError::Type {
709 handle,
710 name: ty.name.clone().unwrap_or_default(),
711 source,
712 }
713 .with_span_handle(handle, &module.types)
714 })?;
715 mod_info.type_flags.push(ty_info.flags);
716 self.types[handle.index()] = ty_info;
717 }
718
719 {
720 let t = crate::Arena::new();
721 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
722 for (handle, _) in module.global_expressions.iter() {
723 mod_info
724 .process_const_expression(handle, &resolve_context, module.to_ctx())
725 .map_err(|source| {
726 ValidationError::ConstExpression { handle, source }
727 .with_span_handle(handle, &module.global_expressions)
728 })?
729 }
730 }
731
732 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
733
734 if self.flags.contains(ValidationFlags::CONSTANTS) {
735 for (handle, _) in module.global_expressions.iter() {
736 self.validate_const_expression(
737 handle,
738 module.to_ctx(),
739 &mod_info,
740 &global_expr_kind,
741 )
742 .map_err(|source| {
743 ValidationError::ConstExpression { handle, source }
744 .with_span_handle(handle, &module.global_expressions)
745 })?
746 }
747
748 for (handle, constant) in module.constants.iter() {
749 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
750 .map_err(|source| {
751 ValidationError::Constant {
752 handle,
753 name: constant.name.clone().unwrap_or_default(),
754 source,
755 }
756 .with_span_handle(handle, &module.constants)
757 })?
758 }
759
760 for (handle, r#override) in module.overrides.iter() {
761 self.validate_override(handle, module.to_ctx(), &mod_info)
762 .map_err(|source| {
763 ValidationError::Override {
764 handle,
765 name: r#override.name.clone().unwrap_or_default(),
766 source,
767 }
768 .with_span_handle(handle, &module.overrides)
769 })?;
770 }
771 }
772
773 for (var_handle, var) in module.global_variables.iter() {
774 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
775 .map_err(|source| {
776 ValidationError::GlobalVariable {
777 handle: var_handle,
778 name: var.name.clone().unwrap_or_default(),
779 source,
780 }
781 .with_span_handle(var_handle, &module.global_variables)
782 })?;
783 }
784
785 for (handle, fun) in module.functions.iter() {
786 match self.validate_function(fun, module, &mod_info, false) {
787 Ok(info) => mod_info.functions.push(info),
788 Err(error) => {
789 return Err(error.and_then(|source| {
790 ValidationError::Function {
791 handle,
792 name: fun.name.clone().unwrap_or_default(),
793 source,
794 }
795 .with_span_handle(handle, &module.functions)
796 }))
797 }
798 }
799 }
800
801 let mut ep_map = FastHashSet::default();
802 for ep in module.entry_points.iter() {
803 if !ep_map.insert((ep.stage, &ep.name)) {
804 return Err(ValidationError::EntryPoint {
805 stage: ep.stage,
806 name: ep.name.clone(),
807 source: EntryPointError::Conflict,
808 }
809 .with_span()); }
811
812 match self.validate_entry_point(ep, module, &mod_info) {
813 Ok(info) => mod_info.entry_points.push(info),
814 Err(error) => {
815 return Err(error.and_then(|source| {
816 ValidationError::EntryPoint {
817 stage: ep.stage,
818 name: ep.name.clone(),
819 source,
820 }
821 .with_span()
822 }));
823 }
824 }
825 }
826
827 Ok(mod_info)
828 }
829}
830
831fn validate_atomic_compare_exchange_struct(
832 types: &crate::UniqueArena<crate::Type>,
833 members: &[crate::StructMember],
834 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
835) -> bool {
836 members.len() == 2
837 && members[0].name.as_deref() == Some("old_value")
838 && scalar_predicate(&types[members[0].ty].inner)
839 && members[1].name.as_deref() == Some("exchanged")
840 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
841}