1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; bitflags::bitflags! {
42 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58 pub struct ValidationFlags: u8 {
59 const EXPRESSIONS = 0x1;
61 const BLOCKS = 0x2;
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
65 const STRUCT_LAYOUTS = 0x8;
67 const CONSTANTS = 0x10;
69 const BINDINGS = 0x20;
71 }
72}
73
74impl Default for ValidationFlags {
75 fn default() -> Self {
76 Self::all()
77 }
78}
79
80bitflags::bitflags! {
81 #[must_use]
83 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86 pub struct Capabilities: u32 {
87 const PUSH_CONSTANT = 1 << 0;
91 const FLOAT64 = 1 << 1;
93 const PRIMITIVE_INDEX = 1 << 2;
97 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105 const CLIP_DISTANCE = 1 << 7;
109 const CULL_DISTANCE = 1 << 8;
113 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115 const MULTIVIEW = 1 << 10;
119 const EARLY_DEPTH_TEST = 1 << 11;
121 const MULTISAMPLED_SHADING = 1 << 12;
126 const RAY_QUERY = 1 << 13;
128 const DUAL_SOURCE_BLENDING = 1 << 14;
130 const CUBE_ARRAY_TEXTURES = 1 << 15;
132 const SHADER_INT64 = 1 << 16;
134 const SUBGROUP = 1 << 17;
138 const SUBGROUP_BARRIER = 1 << 18;
140 const SUBGROUP_VERTEX_STAGE = 1 << 19;
142 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
152 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
154 const SHADER_FLOAT32_ATOMIC = 1 << 22;
163 const TEXTURE_ATOMIC = 1 << 23;
165 const TEXTURE_INT64_ATOMIC = 1 << 24;
167 const RAY_HIT_VERTEX_POSITION = 1 << 25;
169 const SHADER_FLOAT16 = 1 << 26;
171 const TEXTURE_EXTERNAL = 1 << 27;
173 const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
176 }
177}
178
179impl Capabilities {
180 #[cfg(feature = "wgsl-in")]
184 #[doc(hidden)]
185 pub const fn extension(&self) -> Option<crate::front::wgsl::ImplementedEnableExtension> {
186 use crate::front::wgsl::ImplementedEnableExtension as Ext;
187 match *self {
188 Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
189 Self::SHADER_FLOAT16 => Some(Ext::F16),
191 Self::CLIP_DISTANCE => Some(Ext::ClipDistances),
192 _ => None,
193 }
194 }
195}
196
197impl Default for Capabilities {
198 fn default() -> Self {
199 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
200 }
201}
202
203bitflags::bitflags! {
204 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
206 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
207 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
208 pub struct SubgroupOperationSet: u8 {
209 const BASIC = 1 << 0;
211 const VOTE = 1 << 1;
213 const ARITHMETIC = 1 << 2;
215 const BALLOT = 1 << 3;
217 const SHUFFLE = 1 << 4;
219 const SHUFFLE_RELATIVE = 1 << 5;
221 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
226 }
229}
230
231impl super::SubgroupOperation {
232 const fn required_operations(&self) -> SubgroupOperationSet {
233 use SubgroupOperationSet as S;
234 match *self {
235 Self::All | Self::Any => S::VOTE,
236 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
237 S::ARITHMETIC
238 }
239 }
240 }
241}
242
243impl super::GatherMode {
244 const fn required_operations(&self) -> SubgroupOperationSet {
245 use SubgroupOperationSet as S;
246 match *self {
247 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
248 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
249 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
250 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
251 }
252 }
253}
254
255bitflags::bitflags! {
256 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
258 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
259 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
260 pub struct ShaderStages: u8 {
261 const VERTEX = 0x1;
262 const FRAGMENT = 0x2;
263 const COMPUTE = 0x4;
264 }
265}
266
267#[derive(Debug, Clone, Default)]
268#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
269#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
270pub struct ModuleInfo {
271 type_flags: Vec<TypeFlags>,
272 functions: Vec<FunctionInfo>,
273 entry_points: Vec<FunctionInfo>,
274 const_expression_types: Box<[TypeResolution]>,
275}
276
277impl ops::Index<Handle<crate::Type>> for ModuleInfo {
278 type Output = TypeFlags;
279 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
280 &self.type_flags[handle.index()]
281 }
282}
283
284impl ops::Index<Handle<crate::Function>> for ModuleInfo {
285 type Output = FunctionInfo;
286 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
287 &self.functions[handle.index()]
288 }
289}
290
291impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
292 type Output = TypeResolution;
293 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
294 &self.const_expression_types[handle.index()]
295 }
296}
297
298#[derive(Debug)]
299pub struct Validator {
300 flags: ValidationFlags,
301 capabilities: Capabilities,
302 subgroup_stages: ShaderStages,
303 subgroup_operations: SubgroupOperationSet,
304 types: Vec<r#type::TypeInfo>,
305 layouter: Layouter,
306 location_mask: BitSet,
307 blend_src_mask: BitSet,
308 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
309 #[allow(dead_code)]
310 switch_values: FastHashSet<crate::SwitchValue>,
311 valid_expression_list: Vec<Handle<crate::Expression>>,
312 valid_expression_set: HandleSet<crate::Expression>,
313 override_ids: FastHashSet<u16>,
314
315 overrides_resolved: bool,
318
319 needs_visit: HandleSet<crate::Expression>,
338}
339
340#[derive(Clone, Debug, thiserror::Error)]
341#[cfg_attr(test, derive(PartialEq))]
342pub enum ConstantError {
343 #[error("Initializer must be a const-expression")]
344 InitializerExprType,
345 #[error("The type doesn't match the constant")]
346 InvalidType,
347 #[error("The type is not constructible")]
348 NonConstructibleType,
349}
350
351#[derive(Clone, Debug, thiserror::Error)]
352#[cfg_attr(test, derive(PartialEq))]
353pub enum OverrideError {
354 #[error("Override name and ID are missing")]
355 MissingNameAndID,
356 #[error("Override ID must be unique")]
357 DuplicateID,
358 #[error("Initializer must be a const-expression or override-expression")]
359 InitializerExprType,
360 #[error("The type doesn't match the override")]
361 InvalidType,
362 #[error("The type is not constructible")]
363 NonConstructibleType,
364 #[error("The type is not a scalar")]
365 TypeNotScalar,
366 #[error("Override declarations are not allowed")]
367 NotAllowed,
368 #[error("Override is uninitialized")]
369 UninitializedOverride,
370 #[error("Constant expression {handle:?} is invalid")]
371 ConstExpression {
372 handle: Handle<crate::Expression>,
373 source: ConstExpressionError,
374 },
375}
376
377#[derive(Clone, Debug, thiserror::Error)]
378#[cfg_attr(test, derive(PartialEq))]
379pub enum ValidationError {
380 #[error(transparent)]
381 InvalidHandle(#[from] InvalidHandleError),
382 #[error(transparent)]
383 Layouter(#[from] LayoutError),
384 #[error("Type {handle:?} '{name}' is invalid")]
385 Type {
386 handle: Handle<crate::Type>,
387 name: String,
388 source: TypeError,
389 },
390 #[error("Constant expression {handle:?} is invalid")]
391 ConstExpression {
392 handle: Handle<crate::Expression>,
393 source: ConstExpressionError,
394 },
395 #[error("Array size expression {handle:?} is not strictly positive")]
396 ArraySizeError { handle: Handle<crate::Expression> },
397 #[error("Constant {handle:?} '{name}' is invalid")]
398 Constant {
399 handle: Handle<crate::Constant>,
400 name: String,
401 source: ConstantError,
402 },
403 #[error("Override {handle:?} '{name}' is invalid")]
404 Override {
405 handle: Handle<crate::Override>,
406 name: String,
407 source: OverrideError,
408 },
409 #[error("Global variable {handle:?} '{name}' is invalid")]
410 GlobalVariable {
411 handle: Handle<crate::GlobalVariable>,
412 name: String,
413 source: GlobalVariableError,
414 },
415 #[error("Function {handle:?} '{name}' is invalid")]
416 Function {
417 handle: Handle<crate::Function>,
418 name: String,
419 source: FunctionError,
420 },
421 #[error("Entry point {name} at {stage:?} is invalid")]
422 EntryPoint {
423 stage: crate::ShaderStage,
424 name: String,
425 source: EntryPointError,
426 },
427 #[error("Module is corrupted")]
428 Corrupted,
429}
430
431impl crate::TypeInner {
432 const fn is_sized(&self) -> bool {
433 match *self {
434 Self::Scalar { .. }
435 | Self::Vector { .. }
436 | Self::Matrix { .. }
437 | Self::Array {
438 size: crate::ArraySize::Constant(_),
439 ..
440 }
441 | Self::Atomic { .. }
442 | Self::Pointer { .. }
443 | Self::ValuePointer { .. }
444 | Self::Struct { .. } => true,
445 Self::Array { .. }
446 | Self::Image { .. }
447 | Self::Sampler { .. }
448 | Self::AccelerationStructure { .. }
449 | Self::RayQuery { .. }
450 | Self::BindingArray { .. } => false,
451 }
452 }
453
454 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
456 match *self {
457 Self::Scalar(crate::Scalar {
458 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
459 ..
460 }) => Some(crate::ImageDimension::D1),
461 Self::Vector {
462 size: crate::VectorSize::Bi,
463 scalar:
464 crate::Scalar {
465 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
466 ..
467 },
468 } => Some(crate::ImageDimension::D2),
469 Self::Vector {
470 size: crate::VectorSize::Tri,
471 scalar:
472 crate::Scalar {
473 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
474 ..
475 },
476 } => Some(crate::ImageDimension::D3),
477 _ => None,
478 }
479 }
480}
481
482impl Validator {
483 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
497 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
498 use SubgroupOperationSet as S;
499 S::BASIC
500 | S::VOTE
501 | S::ARITHMETIC
502 | S::BALLOT
503 | S::SHUFFLE
504 | S::SHUFFLE_RELATIVE
505 | S::QUAD_FRAGMENT_COMPUTE
506 } else {
507 SubgroupOperationSet::empty()
508 };
509 let subgroup_stages = {
510 let mut stages = ShaderStages::empty();
511 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
512 stages |= ShaderStages::VERTEX;
513 }
514 if capabilities.contains(Capabilities::SUBGROUP) {
515 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
516 }
517 stages
518 };
519
520 Validator {
521 flags,
522 capabilities,
523 subgroup_stages,
524 subgroup_operations,
525 types: Vec::new(),
526 layouter: Layouter::default(),
527 location_mask: BitSet::new(),
528 blend_src_mask: BitSet::new(),
529 ep_resource_bindings: FastHashSet::default(),
530 switch_values: FastHashSet::default(),
531 valid_expression_list: Vec::new(),
532 valid_expression_set: HandleSet::new(),
533 override_ids: FastHashSet::default(),
534 overrides_resolved: false,
535 needs_visit: HandleSet::new(),
536 }
537 }
538
539 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
540 self.subgroup_stages = stages;
541 self
542 }
543
544 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
545 self.subgroup_operations = operations;
546 self
547 }
548
549 pub fn reset(&mut self) {
551 self.types.clear();
552 self.layouter.clear();
553 self.location_mask.clear();
554 self.blend_src_mask.clear();
555 self.ep_resource_bindings.clear();
556 self.switch_values.clear();
557 self.valid_expression_list.clear();
558 self.valid_expression_set.clear();
559 self.override_ids.clear();
560 }
561
562 fn validate_constant(
563 &self,
564 handle: Handle<crate::Constant>,
565 gctx: crate::proc::GlobalCtx,
566 mod_info: &ModuleInfo,
567 global_expr_kind: &ExpressionKindTracker,
568 ) -> Result<(), ConstantError> {
569 let con = &gctx.constants[handle];
570
571 let type_info = &self.types[con.ty.index()];
572 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
573 return Err(ConstantError::NonConstructibleType);
574 }
575
576 if !global_expr_kind.is_const(con.init) {
577 return Err(ConstantError::InitializerExprType);
578 }
579
580 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
581 return Err(ConstantError::InvalidType);
582 }
583
584 Ok(())
585 }
586
587 fn validate_override(
588 &mut self,
589 handle: Handle<crate::Override>,
590 gctx: crate::proc::GlobalCtx,
591 mod_info: &ModuleInfo,
592 ) -> Result<(), OverrideError> {
593 let o = &gctx.overrides[handle];
594
595 if let Some(id) = o.id {
596 if !self.override_ids.insert(id) {
597 return Err(OverrideError::DuplicateID);
598 }
599 }
600
601 let type_info = &self.types[o.ty.index()];
602 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
603 return Err(OverrideError::NonConstructibleType);
604 }
605
606 match gctx.types[o.ty].inner {
607 crate::TypeInner::Scalar(
608 crate::Scalar::BOOL
609 | crate::Scalar::I32
610 | crate::Scalar::U32
611 | crate::Scalar::F16
612 | crate::Scalar::F32
613 | crate::Scalar::F64,
614 ) => {}
615 _ => return Err(OverrideError::TypeNotScalar),
616 }
617
618 if let Some(init) = o.init {
619 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
620 return Err(OverrideError::InvalidType);
621 }
622 } else if self.overrides_resolved {
623 return Err(OverrideError::UninitializedOverride);
624 }
625
626 Ok(())
627 }
628
629 pub fn validate(
631 &mut self,
632 module: &crate::Module,
633 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
634 self.overrides_resolved = false;
635 self.validate_impl(module)
636 }
637
638 pub fn validate_resolved_overrides(
646 &mut self,
647 module: &crate::Module,
648 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
649 self.overrides_resolved = true;
650 self.validate_impl(module)
651 }
652
653 fn validate_impl(
654 &mut self,
655 module: &crate::Module,
656 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
657 self.reset();
658 self.reset_types(module.types.len());
659
660 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
661
662 self.layouter.update(module.to_ctx()).map_err(|e| {
663 let handle = e.ty;
664 ValidationError::from(e).with_span_handle(handle, &module.types)
665 })?;
666
667 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
669 kind: crate::ScalarKind::Bool,
670 width: 0,
671 }));
672
673 let mut mod_info = ModuleInfo {
674 type_flags: Vec::with_capacity(module.types.len()),
675 functions: Vec::with_capacity(module.functions.len()),
676 entry_points: Vec::with_capacity(module.entry_points.len()),
677 const_expression_types: vec![placeholder; module.global_expressions.len()]
678 .into_boxed_slice(),
679 };
680
681 for (handle, ty) in module.types.iter() {
682 let ty_info = self
683 .validate_type(handle, module.to_ctx())
684 .map_err(|source| {
685 ValidationError::Type {
686 handle,
687 name: ty.name.clone().unwrap_or_default(),
688 source,
689 }
690 .with_span_handle(handle, &module.types)
691 })?;
692 mod_info.type_flags.push(ty_info.flags);
693 self.types[handle.index()] = ty_info;
694 }
695
696 {
697 let t = crate::Arena::new();
698 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
699 for (handle, _) in module.global_expressions.iter() {
700 mod_info
701 .process_const_expression(handle, &resolve_context, module.to_ctx())
702 .map_err(|source| {
703 ValidationError::ConstExpression { handle, source }
704 .with_span_handle(handle, &module.global_expressions)
705 })?
706 }
707 }
708
709 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
710
711 if self.flags.contains(ValidationFlags::CONSTANTS) {
712 for (handle, _) in module.global_expressions.iter() {
713 self.validate_const_expression(
714 handle,
715 module.to_ctx(),
716 &mod_info,
717 &global_expr_kind,
718 )
719 .map_err(|source| {
720 ValidationError::ConstExpression { handle, source }
721 .with_span_handle(handle, &module.global_expressions)
722 })?
723 }
724
725 for (handle, constant) in module.constants.iter() {
726 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
727 .map_err(|source| {
728 ValidationError::Constant {
729 handle,
730 name: constant.name.clone().unwrap_or_default(),
731 source,
732 }
733 .with_span_handle(handle, &module.constants)
734 })?
735 }
736
737 for (handle, r#override) in module.overrides.iter() {
738 self.validate_override(handle, module.to_ctx(), &mod_info)
739 .map_err(|source| {
740 ValidationError::Override {
741 handle,
742 name: r#override.name.clone().unwrap_or_default(),
743 source,
744 }
745 .with_span_handle(handle, &module.overrides)
746 })?;
747 }
748 }
749
750 for (var_handle, var) in module.global_variables.iter() {
751 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
752 .map_err(|source| {
753 ValidationError::GlobalVariable {
754 handle: var_handle,
755 name: var.name.clone().unwrap_or_default(),
756 source,
757 }
758 .with_span_handle(var_handle, &module.global_variables)
759 })?;
760 }
761
762 for (handle, fun) in module.functions.iter() {
763 match self.validate_function(fun, module, &mod_info, false) {
764 Ok(info) => mod_info.functions.push(info),
765 Err(error) => {
766 return Err(error.and_then(|source| {
767 ValidationError::Function {
768 handle,
769 name: fun.name.clone().unwrap_or_default(),
770 source,
771 }
772 .with_span_handle(handle, &module.functions)
773 }))
774 }
775 }
776 }
777
778 let mut ep_map = FastHashSet::default();
779 for ep in module.entry_points.iter() {
780 if !ep_map.insert((ep.stage, &ep.name)) {
781 return Err(ValidationError::EntryPoint {
782 stage: ep.stage,
783 name: ep.name.clone(),
784 source: EntryPointError::Conflict,
785 }
786 .with_span()); }
788
789 match self.validate_entry_point(ep, module, &mod_info) {
790 Ok(info) => mod_info.entry_points.push(info),
791 Err(error) => {
792 return Err(error.and_then(|source| {
793 ValidationError::EntryPoint {
794 stage: ep.stage,
795 name: ep.name.clone(),
796 source,
797 }
798 .with_span()
799 }));
800 }
801 }
802 }
803
804 Ok(mod_info)
805 }
806}
807
808fn validate_atomic_compare_exchange_struct(
809 types: &crate::UniqueArena<crate::Type>,
810 members: &[crate::StructMember],
811 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
812) -> bool {
813 members.len() == 2
814 && members[0].name.as_deref() == Some("old_value")
815 && scalar_predicate(&types[members[0].ty].inner)
816 && members[1].name.as_deref() == Some("exchanged")
817 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
818}