mod analyzer;
mod compose;
mod expression;
mod function;
mod handles;
mod interface;
mod r#type;
use crate::{
arena::{Handle, HandleSet},
proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
FastHashSet,
};
use bit_set::BitSet;
use std::ops;
use crate::span::{AddSpan as _, WithSpan};
pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
pub use compose::ComposeError;
pub use expression::{check_literal_value, LiteralError};
pub use expression::{ConstExpressionError, ExpressionError};
pub use function::{CallError, FunctionError, LocalVariableError};
pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError};
use self::handles::InvalidHandleError;
bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ValidationFlags: u8 {
const EXPRESSIONS = 0x1;
const BLOCKS = 0x2;
const CONTROL_FLOW_UNIFORMITY = 0x4;
const STRUCT_LAYOUTS = 0x8;
const CONSTANTS = 0x10;
const BINDINGS = 0x20;
}
}
impl Default for ValidationFlags {
fn default() -> Self {
Self::all()
}
}
bitflags::bitflags! {
#[must_use]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Capabilities: u32 {
const PUSH_CONSTANT = 0x1;
const FLOAT64 = 0x2;
const PRIMITIVE_INDEX = 0x4;
const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
const CLIP_DISTANCE = 0x40;
const CULL_DISTANCE = 0x80;
const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
const MULTIVIEW = 0x200;
const EARLY_DEPTH_TEST = 0x400;
const MULTISAMPLED_SHADING = 0x800;
const RAY_QUERY = 0x1000;
const DUAL_SOURCE_BLENDING = 0x2000;
const CUBE_ARRAY_TEXTURES = 0x4000;
const SHADER_INT64 = 0x8000;
const SUBGROUP = 0x10000;
const SUBGROUP_BARRIER = 0x20000;
const SUBGROUP_VERTEX_STAGE = 0x40000;
const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000;
const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000;
}
}
impl Default for Capabilities {
fn default() -> Self {
Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
}
}
bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct SubgroupOperationSet: u8 {
const BASIC = 1 << 0;
const VOTE = 1 << 1;
const ARITHMETIC = 1 << 2;
const BALLOT = 1 << 3;
const SHUFFLE = 1 << 4;
const SHUFFLE_RELATIVE = 1 << 5;
}
}
impl super::SubgroupOperation {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::All | Self::Any => S::VOTE,
Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
S::ARITHMETIC
}
}
}
}
impl super::GatherMode {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
}
}
}
bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ShaderStages: u8 {
const VERTEX = 0x1;
const FRAGMENT = 0x2;
const COMPUTE = 0x4;
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
type_flags: Vec<TypeFlags>,
functions: Vec<FunctionInfo>,
entry_points: Vec<FunctionInfo>,
const_expression_types: Box<[TypeResolution]>,
}
impl ops::Index<Handle<crate::Type>> for ModuleInfo {
type Output = TypeFlags;
fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
&self.type_flags[handle.index()]
}
}
impl ops::Index<Handle<crate::Function>> for ModuleInfo {
type Output = FunctionInfo;
fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
&self.functions[handle.index()]
}
}
impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
type Output = TypeResolution;
fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
&self.const_expression_types[handle.index()]
}
}
#[derive(Debug)]
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
subgroup_stages: ShaderStages,
subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
#[allow(dead_code)]
switch_values: FastHashSet<crate::SwitchValue>,
valid_expression_list: Vec<Handle<crate::Expression>>,
valid_expression_set: HandleSet<crate::Expression>,
override_ids: FastHashSet<u16>,
allow_overrides: bool,
needs_visit: HandleSet<crate::Expression>,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantError {
#[error("Initializer must be a const-expression")]
InitializerExprType,
#[error("The type doesn't match the constant")]
InvalidType,
#[error("The type is not constructible")]
NonConstructibleType,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum OverrideError {
#[error("Override name and ID are missing")]
MissingNameAndID,
#[error("Override ID must be unique")]
DuplicateID,
#[error("Initializer must be a const-expression or override-expression")]
InitializerExprType,
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
NonConstructibleType,
#[error("The type is not a scalar")]
TypeNotScalar,
#[error("Override declarations are not allowed")]
NotAllowed,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ValidationError {
#[error(transparent)]
InvalidHandle(#[from] InvalidHandleError),
#[error(transparent)]
Layouter(#[from] LayoutError),
#[error("Type {handle:?} '{name}' is invalid")]
Type {
handle: Handle<crate::Type>,
name: String,
source: TypeError,
},
#[error("Constant expression {handle:?} is invalid")]
ConstExpression {
handle: Handle<crate::Expression>,
source: ConstExpressionError,
},
#[error("Array size expression {handle:?} is not strictly positive")]
ArraySizeError { handle: Handle<crate::Expression> },
#[error("Constant {handle:?} '{name}' is invalid")]
Constant {
handle: Handle<crate::Constant>,
name: String,
source: ConstantError,
},
#[error("Override {handle:?} '{name}' is invalid")]
Override {
handle: Handle<crate::Override>,
name: String,
source: OverrideError,
},
#[error("Global variable {handle:?} '{name}' is invalid")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
name: String,
source: GlobalVariableError,
},
#[error("Function {handle:?} '{name}' is invalid")]
Function {
handle: Handle<crate::Function>,
name: String,
source: FunctionError,
},
#[error("Entry point {name} at {stage:?} is invalid")]
EntryPoint {
stage: crate::ShaderStage,
name: String,
source: EntryPointError,
},
#[error("Module is corrupted")]
Corrupted,
}
impl crate::TypeInner {
const fn is_sized(&self) -> bool {
match *self {
Self::Scalar { .. }
| Self::Vector { .. }
| Self::Matrix { .. }
| Self::Array {
size: crate::ArraySize::Constant(_),
..
}
| Self::Atomic { .. }
| Self::Pointer { .. }
| Self::ValuePointer { .. }
| Self::Struct { .. } => true,
Self::Array { .. }
| Self::Image { .. }
| Self::Sampler { .. }
| Self::AccelerationStructure
| Self::RayQuery
| Self::BindingArray { .. } => false,
}
}
const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
match *self {
Self::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
}) => Some(crate::ImageDimension::D1),
Self::Vector {
size: crate::VectorSize::Bi,
scalar:
crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
} => Some(crate::ImageDimension::D2),
Self::Vector {
size: crate::VectorSize::Tri,
scalar:
crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
} => Some(crate::ImageDimension::D3),
_ => None,
}
}
}
impl Validator {
pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
use SubgroupOperationSet as S;
S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
} else {
SubgroupOperationSet::empty()
};
let subgroup_stages = {
let mut stages = ShaderStages::empty();
if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
stages |= ShaderStages::VERTEX;
}
if capabilities.contains(Capabilities::SUBGROUP) {
stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
}
stages
};
Validator {
flags,
capabilities,
subgroup_stages,
subgroup_operations,
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
ep_resource_bindings: FastHashSet::default(),
switch_values: FastHashSet::default(),
valid_expression_list: Vec::new(),
valid_expression_set: HandleSet::new(),
override_ids: FastHashSet::default(),
allow_overrides: true,
needs_visit: HandleSet::new(),
}
}
pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
self.subgroup_stages = stages;
self
}
pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
self.subgroup_operations = operations;
self
}
pub fn reset(&mut self) {
self.types.clear();
self.layouter.clear();
self.location_mask.clear();
self.ep_resource_bindings.clear();
self.switch_values.clear();
self.valid_expression_list.clear();
self.valid_expression_set.clear();
self.override_ids.clear();
}
fn validate_constant(
&self,
handle: Handle<crate::Constant>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
global_expr_kind: &ExpressionKindTracker,
) -> Result<(), ConstantError> {
let con = &gctx.constants[handle];
let type_info = &self.types[con.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(ConstantError::NonConstructibleType);
}
if !global_expr_kind.is_const(con.init) {
return Err(ConstantError::InitializerExprType);
}
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
return Err(ConstantError::InvalidType);
}
Ok(())
}
fn validate_override(
&mut self,
handle: Handle<crate::Override>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
) -> Result<(), OverrideError> {
if !self.allow_overrides {
return Err(OverrideError::NotAllowed);
}
let o = &gctx.overrides[handle];
if o.name.is_none() && o.id.is_none() {
return Err(OverrideError::MissingNameAndID);
}
if let Some(id) = o.id {
if !self.override_ids.insert(id) {
return Err(OverrideError::DuplicateID);
}
}
let type_info = &self.types[o.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(OverrideError::NonConstructibleType);
}
let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(
crate::Scalar::BOOL
| crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::F32
| crate::Scalar::F64,
) => {}
_ => return Err(OverrideError::TypeNotScalar),
}
if let Some(init) = o.init {
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
return Err(OverrideError::InvalidType);
}
}
Ok(())
}
pub fn validate(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.allow_overrides = true;
self.validate_impl(module)
}
pub fn validate_no_overrides(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.allow_overrides = false;
self.validate_impl(module)
}
fn validate_impl(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.reset();
self.reset_types(module.types.len());
Self::validate_module_handles(module).map_err(|e| e.with_span())?;
self.layouter.update(module.to_ctx()).map_err(|e| {
let handle = e.ty;
ValidationError::from(e).with_span_handle(handle, &module.types)
})?;
let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Bool,
width: 0,
}));
let mut mod_info = ModuleInfo {
type_flags: Vec::with_capacity(module.types.len()),
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
const_expression_types: vec![placeholder; module.global_expressions.len()]
.into_boxed_slice(),
};
for (handle, ty) in module.types.iter() {
let ty_info = self
.validate_type(handle, module.to_ctx())
.map_err(|source| {
ValidationError::Type {
handle,
name: ty.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(handle, &module.types)
})?;
if !self.allow_overrides {
if let crate::TypeInner::Array {
size: crate::ArraySize::Pending(_),
..
} = ty.inner
{
return Err((ValidationError::Type {
handle,
name: ty.name.clone().unwrap_or_default(),
source: TypeError::UnresolvedOverride(handle),
})
.with_span_handle(handle, &module.types));
}
}
mod_info.type_flags.push(ty_info.flags);
self.types[handle.index()] = ty_info;
}
{
let t = crate::Arena::new();
let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
for (handle, _) in module.global_expressions.iter() {
mod_info
.process_const_expression(handle, &resolve_context, module.to_ctx())
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
.with_span_handle(handle, &module.global_expressions)
})?
}
}
let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
if self.flags.contains(ValidationFlags::CONSTANTS) {
for (handle, _) in module.global_expressions.iter() {
self.validate_const_expression(
handle,
module.to_ctx(),
&mod_info,
&global_expr_kind,
)
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
.with_span_handle(handle, &module.global_expressions)
})?
}
for (handle, constant) in module.constants.iter() {
self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::Constant {
handle,
name: constant.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(handle, &module.constants)
})?
}
for (handle, override_) in module.overrides.iter() {
self.validate_override(handle, module.to_ctx(), &mod_info)
.map_err(|source| {
ValidationError::Override {
handle,
name: override_.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(handle, &module.overrides)
})?
}
}
for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::GlobalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(var_handle, &module.global_variables)
})?;
}
for (handle, fun) in module.functions.iter() {
match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(error.and_then(|source| {
ValidationError::Function {
handle,
name: fun.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(handle, &module.functions)
}))
}
}
}
let mut ep_map = FastHashSet::default();
for ep in module.entry_points.iter() {
if !ep_map.insert((ep.stage, &ep.name)) {
return Err(ValidationError::EntryPoint {
stage: ep.stage,
name: ep.name.clone(),
source: EntryPointError::Conflict,
}
.with_span()); }
match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(error.and_then(|source| {
ValidationError::EntryPoint {
stage: ep.stage,
name: ep.name.clone(),
source,
}
.with_span()
}));
}
}
}
Ok(mod_info)
}
}
fn validate_atomic_compare_exchange_struct(
types: &crate::UniqueArena<crate::Type>,
members: &[crate::StructMember],
scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
) -> bool {
members.len() == 2
&& members[0].name.as_deref() == Some("old_value")
&& scalar_predicate(&types[members[0].ty].inner)
&& members[1].name.as_deref() == Some("exchanged")
&& types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
}