1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; bitflags::bitflags! {
42 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58 pub struct ValidationFlags: u8 {
59 const EXPRESSIONS = 0x1;
61 const BLOCKS = 0x2;
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
65 const STRUCT_LAYOUTS = 0x8;
67 const CONSTANTS = 0x10;
69 const BINDINGS = 0x20;
71 }
72}
73
74impl Default for ValidationFlags {
75 fn default() -> Self {
76 Self::all()
77 }
78}
79
80bitflags::bitflags! {
81 #[must_use]
83 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86 pub struct Capabilities: u32 {
87 const PUSH_CONSTANT = 1 << 0;
91 const FLOAT64 = 1 << 1;
93 const PRIMITIVE_INDEX = 1 << 2;
97 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105 const CLIP_DISTANCE = 1 << 7;
109 const CULL_DISTANCE = 1 << 8;
113 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115 const MULTIVIEW = 1 << 10;
119 const EARLY_DEPTH_TEST = 1 << 11;
121 const MULTISAMPLED_SHADING = 1 << 12;
126 const RAY_QUERY = 1 << 13;
128 const DUAL_SOURCE_BLENDING = 1 << 14;
130 const CUBE_ARRAY_TEXTURES = 1 << 15;
132 const SHADER_INT64 = 1 << 16;
134 const SUBGROUP = 1 << 17;
145 const SUBGROUP_BARRIER = 1 << 18;
149 const SUBGROUP_VERTEX_STAGE = 1 << 19;
155 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
165 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
167 const SHADER_FLOAT32_ATOMIC = 1 << 22;
176 const TEXTURE_ATOMIC = 1 << 23;
178 const TEXTURE_INT64_ATOMIC = 1 << 24;
180 const RAY_HIT_VERTEX_POSITION = 1 << 25;
182 const SHADER_FLOAT16 = 1 << 26;
184 const TEXTURE_EXTERNAL = 1 << 27;
186 const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189 const SHADER_BARYCENTRICS = 1 << 29;
191 const MESH_SHADER = 1 << 30;
193 const MESH_SHADER_POINT_TOPOLOGY = 1 << 30;
195 }
196}
197
198impl Capabilities {
199 #[cfg(feature = "wgsl-in")]
203 #[doc(hidden)]
204 pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
205 use crate::front::wgsl::ImplementedEnableExtension as Ext;
206 match *self {
207 Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
208 Self::SHADER_FLOAT16 => Some(Ext::F16),
210 Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
211 Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
212 Self::RAY_HIT_VERTEX_POSITION => Some(Ext::WgpuRayQueryVertexReturn),
213 _ => None,
214 }
215 }
216}
217
218impl Default for Capabilities {
219 fn default() -> Self {
220 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
221 }
222}
223
224bitflags::bitflags! {
225 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
227 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
228 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
229 pub struct SubgroupOperationSet: u8 {
230 const BASIC = 1 << 0;
236 const VOTE = 1 << 1;
238 const ARITHMETIC = 1 << 2;
240 const BALLOT = 1 << 3;
242 const SHUFFLE = 1 << 4;
244 const SHUFFLE_RELATIVE = 1 << 5;
246 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
251 }
254}
255
256impl super::SubgroupOperation {
257 const fn required_operations(&self) -> SubgroupOperationSet {
258 use SubgroupOperationSet as S;
259 match *self {
260 Self::All | Self::Any => S::VOTE,
261 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
262 S::ARITHMETIC
263 }
264 }
265 }
266}
267
268impl super::GatherMode {
269 const fn required_operations(&self) -> SubgroupOperationSet {
270 use SubgroupOperationSet as S;
271 match *self {
272 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
273 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
274 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
275 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
276 }
277 }
278}
279
280bitflags::bitflags! {
281 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
283 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
284 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
285 pub struct ShaderStages: u8 {
286 const VERTEX = 0x1;
287 const FRAGMENT = 0x2;
288 const COMPUTE = 0x4;
289 const MESH = 0x8;
290 const TASK = 0x10;
291 }
292}
293
294#[derive(Debug, Clone, Default)]
295#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
296#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
297pub struct ModuleInfo {
298 type_flags: Vec<TypeFlags>,
299 functions: Vec<FunctionInfo>,
300 entry_points: Vec<FunctionInfo>,
301 const_expression_types: Box<[TypeResolution]>,
302}
303
304impl ops::Index<Handle<crate::Type>> for ModuleInfo {
305 type Output = TypeFlags;
306 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
307 &self.type_flags[handle.index()]
308 }
309}
310
311impl ops::Index<Handle<crate::Function>> for ModuleInfo {
312 type Output = FunctionInfo;
313 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
314 &self.functions[handle.index()]
315 }
316}
317
318impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
319 type Output = TypeResolution;
320 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
321 &self.const_expression_types[handle.index()]
322 }
323}
324
325#[derive(Debug)]
326pub struct Validator {
327 flags: ValidationFlags,
328 capabilities: Capabilities,
329 subgroup_stages: ShaderStages,
330 subgroup_operations: SubgroupOperationSet,
331 types: Vec<r#type::TypeInfo>,
332 layouter: Layouter,
333 location_mask: BitSet,
334 blend_src_mask: BitSet,
335 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
336 #[allow(dead_code)]
337 switch_values: FastHashSet<crate::SwitchValue>,
338 valid_expression_list: Vec<Handle<crate::Expression>>,
339 valid_expression_set: HandleSet<crate::Expression>,
340 override_ids: FastHashSet<u16>,
341
342 overrides_resolved: bool,
345
346 needs_visit: HandleSet<crate::Expression>,
365}
366
367#[derive(Clone, Debug, thiserror::Error)]
368#[cfg_attr(test, derive(PartialEq))]
369pub enum ConstantError {
370 #[error("Initializer must be a const-expression")]
371 InitializerExprType,
372 #[error("The type doesn't match the constant")]
373 InvalidType,
374 #[error("The type is not constructible")]
375 NonConstructibleType,
376}
377
378#[derive(Clone, Debug, thiserror::Error)]
379#[cfg_attr(test, derive(PartialEq))]
380pub enum OverrideError {
381 #[error("Override name and ID are missing")]
382 MissingNameAndID,
383 #[error("Override ID must be unique")]
384 DuplicateID,
385 #[error("Initializer must be a const-expression or override-expression")]
386 InitializerExprType,
387 #[error("The type doesn't match the override")]
388 InvalidType,
389 #[error("The type is not constructible")]
390 NonConstructibleType,
391 #[error("The type is not a scalar")]
392 TypeNotScalar,
393 #[error("Override declarations are not allowed")]
394 NotAllowed,
395 #[error("Override is uninitialized")]
396 UninitializedOverride,
397 #[error("Constant expression {handle:?} is invalid")]
398 ConstExpression {
399 handle: Handle<crate::Expression>,
400 source: ConstExpressionError,
401 },
402}
403
404#[derive(Clone, Debug, thiserror::Error)]
405#[cfg_attr(test, derive(PartialEq))]
406pub enum ValidationError {
407 #[error(transparent)]
408 InvalidHandle(#[from] InvalidHandleError),
409 #[error(transparent)]
410 Layouter(#[from] LayoutError),
411 #[error("Type {handle:?} '{name}' is invalid")]
412 Type {
413 handle: Handle<crate::Type>,
414 name: String,
415 source: TypeError,
416 },
417 #[error("Constant expression {handle:?} is invalid")]
418 ConstExpression {
419 handle: Handle<crate::Expression>,
420 source: ConstExpressionError,
421 },
422 #[error("Array size expression {handle:?} is not strictly positive")]
423 ArraySizeError { handle: Handle<crate::Expression> },
424 #[error("Constant {handle:?} '{name}' is invalid")]
425 Constant {
426 handle: Handle<crate::Constant>,
427 name: String,
428 source: ConstantError,
429 },
430 #[error("Override {handle:?} '{name}' is invalid")]
431 Override {
432 handle: Handle<crate::Override>,
433 name: String,
434 source: OverrideError,
435 },
436 #[error("Global variable {handle:?} '{name}' is invalid")]
437 GlobalVariable {
438 handle: Handle<crate::GlobalVariable>,
439 name: String,
440 source: GlobalVariableError,
441 },
442 #[error("Function {handle:?} '{name}' is invalid")]
443 Function {
444 handle: Handle<crate::Function>,
445 name: String,
446 source: FunctionError,
447 },
448 #[error("Entry point {name} at {stage:?} is invalid")]
449 EntryPoint {
450 stage: crate::ShaderStage,
451 name: String,
452 source: EntryPointError,
453 },
454 #[error("Module is corrupted")]
455 Corrupted,
456}
457
458impl crate::TypeInner {
459 const fn is_sized(&self) -> bool {
460 match *self {
461 Self::Scalar { .. }
462 | Self::Vector { .. }
463 | Self::Matrix { .. }
464 | Self::Array {
465 size: crate::ArraySize::Constant(_),
466 ..
467 }
468 | Self::Atomic { .. }
469 | Self::Pointer { .. }
470 | Self::ValuePointer { .. }
471 | Self::Struct { .. } => true,
472 Self::Array { .. }
473 | Self::Image { .. }
474 | Self::Sampler { .. }
475 | Self::AccelerationStructure { .. }
476 | Self::RayQuery { .. }
477 | Self::BindingArray { .. } => false,
478 }
479 }
480
481 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
483 match *self {
484 Self::Scalar(crate::Scalar {
485 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
486 ..
487 }) => Some(crate::ImageDimension::D1),
488 Self::Vector {
489 size: crate::VectorSize::Bi,
490 scalar:
491 crate::Scalar {
492 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
493 ..
494 },
495 } => Some(crate::ImageDimension::D2),
496 Self::Vector {
497 size: crate::VectorSize::Tri,
498 scalar:
499 crate::Scalar {
500 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
501 ..
502 },
503 } => Some(crate::ImageDimension::D3),
504 _ => None,
505 }
506 }
507}
508
509impl Validator {
510 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
524 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
525 use SubgroupOperationSet as S;
526 S::BASIC
527 | S::VOTE
528 | S::ARITHMETIC
529 | S::BALLOT
530 | S::SHUFFLE
531 | S::SHUFFLE_RELATIVE
532 | S::QUAD_FRAGMENT_COMPUTE
533 } else {
534 SubgroupOperationSet::empty()
535 };
536 let subgroup_stages = {
537 let mut stages = ShaderStages::empty();
538 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
539 stages |= ShaderStages::VERTEX;
540 }
541 if capabilities.contains(Capabilities::SUBGROUP) {
542 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
543 }
544 stages
545 };
546
547 Validator {
548 flags,
549 capabilities,
550 subgroup_stages,
551 subgroup_operations,
552 types: Vec::new(),
553 layouter: Layouter::default(),
554 location_mask: BitSet::new(),
555 blend_src_mask: BitSet::new(),
556 ep_resource_bindings: FastHashSet::default(),
557 switch_values: FastHashSet::default(),
558 valid_expression_list: Vec::new(),
559 valid_expression_set: HandleSet::new(),
560 override_ids: FastHashSet::default(),
561 overrides_resolved: false,
562 needs_visit: HandleSet::new(),
563 }
564 }
565
566 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
567 self.subgroup_stages = stages;
568 self
569 }
570
571 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
572 self.subgroup_operations = operations;
573 self
574 }
575
576 pub fn reset(&mut self) {
578 self.types.clear();
579 self.layouter.clear();
580 self.location_mask.clear();
581 self.blend_src_mask.clear();
582 self.ep_resource_bindings.clear();
583 self.switch_values.clear();
584 self.valid_expression_list.clear();
585 self.valid_expression_set.clear();
586 self.override_ids.clear();
587 }
588
589 fn validate_constant(
590 &self,
591 handle: Handle<crate::Constant>,
592 gctx: crate::proc::GlobalCtx,
593 mod_info: &ModuleInfo,
594 global_expr_kind: &ExpressionKindTracker,
595 ) -> Result<(), ConstantError> {
596 let con = &gctx.constants[handle];
597
598 let type_info = &self.types[con.ty.index()];
599 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
600 return Err(ConstantError::NonConstructibleType);
601 }
602
603 if !global_expr_kind.is_const(con.init) {
604 return Err(ConstantError::InitializerExprType);
605 }
606
607 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
608 return Err(ConstantError::InvalidType);
609 }
610
611 Ok(())
612 }
613
614 fn validate_override(
615 &mut self,
616 handle: Handle<crate::Override>,
617 gctx: crate::proc::GlobalCtx,
618 mod_info: &ModuleInfo,
619 ) -> Result<(), OverrideError> {
620 let o = &gctx.overrides[handle];
621
622 if let Some(id) = o.id {
623 if !self.override_ids.insert(id) {
624 return Err(OverrideError::DuplicateID);
625 }
626 }
627
628 let type_info = &self.types[o.ty.index()];
629 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
630 return Err(OverrideError::NonConstructibleType);
631 }
632
633 match gctx.types[o.ty].inner {
634 crate::TypeInner::Scalar(
635 crate::Scalar::BOOL
636 | crate::Scalar::I32
637 | crate::Scalar::U32
638 | crate::Scalar::F16
639 | crate::Scalar::F32
640 | crate::Scalar::F64,
641 ) => {}
642 _ => return Err(OverrideError::TypeNotScalar),
643 }
644
645 if let Some(init) = o.init {
646 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
647 return Err(OverrideError::InvalidType);
648 }
649 } else if self.overrides_resolved {
650 return Err(OverrideError::UninitializedOverride);
651 }
652
653 Ok(())
654 }
655
656 pub fn validate(
658 &mut self,
659 module: &crate::Module,
660 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
661 self.overrides_resolved = false;
662 self.validate_impl(module)
663 }
664
665 pub fn validate_resolved_overrides(
673 &mut self,
674 module: &crate::Module,
675 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
676 self.overrides_resolved = true;
677 self.validate_impl(module)
678 }
679
680 fn validate_impl(
681 &mut self,
682 module: &crate::Module,
683 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
684 self.reset();
685 self.reset_types(module.types.len());
686
687 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
688
689 self.layouter.update(module.to_ctx()).map_err(|e| {
690 let handle = e.ty;
691 ValidationError::from(e).with_span_handle(handle, &module.types)
692 })?;
693
694 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
696 kind: crate::ScalarKind::Bool,
697 width: 0,
698 }));
699
700 let mut mod_info = ModuleInfo {
701 type_flags: Vec::with_capacity(module.types.len()),
702 functions: Vec::with_capacity(module.functions.len()),
703 entry_points: Vec::with_capacity(module.entry_points.len()),
704 const_expression_types: vec![placeholder; module.global_expressions.len()]
705 .into_boxed_slice(),
706 };
707
708 for (handle, ty) in module.types.iter() {
709 let ty_info = self
710 .validate_type(handle, module.to_ctx())
711 .map_err(|source| {
712 ValidationError::Type {
713 handle,
714 name: ty.name.clone().unwrap_or_default(),
715 source,
716 }
717 .with_span_handle(handle, &module.types)
718 })?;
719 mod_info.type_flags.push(ty_info.flags);
720 self.types[handle.index()] = ty_info;
721 }
722
723 {
724 let t = crate::Arena::new();
725 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
726 for (handle, _) in module.global_expressions.iter() {
727 mod_info
728 .process_const_expression(handle, &resolve_context, module.to_ctx())
729 .map_err(|source| {
730 ValidationError::ConstExpression { handle, source }
731 .with_span_handle(handle, &module.global_expressions)
732 })?
733 }
734 }
735
736 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
737
738 if self.flags.contains(ValidationFlags::CONSTANTS) {
739 for (handle, _) in module.global_expressions.iter() {
740 self.validate_const_expression(
741 handle,
742 module.to_ctx(),
743 &mod_info,
744 &global_expr_kind,
745 )
746 .map_err(|source| {
747 ValidationError::ConstExpression { handle, source }
748 .with_span_handle(handle, &module.global_expressions)
749 })?
750 }
751
752 for (handle, constant) in module.constants.iter() {
753 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
754 .map_err(|source| {
755 ValidationError::Constant {
756 handle,
757 name: constant.name.clone().unwrap_or_default(),
758 source,
759 }
760 .with_span_handle(handle, &module.constants)
761 })?
762 }
763
764 for (handle, r#override) in module.overrides.iter() {
765 self.validate_override(handle, module.to_ctx(), &mod_info)
766 .map_err(|source| {
767 ValidationError::Override {
768 handle,
769 name: r#override.name.clone().unwrap_or_default(),
770 source,
771 }
772 .with_span_handle(handle, &module.overrides)
773 })?;
774 }
775 }
776
777 for (var_handle, var) in module.global_variables.iter() {
778 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
779 .map_err(|source| {
780 ValidationError::GlobalVariable {
781 handle: var_handle,
782 name: var.name.clone().unwrap_or_default(),
783 source,
784 }
785 .with_span_handle(var_handle, &module.global_variables)
786 })?;
787 }
788
789 for (handle, fun) in module.functions.iter() {
790 match self.validate_function(fun, module, &mod_info, false) {
791 Ok(info) => mod_info.functions.push(info),
792 Err(error) => {
793 return Err(error.and_then(|source| {
794 ValidationError::Function {
795 handle,
796 name: fun.name.clone().unwrap_or_default(),
797 source,
798 }
799 .with_span_handle(handle, &module.functions)
800 }))
801 }
802 }
803 }
804
805 let mut ep_map = FastHashSet::default();
806 for ep in module.entry_points.iter() {
807 if !ep_map.insert((ep.stage, &ep.name)) {
808 return Err(ValidationError::EntryPoint {
809 stage: ep.stage,
810 name: ep.name.clone(),
811 source: EntryPointError::Conflict,
812 }
813 .with_span()); }
815
816 match self.validate_entry_point(ep, module, &mod_info) {
817 Ok(info) => mod_info.entry_points.push(info),
818 Err(error) => {
819 return Err(error.and_then(|source| {
820 ValidationError::EntryPoint {
821 stage: ep.stage,
822 name: ep.name.clone(),
823 source,
824 }
825 .with_span()
826 }));
827 }
828 }
829 }
830
831 Ok(mod_info)
832 }
833}
834
835fn validate_atomic_compare_exchange_struct(
836 types: &crate::UniqueArena<crate::Type>,
837 members: &[crate::StructMember],
838 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
839) -> bool {
840 members.len() == 2
841 && members[0].name.as_deref() == Some("old_value")
842 && scalar_predicate(&types[members[0].ty].inner)
843 && members[1].name.as_deref() == Some("exchanged")
844 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
845}