1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38pub const MAX_TYPE_SIZE: u32 = 0x4000_0000; bitflags::bitflags! {
42 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
56 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
57 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
58 pub struct ValidationFlags: u8 {
59 const EXPRESSIONS = 0x1;
61 const BLOCKS = 0x2;
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
65 const STRUCT_LAYOUTS = 0x8;
67 const CONSTANTS = 0x10;
69 const BINDINGS = 0x20;
71 }
72}
73
74impl Default for ValidationFlags {
75 fn default() -> Self {
76 Self::all()
77 }
78}
79
80bitflags::bitflags! {
81 #[must_use]
83 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
84 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
85 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
86 pub struct Capabilities: u32 {
87 const PUSH_CONSTANT = 1 << 0;
91 const FLOAT64 = 1 << 1;
93 const PRIMITIVE_INDEX = 1 << 2;
97 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
99 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
101 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
103 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
105 const CLIP_DISTANCE = 1 << 7;
109 const CULL_DISTANCE = 1 << 8;
113 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
115 const MULTIVIEW = 1 << 10;
119 const EARLY_DEPTH_TEST = 1 << 11;
121 const MULTISAMPLED_SHADING = 1 << 12;
126 const RAY_QUERY = 1 << 13;
128 const DUAL_SOURCE_BLENDING = 1 << 14;
130 const CUBE_ARRAY_TEXTURES = 1 << 15;
132 const SHADER_INT64 = 1 << 16;
134 const SUBGROUP = 1 << 17;
138 const SUBGROUP_BARRIER = 1 << 18;
140 const SUBGROUP_VERTEX_STAGE = 1 << 19;
142 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
152 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
154 const SHADER_FLOAT32_ATOMIC = 1 << 22;
163 const TEXTURE_ATOMIC = 1 << 23;
165 const TEXTURE_INT64_ATOMIC = 1 << 24;
167 const RAY_HIT_VERTEX_POSITION = 1 << 25;
169 const SHADER_FLOAT16 = 1 << 26;
171 const TEXTURE_EXTERNAL = 1 << 27;
173 }
174}
175
176impl Default for Capabilities {
177 fn default() -> Self {
178 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
179 }
180}
181
182bitflags::bitflags! {
183 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
185 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
186 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
187 pub struct SubgroupOperationSet: u8 {
188 const BASIC = 1 << 0;
190 const VOTE = 1 << 1;
192 const ARITHMETIC = 1 << 2;
194 const BALLOT = 1 << 3;
196 const SHUFFLE = 1 << 4;
198 const SHUFFLE_RELATIVE = 1 << 5;
200 const QUAD_FRAGMENT_COMPUTE = 1 << 7;
205 }
208}
209
210impl super::SubgroupOperation {
211 const fn required_operations(&self) -> SubgroupOperationSet {
212 use SubgroupOperationSet as S;
213 match *self {
214 Self::All | Self::Any => S::VOTE,
215 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
216 S::ARITHMETIC
217 }
218 }
219 }
220}
221
222impl super::GatherMode {
223 const fn required_operations(&self) -> SubgroupOperationSet {
224 use SubgroupOperationSet as S;
225 match *self {
226 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
227 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
228 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
229 Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
230 }
231 }
232}
233
234bitflags::bitflags! {
235 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
237 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
238 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
239 pub struct ShaderStages: u8 {
240 const VERTEX = 0x1;
241 const FRAGMENT = 0x2;
242 const COMPUTE = 0x4;
243 }
244}
245
246#[derive(Debug, Clone, Default)]
247#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
248#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
249pub struct ModuleInfo {
250 type_flags: Vec<TypeFlags>,
251 functions: Vec<FunctionInfo>,
252 entry_points: Vec<FunctionInfo>,
253 const_expression_types: Box<[TypeResolution]>,
254}
255
256impl ops::Index<Handle<crate::Type>> for ModuleInfo {
257 type Output = TypeFlags;
258 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
259 &self.type_flags[handle.index()]
260 }
261}
262
263impl ops::Index<Handle<crate::Function>> for ModuleInfo {
264 type Output = FunctionInfo;
265 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
266 &self.functions[handle.index()]
267 }
268}
269
270impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
271 type Output = TypeResolution;
272 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
273 &self.const_expression_types[handle.index()]
274 }
275}
276
277#[derive(Debug)]
278pub struct Validator {
279 flags: ValidationFlags,
280 capabilities: Capabilities,
281 subgroup_stages: ShaderStages,
282 subgroup_operations: SubgroupOperationSet,
283 types: Vec<r#type::TypeInfo>,
284 layouter: Layouter,
285 location_mask: BitSet,
286 blend_src_mask: BitSet,
287 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
288 #[allow(dead_code)]
289 switch_values: FastHashSet<crate::SwitchValue>,
290 valid_expression_list: Vec<Handle<crate::Expression>>,
291 valid_expression_set: HandleSet<crate::Expression>,
292 override_ids: FastHashSet<u16>,
293
294 overrides_resolved: bool,
297
298 needs_visit: HandleSet<crate::Expression>,
317}
318
319#[derive(Clone, Debug, thiserror::Error)]
320#[cfg_attr(test, derive(PartialEq))]
321pub enum ConstantError {
322 #[error("Initializer must be a const-expression")]
323 InitializerExprType,
324 #[error("The type doesn't match the constant")]
325 InvalidType,
326 #[error("The type is not constructible")]
327 NonConstructibleType,
328}
329
330#[derive(Clone, Debug, thiserror::Error)]
331#[cfg_attr(test, derive(PartialEq))]
332pub enum OverrideError {
333 #[error("Override name and ID are missing")]
334 MissingNameAndID,
335 #[error("Override ID must be unique")]
336 DuplicateID,
337 #[error("Initializer must be a const-expression or override-expression")]
338 InitializerExprType,
339 #[error("The type doesn't match the override")]
340 InvalidType,
341 #[error("The type is not constructible")]
342 NonConstructibleType,
343 #[error("The type is not a scalar")]
344 TypeNotScalar,
345 #[error("Override declarations are not allowed")]
346 NotAllowed,
347 #[error("Override is uninitialized")]
348 UninitializedOverride,
349 #[error("Constant expression {handle:?} is invalid")]
350 ConstExpression {
351 handle: Handle<crate::Expression>,
352 source: ConstExpressionError,
353 },
354}
355
356#[derive(Clone, Debug, thiserror::Error)]
357#[cfg_attr(test, derive(PartialEq))]
358pub enum ValidationError {
359 #[error(transparent)]
360 InvalidHandle(#[from] InvalidHandleError),
361 #[error(transparent)]
362 Layouter(#[from] LayoutError),
363 #[error("Type {handle:?} '{name}' is invalid")]
364 Type {
365 handle: Handle<crate::Type>,
366 name: String,
367 source: TypeError,
368 },
369 #[error("Constant expression {handle:?} is invalid")]
370 ConstExpression {
371 handle: Handle<crate::Expression>,
372 source: ConstExpressionError,
373 },
374 #[error("Array size expression {handle:?} is not strictly positive")]
375 ArraySizeError { handle: Handle<crate::Expression> },
376 #[error("Constant {handle:?} '{name}' is invalid")]
377 Constant {
378 handle: Handle<crate::Constant>,
379 name: String,
380 source: ConstantError,
381 },
382 #[error("Override {handle:?} '{name}' is invalid")]
383 Override {
384 handle: Handle<crate::Override>,
385 name: String,
386 source: OverrideError,
387 },
388 #[error("Global variable {handle:?} '{name}' is invalid")]
389 GlobalVariable {
390 handle: Handle<crate::GlobalVariable>,
391 name: String,
392 source: GlobalVariableError,
393 },
394 #[error("Function {handle:?} '{name}' is invalid")]
395 Function {
396 handle: Handle<crate::Function>,
397 name: String,
398 source: FunctionError,
399 },
400 #[error("Entry point {name} at {stage:?} is invalid")]
401 EntryPoint {
402 stage: crate::ShaderStage,
403 name: String,
404 source: EntryPointError,
405 },
406 #[error("Module is corrupted")]
407 Corrupted,
408}
409
410impl crate::TypeInner {
411 const fn is_sized(&self) -> bool {
412 match *self {
413 Self::Scalar { .. }
414 | Self::Vector { .. }
415 | Self::Matrix { .. }
416 | Self::Array {
417 size: crate::ArraySize::Constant(_),
418 ..
419 }
420 | Self::Atomic { .. }
421 | Self::Pointer { .. }
422 | Self::ValuePointer { .. }
423 | Self::Struct { .. } => true,
424 Self::Array { .. }
425 | Self::Image { .. }
426 | Self::Sampler { .. }
427 | Self::AccelerationStructure { .. }
428 | Self::RayQuery { .. }
429 | Self::BindingArray { .. } => false,
430 }
431 }
432
433 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
435 match *self {
436 Self::Scalar(crate::Scalar {
437 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
438 ..
439 }) => Some(crate::ImageDimension::D1),
440 Self::Vector {
441 size: crate::VectorSize::Bi,
442 scalar:
443 crate::Scalar {
444 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
445 ..
446 },
447 } => Some(crate::ImageDimension::D2),
448 Self::Vector {
449 size: crate::VectorSize::Tri,
450 scalar:
451 crate::Scalar {
452 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
453 ..
454 },
455 } => Some(crate::ImageDimension::D3),
456 _ => None,
457 }
458 }
459}
460
461impl Validator {
462 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
476 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
477 use SubgroupOperationSet as S;
478 S::BASIC
479 | S::VOTE
480 | S::ARITHMETIC
481 | S::BALLOT
482 | S::SHUFFLE
483 | S::SHUFFLE_RELATIVE
484 | S::QUAD_FRAGMENT_COMPUTE
485 } else {
486 SubgroupOperationSet::empty()
487 };
488 let subgroup_stages = {
489 let mut stages = ShaderStages::empty();
490 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
491 stages |= ShaderStages::VERTEX;
492 }
493 if capabilities.contains(Capabilities::SUBGROUP) {
494 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
495 }
496 stages
497 };
498
499 Validator {
500 flags,
501 capabilities,
502 subgroup_stages,
503 subgroup_operations,
504 types: Vec::new(),
505 layouter: Layouter::default(),
506 location_mask: BitSet::new(),
507 blend_src_mask: BitSet::new(),
508 ep_resource_bindings: FastHashSet::default(),
509 switch_values: FastHashSet::default(),
510 valid_expression_list: Vec::new(),
511 valid_expression_set: HandleSet::new(),
512 override_ids: FastHashSet::default(),
513 overrides_resolved: false,
514 needs_visit: HandleSet::new(),
515 }
516 }
517
518 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
519 self.subgroup_stages = stages;
520 self
521 }
522
523 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
524 self.subgroup_operations = operations;
525 self
526 }
527
528 pub fn reset(&mut self) {
530 self.types.clear();
531 self.layouter.clear();
532 self.location_mask.clear();
533 self.blend_src_mask.clear();
534 self.ep_resource_bindings.clear();
535 self.switch_values.clear();
536 self.valid_expression_list.clear();
537 self.valid_expression_set.clear();
538 self.override_ids.clear();
539 }
540
541 fn validate_constant(
542 &self,
543 handle: Handle<crate::Constant>,
544 gctx: crate::proc::GlobalCtx,
545 mod_info: &ModuleInfo,
546 global_expr_kind: &ExpressionKindTracker,
547 ) -> Result<(), ConstantError> {
548 let con = &gctx.constants[handle];
549
550 let type_info = &self.types[con.ty.index()];
551 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
552 return Err(ConstantError::NonConstructibleType);
553 }
554
555 if !global_expr_kind.is_const(con.init) {
556 return Err(ConstantError::InitializerExprType);
557 }
558
559 if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
560 return Err(ConstantError::InvalidType);
561 }
562
563 Ok(())
564 }
565
566 fn validate_override(
567 &mut self,
568 handle: Handle<crate::Override>,
569 gctx: crate::proc::GlobalCtx,
570 mod_info: &ModuleInfo,
571 ) -> Result<(), OverrideError> {
572 let o = &gctx.overrides[handle];
573
574 if let Some(id) = o.id {
575 if !self.override_ids.insert(id) {
576 return Err(OverrideError::DuplicateID);
577 }
578 }
579
580 let type_info = &self.types[o.ty.index()];
581 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
582 return Err(OverrideError::NonConstructibleType);
583 }
584
585 match gctx.types[o.ty].inner {
586 crate::TypeInner::Scalar(
587 crate::Scalar::BOOL
588 | crate::Scalar::I32
589 | crate::Scalar::U32
590 | crate::Scalar::F16
591 | crate::Scalar::F32
592 | crate::Scalar::F64,
593 ) => {}
594 _ => return Err(OverrideError::TypeNotScalar),
595 }
596
597 if let Some(init) = o.init {
598 if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
599 return Err(OverrideError::InvalidType);
600 }
601 } else if self.overrides_resolved {
602 return Err(OverrideError::UninitializedOverride);
603 }
604
605 Ok(())
606 }
607
608 pub fn validate(
610 &mut self,
611 module: &crate::Module,
612 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
613 self.overrides_resolved = false;
614 self.validate_impl(module)
615 }
616
617 pub fn validate_resolved_overrides(
625 &mut self,
626 module: &crate::Module,
627 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
628 self.overrides_resolved = true;
629 self.validate_impl(module)
630 }
631
632 fn validate_impl(
633 &mut self,
634 module: &crate::Module,
635 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
636 self.reset();
637 self.reset_types(module.types.len());
638
639 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
640
641 self.layouter.update(module.to_ctx()).map_err(|e| {
642 let handle = e.ty;
643 ValidationError::from(e).with_span_handle(handle, &module.types)
644 })?;
645
646 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
648 kind: crate::ScalarKind::Bool,
649 width: 0,
650 }));
651
652 let mut mod_info = ModuleInfo {
653 type_flags: Vec::with_capacity(module.types.len()),
654 functions: Vec::with_capacity(module.functions.len()),
655 entry_points: Vec::with_capacity(module.entry_points.len()),
656 const_expression_types: vec![placeholder; module.global_expressions.len()]
657 .into_boxed_slice(),
658 };
659
660 for (handle, ty) in module.types.iter() {
661 let ty_info = self
662 .validate_type(handle, module.to_ctx())
663 .map_err(|source| {
664 ValidationError::Type {
665 handle,
666 name: ty.name.clone().unwrap_or_default(),
667 source,
668 }
669 .with_span_handle(handle, &module.types)
670 })?;
671 mod_info.type_flags.push(ty_info.flags);
672 self.types[handle.index()] = ty_info;
673 }
674
675 {
676 let t = crate::Arena::new();
677 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
678 for (handle, _) in module.global_expressions.iter() {
679 mod_info
680 .process_const_expression(handle, &resolve_context, module.to_ctx())
681 .map_err(|source| {
682 ValidationError::ConstExpression { handle, source }
683 .with_span_handle(handle, &module.global_expressions)
684 })?
685 }
686 }
687
688 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
689
690 if self.flags.contains(ValidationFlags::CONSTANTS) {
691 for (handle, _) in module.global_expressions.iter() {
692 self.validate_const_expression(
693 handle,
694 module.to_ctx(),
695 &mod_info,
696 &global_expr_kind,
697 )
698 .map_err(|source| {
699 ValidationError::ConstExpression { handle, source }
700 .with_span_handle(handle, &module.global_expressions)
701 })?
702 }
703
704 for (handle, constant) in module.constants.iter() {
705 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
706 .map_err(|source| {
707 ValidationError::Constant {
708 handle,
709 name: constant.name.clone().unwrap_or_default(),
710 source,
711 }
712 .with_span_handle(handle, &module.constants)
713 })?
714 }
715
716 for (handle, r#override) in module.overrides.iter() {
717 self.validate_override(handle, module.to_ctx(), &mod_info)
718 .map_err(|source| {
719 ValidationError::Override {
720 handle,
721 name: r#override.name.clone().unwrap_or_default(),
722 source,
723 }
724 .with_span_handle(handle, &module.overrides)
725 })?;
726 }
727 }
728
729 for (var_handle, var) in module.global_variables.iter() {
730 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
731 .map_err(|source| {
732 ValidationError::GlobalVariable {
733 handle: var_handle,
734 name: var.name.clone().unwrap_or_default(),
735 source,
736 }
737 .with_span_handle(var_handle, &module.global_variables)
738 })?;
739 }
740
741 for (handle, fun) in module.functions.iter() {
742 match self.validate_function(fun, module, &mod_info, false) {
743 Ok(info) => mod_info.functions.push(info),
744 Err(error) => {
745 return Err(error.and_then(|source| {
746 ValidationError::Function {
747 handle,
748 name: fun.name.clone().unwrap_or_default(),
749 source,
750 }
751 .with_span_handle(handle, &module.functions)
752 }))
753 }
754 }
755 }
756
757 let mut ep_map = FastHashSet::default();
758 for ep in module.entry_points.iter() {
759 if !ep_map.insert((ep.stage, &ep.name)) {
760 return Err(ValidationError::EntryPoint {
761 stage: ep.stage,
762 name: ep.name.clone(),
763 source: EntryPointError::Conflict,
764 }
765 .with_span()); }
767
768 match self.validate_entry_point(ep, module, &mod_info) {
769 Ok(info) => mod_info.entry_points.push(info),
770 Err(error) => {
771 return Err(error.and_then(|source| {
772 ValidationError::EntryPoint {
773 stage: ep.stage,
774 name: ep.name.clone(),
775 source,
776 }
777 .with_span()
778 }));
779 }
780 }
781 }
782
783 Ok(mod_info)
784 }
785}
786
787fn validate_atomic_compare_exchange_struct(
788 types: &crate::UniqueArena<crate::Type>,
789 members: &[crate::StructMember],
790 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
791) -> bool {
792 members.len() == 2
793 && members[0].name.as_deref() == Some("old_value")
794 && scalar_predicate(&types[members[0].ty].inner)
795 && members[1].name.as_deref() == Some("exchanged")
796 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
797}