1use alloc::{format, string::String};
2
3use super::{
4 analyzer::{UniformityDisruptor, UniformityRequirements},
5 ExpressionError, FunctionInfo, ModuleInfo,
6};
7use crate::arena::{Arena, UniqueArena};
8use crate::arena::{Handle, HandleSet};
9use crate::proc::TypeResolution;
10use crate::span::WithSpan;
11use crate::span::{AddSpan as _, MapErrWithSpan as _};
12
13#[derive(Clone, Debug, thiserror::Error)]
14#[cfg_attr(test, derive(PartialEq))]
15pub enum CallError {
16 #[error("Argument {index} expression is invalid")]
17 Argument {
18 index: usize,
19 source: ExpressionError,
20 },
21 #[error("Result expression {0:?} has already been introduced earlier")]
22 ResultAlreadyInScope(Handle<crate::Expression>),
23 #[error("Result expression {0:?} is populated by multiple `Call` statements")]
24 ResultAlreadyPopulated(Handle<crate::Expression>),
25 #[error("Requires {required} arguments, but {seen} are provided")]
26 ArgumentCount { required: usize, seen: usize },
27 #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
28 ArgumentType {
29 index: usize,
30 required: Handle<crate::Type>,
31 seen_expression: Handle<crate::Expression>,
32 },
33 #[error("The emitted expression doesn't match the call")]
34 ExpressionMismatch(Option<Handle<crate::Expression>>),
35}
36
37#[derive(Clone, Debug, thiserror::Error)]
38#[cfg_attr(test, derive(PartialEq))]
39pub enum AtomicError {
40 #[error("Pointer {0:?} to atomic is invalid.")]
41 InvalidPointer(Handle<crate::Expression>),
42 #[error("Address space {0:?} is not supported.")]
43 InvalidAddressSpace(crate::AddressSpace),
44 #[error("Operand {0:?} has invalid type.")]
45 InvalidOperand(Handle<crate::Expression>),
46 #[error("Operator {0:?} is not supported.")]
47 InvalidOperator(crate::AtomicFunction),
48 #[error("Result expression {0:?} is not an `AtomicResult` expression")]
49 InvalidResultExpression(Handle<crate::Expression>),
50 #[error("Result expression {0:?} is marked as an `exchange`")]
51 ResultExpressionExchange(Handle<crate::Expression>),
52 #[error("Result expression {0:?} is not marked as an `exchange`")]
53 ResultExpressionNotExchange(Handle<crate::Expression>),
54 #[error("Result type for {0:?} doesn't match the statement")]
55 ResultTypeMismatch(Handle<crate::Expression>),
56 #[error("Exchange operations must return a value")]
57 MissingReturnValue,
58 #[error("Capability {0:?} is required")]
59 MissingCapability(super::Capabilities),
60 #[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
61 ResultAlreadyPopulated(Handle<crate::Expression>),
62}
63
64#[derive(Clone, Debug, thiserror::Error)]
65#[cfg_attr(test, derive(PartialEq))]
66pub enum SubgroupError {
67 #[error("Operand {0:?} has invalid type.")]
68 InvalidOperand(Handle<crate::Expression>),
69 #[error("Result type for {0:?} doesn't match the statement")]
70 ResultTypeMismatch(Handle<crate::Expression>),
71 #[error("Support for subgroup operation {0:?} is required")]
72 UnsupportedOperation(super::SubgroupOperationSet),
73 #[error("Unknown operation")]
74 UnknownOperation,
75 #[error("Invocation ID must be a const-expression")]
76 InvalidInvocationIdExprType(Handle<crate::Expression>),
77}
78
79#[derive(Clone, Debug, thiserror::Error)]
80#[cfg_attr(test, derive(PartialEq))]
81pub enum LocalVariableError {
82 #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
83 InvalidType(Handle<crate::Type>),
84 #[error("Initializer doesn't match the variable type")]
85 InitializerType,
86 #[error("Initializer is not a const or override expression")]
87 NonConstOrOverrideInitializer,
88}
89
90#[derive(Clone, Debug, thiserror::Error)]
91#[cfg_attr(test, derive(PartialEq))]
92pub enum FunctionError {
93 #[error("Expression {handle:?} is invalid")]
94 Expression {
95 handle: Handle<crate::Expression>,
96 source: ExpressionError,
97 },
98 #[error("Expression {0:?} can't be introduced - it's already in scope")]
99 ExpressionAlreadyInScope(Handle<crate::Expression>),
100 #[error("Local variable {handle:?} '{name}' is invalid")]
101 LocalVariable {
102 handle: Handle<crate::LocalVariable>,
103 name: String,
104 source: LocalVariableError,
105 },
106 #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
107 InvalidArgumentType { index: usize, name: String },
108 #[error("The function's given return type cannot be returned from functions")]
109 NonConstructibleReturnType,
110 #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
111 InvalidArgumentPointerSpace {
112 index: usize,
113 name: String,
114 space: crate::AddressSpace,
115 },
116 #[error("The `break` is used outside of a `loop` or `switch` context")]
117 BreakOutsideOfLoopOrSwitch,
118 #[error("The `continue` is used outside of a `loop` context")]
119 ContinueOutsideOfLoop,
120 #[error("The `return` is called within a `continuing` block")]
121 InvalidReturnSpot,
122 #[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
123 InvalidReturnType {
124 expression: Option<Handle<crate::Expression>>,
125 expected_ty: Option<Handle<crate::Type>>,
126 },
127 #[error("The `if` condition {0:?} is not a boolean scalar")]
128 InvalidIfType(Handle<crate::Expression>),
129 #[error("The `switch` value {0:?} is not an integer scalar")]
130 InvalidSwitchType(Handle<crate::Expression>),
131 #[error("Multiple `switch` cases for {0:?} are present")]
132 ConflictingSwitchCase(crate::SwitchValue),
133 #[error("The `switch` contains cases with conflicting types")]
134 ConflictingCaseType,
135 #[error("The `switch` is missing a `default` case")]
136 MissingDefaultCase,
137 #[error("Multiple `default` cases are present")]
138 MultipleDefaultCases,
139 #[error("The last `switch` case contains a `fallthrough`")]
140 LastCaseFallTrough,
141 #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
142 InvalidStorePointer(Handle<crate::Expression>),
143 #[error("Image store texture parameter type mismatch")]
144 InvalidStoreTexture {
145 actual: Handle<crate::Expression>,
146 actual_ty: crate::TypeInner,
147 },
148 #[error("Image store value parameter type mismatch")]
149 InvalidStoreValue {
150 actual: Handle<crate::Expression>,
151 actual_ty: crate::TypeInner,
152 expected_ty: crate::TypeInner,
153 },
154 #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
155 InvalidStoreTypes {
156 pointer: Handle<crate::Expression>,
157 value: Handle<crate::Expression>,
158 },
159 #[error("Image store parameters are invalid")]
160 InvalidImageStore(#[source] ExpressionError),
161 #[error("Image atomic parameters are invalid")]
162 InvalidImageAtomic(#[source] ExpressionError),
163 #[error("Image atomic function is invalid")]
164 InvalidImageAtomicFunction(crate::AtomicFunction),
165 #[error("Image atomic value is invalid")]
166 InvalidImageAtomicValue(Handle<crate::Expression>),
167 #[error("Call to {function:?} is invalid")]
168 InvalidCall {
169 function: Handle<crate::Function>,
170 #[source]
171 error: CallError,
172 },
173 #[error("Atomic operation is invalid")]
174 InvalidAtomic(#[from] AtomicError),
175 #[error("Ray Query {0:?} is not a local variable")]
176 InvalidRayQueryExpression(Handle<crate::Expression>),
177 #[error("Acceleration structure {0:?} is not a matching expression")]
178 InvalidAccelerationStructure(Handle<crate::Expression>),
179 #[error(
180 "Acceleration structure {0:?} is missing flag vertex_return while Ray Query {1:?} does"
181 )]
182 MissingAccelerationStructureVertexReturn(Handle<crate::Expression>, Handle<crate::Expression>),
183 #[error("Ray Query {0:?} is missing flag vertex_return")]
184 MissingRayQueryVertexReturn(Handle<crate::Expression>),
185 #[error("Ray descriptor {0:?} is not a matching expression")]
186 InvalidRayDescriptor(Handle<crate::Expression>),
187 #[error("Ray Query {0:?} does not have a matching type")]
188 InvalidRayQueryType(Handle<crate::Type>),
189 #[error("Hit distance {0:?} must be an f32")]
190 InvalidHitDistanceType(Handle<crate::Expression>),
191 #[error("Shader requires capability {0:?}")]
192 MissingCapability(super::Capabilities),
193 #[error(
194 "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
195 )]
196 NonUniformControlFlow(
197 UniformityRequirements,
198 Handle<crate::Expression>,
199 UniformityDisruptor,
200 ),
201 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
202 PipelineInputRegularFunction { name: String },
203 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
204 PipelineOutputRegularFunction,
205 #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
206 NonUniformWorkgroupUniformLoad(UniformityDisruptor),
208 #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
210 WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
211 #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
212 WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
213 #[error("Subgroup operation is invalid")]
214 InvalidSubgroup(#[from] SubgroupError),
215 #[error("Invalid target type for a cooperative store")]
216 InvalidCooperativeStoreTarget(Handle<crate::Expression>),
217 #[error("Cooperative load/store data pointer has invalid type")]
218 InvalidCooperativeDataPointer(Handle<crate::Expression>),
219 #[error("Emit statement should not cover \"result\" expressions like {0:?}")]
220 EmitResult(Handle<crate::Expression>),
221 #[error("Expression not visited by the appropriate statement")]
222 UnvisitedExpression(Handle<crate::Expression>),
223 #[error("Expression {0:?} in mesh shader intrinsic call should be `u32` (is the expression a signed integer?)")]
224 InvalidMeshFunctionCall(Handle<crate::Expression>),
225 #[error("Mesh output types differ from {0:?} to {1:?}")]
226 ConflictingMeshOutputTypes(Handle<crate::Expression>, Handle<crate::Expression>),
227 #[error("Task payload variables differ from {0:?} to {1:?}")]
228 ConflictingTaskPayloadVariables(Handle<crate::Expression>, Handle<crate::Expression>),
229 #[error("Mesh shader output at {0:?} is not a user-defined struct")]
230 InvalidMeshShaderOutputType(Handle<crate::Expression>),
231}
232
233bitflags::bitflags! {
234 #[repr(transparent)]
235 #[derive(Clone, Copy)]
236 struct ControlFlowAbility: u8 {
237 const RETURN = 0x1;
239 const BREAK = 0x2;
241 const CONTINUE = 0x4;
243 }
244}
245
246struct BlockInfo {
247 stages: super::ShaderStages,
248}
249
250struct BlockContext<'a> {
251 abilities: ControlFlowAbility,
252 info: &'a FunctionInfo,
253 expressions: &'a Arena<crate::Expression>,
254 types: &'a UniqueArena<crate::Type>,
255 local_vars: &'a Arena<crate::LocalVariable>,
256 global_vars: &'a Arena<crate::GlobalVariable>,
257 functions: &'a Arena<crate::Function>,
258 special_types: &'a crate::SpecialTypes,
259 prev_infos: &'a [FunctionInfo],
260 return_type: Option<Handle<crate::Type>>,
261 local_expr_kind: &'a crate::proc::ExpressionKindTracker,
262}
263
264impl<'a> BlockContext<'a> {
265 fn new(
266 fun: &'a crate::Function,
267 module: &'a crate::Module,
268 info: &'a FunctionInfo,
269 prev_infos: &'a [FunctionInfo],
270 local_expr_kind: &'a crate::proc::ExpressionKindTracker,
271 ) -> Self {
272 Self {
273 abilities: ControlFlowAbility::RETURN,
274 info,
275 expressions: &fun.expressions,
276 types: &module.types,
277 local_vars: &fun.local_variables,
278 global_vars: &module.global_variables,
279 functions: &module.functions,
280 special_types: &module.special_types,
281 prev_infos,
282 return_type: fun.result.as_ref().map(|fr| fr.ty),
283 local_expr_kind,
284 }
285 }
286
287 const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
288 BlockContext { abilities, ..*self }
289 }
290
291 fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
292 &self.expressions[handle]
293 }
294
295 fn resolve_type_impl(
296 &self,
297 handle: Handle<crate::Expression>,
298 valid_expressions: &HandleSet<crate::Expression>,
299 ) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
300 if !valid_expressions.contains(handle) {
301 Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
302 } else {
303 Ok(&self.info[handle].ty)
304 }
305 }
306
307 fn resolve_type(
308 &self,
309 handle: Handle<crate::Expression>,
310 valid_expressions: &HandleSet<crate::Expression>,
311 ) -> Result<&TypeResolution, WithSpan<FunctionError>> {
312 self.resolve_type_impl(handle, valid_expressions)
313 .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
314 }
315
316 fn resolve_type_inner(
317 &self,
318 handle: Handle<crate::Expression>,
319 valid_expressions: &HandleSet<crate::Expression>,
320 ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
321 self.resolve_type(handle, valid_expressions)
322 .map(|tr| tr.inner_with(self.types))
323 }
324
325 fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
326 self.info[handle].ty.inner_with(self.types)
327 }
328
329 fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
330 crate::proc::compare_types(lhs, rhs, self.types)
331 }
332}
333
334impl super::Validator {
335 fn validate_call(
336 &mut self,
337 function: Handle<crate::Function>,
338 arguments: &[Handle<crate::Expression>],
339 result: Option<Handle<crate::Expression>>,
340 context: &BlockContext,
341 ) -> Result<super::ShaderStages, WithSpan<CallError>> {
342 let fun = &context.functions[function];
343 if fun.arguments.len() != arguments.len() {
344 return Err(CallError::ArgumentCount {
345 required: fun.arguments.len(),
346 seen: arguments.len(),
347 }
348 .with_span());
349 }
350 for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
351 let ty = context
352 .resolve_type_impl(expr, &self.valid_expression_set)
353 .map_err_inner(|source| {
354 CallError::Argument { index, source }
355 .with_span_handle(expr, context.expressions)
356 })?;
357 if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
358 return Err(CallError::ArgumentType {
359 index,
360 required: arg.ty,
361 seen_expression: expr,
362 }
363 .with_span_handle(expr, context.expressions));
364 }
365 }
366
367 if let Some(expr) = result {
368 if self.valid_expression_set.insert(expr) {
369 self.valid_expression_list.push(expr);
370 } else {
371 return Err(CallError::ResultAlreadyInScope(expr)
372 .with_span_handle(expr, context.expressions));
373 }
374 match context.expressions[expr] {
375 crate::Expression::CallResult(callee)
376 if fun.result.is_some() && callee == function =>
377 {
378 if !self.needs_visit.remove(expr) {
379 return Err(CallError::ResultAlreadyPopulated(expr)
380 .with_span_handle(expr, context.expressions));
381 }
382 }
383 _ => {
384 return Err(CallError::ExpressionMismatch(result)
385 .with_span_handle(expr, context.expressions))
386 }
387 }
388 } else if fun.result.is_some() {
389 return Err(CallError::ExpressionMismatch(result).with_span());
390 }
391
392 let callee_info = &context.prev_infos[function.index()];
393 Ok(callee_info.available_stages)
394 }
395
396 fn emit_expression(
397 &mut self,
398 handle: Handle<crate::Expression>,
399 context: &BlockContext,
400 ) -> Result<(), WithSpan<FunctionError>> {
401 if self.valid_expression_set.insert(handle) {
402 self.valid_expression_list.push(handle);
403 Ok(())
404 } else {
405 Err(FunctionError::ExpressionAlreadyInScope(handle)
406 .with_span_handle(handle, context.expressions))
407 }
408 }
409
410 fn validate_atomic(
411 &mut self,
412 pointer: Handle<crate::Expression>,
413 fun: &crate::AtomicFunction,
414 value: Handle<crate::Expression>,
415 result: Option<Handle<crate::Expression>>,
416 span: crate::Span,
417 context: &BlockContext,
418 ) -> Result<(), WithSpan<FunctionError>> {
419 let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
421 let crate::TypeInner::Pointer {
422 base: pointer_base,
423 space: pointer_space,
424 } = *pointer_inner
425 else {
426 log::error!("Atomic operation on type {:?}", *pointer_inner);
427 return Err(AtomicError::InvalidPointer(pointer)
428 .with_span_handle(pointer, context.expressions)
429 .into_other());
430 };
431 let crate::TypeInner::Atomic(pointer_scalar) = context.types[pointer_base].inner else {
432 log::error!(
433 "Atomic pointer to type {:?}",
434 context.types[pointer_base].inner
435 );
436 return Err(AtomicError::InvalidPointer(pointer)
437 .with_span_handle(pointer, context.expressions)
438 .into_other());
439 };
440
441 let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
443 let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
444 log::error!("Atomic operand type {:?}", *value_inner);
445 return Err(AtomicError::InvalidOperand(value)
446 .with_span_handle(value, context.expressions)
447 .into_other());
448 };
449 if pointer_scalar != value_scalar {
450 log::error!("Atomic operand type {:?}", *value_inner);
451 return Err(AtomicError::InvalidOperand(value)
452 .with_span_handle(value, context.expressions)
453 .into_other());
454 }
455
456 match pointer_scalar {
457 crate::Scalar::I64 | crate::Scalar::U64 => {
463 if self
466 .capabilities
467 .contains(super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS)
468 {
469 } else {
471 if matches!(
474 *fun,
475 crate::AtomicFunction::Min | crate::AtomicFunction::Max
476 ) && matches!(pointer_space, crate::AddressSpace::Storage { .. })
477 && result.is_none()
478 {
479 if !self
480 .capabilities
481 .contains(super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX)
482 {
483 log::error!("Int64 min-max atomic operations are not supported");
484 return Err(AtomicError::MissingCapability(
485 super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
486 )
487 .with_span_handle(value, context.expressions)
488 .into_other());
489 }
490 } else {
491 log::error!("Int64 atomic operations are not supported");
493 return Err(AtomicError::MissingCapability(
494 super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
495 )
496 .with_span_handle(value, context.expressions)
497 .into_other());
498 }
499 }
500 }
501 crate::Scalar::F32 => {
503 if !self
507 .capabilities
508 .contains(super::Capabilities::SHADER_FLOAT32_ATOMIC)
509 {
510 log::error!("Float32 atomic operations are not supported");
511 return Err(AtomicError::MissingCapability(
512 super::Capabilities::SHADER_FLOAT32_ATOMIC,
513 )
514 .with_span_handle(value, context.expressions)
515 .into_other());
516 }
517 if !matches!(
518 *fun,
519 crate::AtomicFunction::Add
520 | crate::AtomicFunction::Subtract
521 | crate::AtomicFunction::Exchange { compare: None }
522 ) {
523 log::error!("Float32 atomic operation {fun:?} is not supported");
524 return Err(AtomicError::InvalidOperator(*fun)
525 .with_span_handle(value, context.expressions)
526 .into_other());
527 }
528 if !matches!(pointer_space, crate::AddressSpace::Storage { .. }) {
529 log::error!(
530 "Float32 atomic operations are only supported in the Storage address space"
531 );
532 return Err(AtomicError::InvalidAddressSpace(pointer_space)
533 .with_span_handle(value, context.expressions)
534 .into_other());
535 }
536 }
537 _ => {}
538 }
539
540 match result {
542 Some(result) => {
543 let crate::Expression::AtomicResult {
545 ty: result_ty,
546 comparison,
547 } = context.expressions[result]
548 else {
549 return Err(AtomicError::InvalidResultExpression(result)
550 .with_span_handle(result, context.expressions)
551 .into_other());
552 };
553
554 if !self.needs_visit.remove(result) {
557 return Err(AtomicError::ResultAlreadyPopulated(result)
558 .with_span_handle(result, context.expressions)
559 .into_other());
560 }
561
562 if let crate::AtomicFunction::Exchange {
564 compare: Some(compare),
565 } = *fun
566 {
567 let compare_inner =
570 context.resolve_type_inner(compare, &self.valid_expression_set)?;
571 if !compare_inner.non_struct_equivalent(value_inner, context.types) {
572 log::error!(
573 "Atomic exchange comparison has a different type from the value"
574 );
575 return Err(AtomicError::InvalidOperand(compare)
576 .with_span_handle(compare, context.expressions)
577 .into_other());
578 }
579
580 let crate::TypeInner::Struct { ref members, .. } =
584 context.types[result_ty].inner
585 else {
586 return Err(AtomicError::ResultTypeMismatch(result)
587 .with_span_handle(result, context.expressions)
588 .into_other());
589 };
590 if !super::validate_atomic_compare_exchange_struct(
591 context.types,
592 members,
593 |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar),
594 ) {
595 return Err(AtomicError::ResultTypeMismatch(result)
596 .with_span_handle(result, context.expressions)
597 .into_other());
598 }
599
600 if !comparison {
602 return Err(AtomicError::ResultExpressionNotExchange(result)
603 .with_span_handle(result, context.expressions)
604 .into_other());
605 }
606 } else {
607 let result_inner = &context.types[result_ty].inner;
610 if !result_inner.non_struct_equivalent(value_inner, context.types) {
611 return Err(AtomicError::ResultTypeMismatch(result)
612 .with_span_handle(result, context.expressions)
613 .into_other());
614 }
615
616 if comparison {
618 return Err(AtomicError::ResultExpressionExchange(result)
619 .with_span_handle(result, context.expressions)
620 .into_other());
621 }
622 }
623 self.emit_expression(result, context)?;
624 }
625
626 None => {
627 if let crate::AtomicFunction::Exchange { compare: None } = *fun {
629 log::error!("Atomic exchange's value is unused");
630 return Err(AtomicError::MissingReturnValue
631 .with_span_static(span, "atomic exchange operation")
632 .into_other());
633 }
634 }
635 }
636
637 Ok(())
638 }
639 fn validate_subgroup_operation(
640 &mut self,
641 op: &crate::SubgroupOperation,
642 collective_op: &crate::CollectiveOperation,
643 argument: Handle<crate::Expression>,
644 result: Handle<crate::Expression>,
645 context: &BlockContext,
646 ) -> Result<(), WithSpan<FunctionError>> {
647 let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
648
649 let (is_scalar, scalar) = match *argument_inner {
650 crate::TypeInner::Scalar(scalar) => (true, scalar),
651 crate::TypeInner::Vector { scalar, .. } => (false, scalar),
652 _ => {
653 log::error!("Subgroup operand type {argument_inner:?}");
654 return Err(SubgroupError::InvalidOperand(argument)
655 .with_span_handle(argument, context.expressions)
656 .into_other());
657 }
658 };
659
660 use crate::ScalarKind as sk;
661 use crate::SubgroupOperation as sg;
662 match (scalar.kind, *op) {
663 (sk::Bool, sg::All | sg::Any) if is_scalar => {}
664 (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
665 (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
666
667 (_, _) => {
668 log::error!("Subgroup operand type {argument_inner:?}");
669 return Err(SubgroupError::InvalidOperand(argument)
670 .with_span_handle(argument, context.expressions)
671 .into_other());
672 }
673 };
674
675 use crate::CollectiveOperation as co;
676 match (*collective_op, *op) {
677 (
678 co::Reduce,
679 sg::All
680 | sg::Any
681 | sg::Add
682 | sg::Mul
683 | sg::Min
684 | sg::Max
685 | sg::And
686 | sg::Or
687 | sg::Xor,
688 ) => {}
689 (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
690
691 (_, _) => {
692 return Err(SubgroupError::UnknownOperation.with_span().into_other());
693 }
694 };
695
696 self.emit_expression(result, context)?;
697 match context.expressions[result] {
698 crate::Expression::SubgroupOperationResult { ty }
699 if { &context.types[ty].inner == argument_inner } => {}
700 _ => {
701 return Err(SubgroupError::ResultTypeMismatch(result)
702 .with_span_handle(result, context.expressions)
703 .into_other())
704 }
705 }
706 Ok(())
707 }
708 fn validate_subgroup_gather(
709 &mut self,
710 mode: &crate::GatherMode,
711 argument: Handle<crate::Expression>,
712 result: Handle<crate::Expression>,
713 context: &BlockContext,
714 ) -> Result<(), WithSpan<FunctionError>> {
715 match *mode {
716 crate::GatherMode::BroadcastFirst => {}
717 crate::GatherMode::Broadcast(index)
718 | crate::GatherMode::Shuffle(index)
719 | crate::GatherMode::ShuffleDown(index)
720 | crate::GatherMode::ShuffleUp(index)
721 | crate::GatherMode::ShuffleXor(index)
722 | crate::GatherMode::QuadBroadcast(index) => {
723 let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
724 match *index_ty {
725 crate::TypeInner::Scalar(crate::Scalar::U32) => {}
726 _ => {
727 log::error!(
728 "Subgroup gather index type {index_ty:?}, expected unsigned int"
729 );
730 return Err(SubgroupError::InvalidOperand(argument)
731 .with_span_handle(index, context.expressions)
732 .into_other());
733 }
734 }
735 }
736 crate::GatherMode::QuadSwap(_) => {}
737 }
738 match *mode {
739 crate::GatherMode::Broadcast(index) | crate::GatherMode::QuadBroadcast(index) => {
740 if !context.local_expr_kind.is_const(index) {
741 return Err(SubgroupError::InvalidInvocationIdExprType(index)
742 .with_span_handle(index, context.expressions)
743 .into_other());
744 }
745 }
746 _ => {}
747 }
748 let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
749 if !matches!(*argument_inner,
750 crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
751 if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
752 ) {
753 log::error!("Subgroup gather operand type {argument_inner:?}");
754 return Err(SubgroupError::InvalidOperand(argument)
755 .with_span_handle(argument, context.expressions)
756 .into_other());
757 }
758
759 self.emit_expression(result, context)?;
760 match context.expressions[result] {
761 crate::Expression::SubgroupOperationResult { ty }
762 if { &context.types[ty].inner == argument_inner } => {}
763 _ => {
764 return Err(SubgroupError::ResultTypeMismatch(result)
765 .with_span_handle(result, context.expressions)
766 .into_other())
767 }
768 }
769 Ok(())
770 }
771
772 fn validate_block_impl(
773 &mut self,
774 statements: &crate::Block,
775 context: &BlockContext,
776 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
777 use crate::{AddressSpace, Statement as S, TypeInner as Ti};
778 let mut stages = super::ShaderStages::all();
779 for (statement, &span) in statements.span_iter() {
780 match *statement {
781 S::Emit(ref range) => {
782 for handle in range.clone() {
783 use crate::Expression as Ex;
784 match context.expressions[handle] {
785 Ex::Literal(_)
786 | Ex::Constant(_)
787 | Ex::Override(_)
788 | Ex::ZeroValue(_)
789 | Ex::Compose { .. }
790 | Ex::Access { .. }
791 | Ex::AccessIndex { .. }
792 | Ex::Splat { .. }
793 | Ex::Swizzle { .. }
794 | Ex::FunctionArgument(_)
795 | Ex::GlobalVariable(_)
796 | Ex::LocalVariable(_)
797 | Ex::Load { .. }
798 | Ex::ImageSample { .. }
799 | Ex::ImageLoad { .. }
800 | Ex::ImageQuery { .. }
801 | Ex::Unary { .. }
802 | Ex::Binary { .. }
803 | Ex::Select { .. }
804 | Ex::Derivative { .. }
805 | Ex::Relational { .. }
806 | Ex::Math { .. }
807 | Ex::As { .. }
808 | Ex::ArrayLength(_)
809 | Ex::RayQueryGetIntersection { .. }
810 | Ex::RayQueryVertexPositions { .. }
811 | Ex::CooperativeLoad { .. }
812 | Ex::CooperativeMultiplyAdd { .. } => {
813 self.emit_expression(handle, context)?
814 }
815 Ex::CallResult(_)
816 | Ex::AtomicResult { .. }
817 | Ex::WorkGroupUniformLoadResult { .. }
818 | Ex::RayQueryProceedResult
819 | Ex::SubgroupBallotResult
820 | Ex::SubgroupOperationResult { .. } => {
821 return Err(FunctionError::EmitResult(handle)
822 .with_span_handle(handle, context.expressions));
823 }
824 }
825 }
826 }
827 S::Block(ref block) => {
828 let info = self.validate_block(block, context)?;
829 stages &= info.stages;
830 }
831 S::If {
832 condition,
833 ref accept,
834 ref reject,
835 } => {
836 match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
837 Ti::Scalar(crate::Scalar {
838 kind: crate::ScalarKind::Bool,
839 width: _,
840 }) => {}
841 _ => {
842 return Err(FunctionError::InvalidIfType(condition)
843 .with_span_handle(condition, context.expressions))
844 }
845 }
846 stages &= self.validate_block(accept, context)?.stages;
847 stages &= self.validate_block(reject, context)?.stages;
848 }
849 S::Switch {
850 selector,
851 ref cases,
852 } => {
853 let uint = match context
854 .resolve_type_inner(selector, &self.valid_expression_set)?
855 .scalar_kind()
856 {
857 Some(crate::ScalarKind::Uint) => true,
858 Some(crate::ScalarKind::Sint) => false,
859 _ => {
860 return Err(FunctionError::InvalidSwitchType(selector)
861 .with_span_handle(selector, context.expressions))
862 }
863 };
864 self.switch_values.clear();
865 for case in cases {
866 match case.value {
867 crate::SwitchValue::I32(_) if !uint => {}
868 crate::SwitchValue::U32(_) if uint => {}
869 crate::SwitchValue::Default => {}
870 _ => {
871 return Err(FunctionError::ConflictingCaseType.with_span_static(
872 case.body
873 .span_iter()
874 .next()
875 .map_or(Default::default(), |(_, s)| *s),
876 "conflicting switch arm here",
877 ));
878 }
879 };
880 if !self.switch_values.insert(case.value) {
881 return Err(match case.value {
882 crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
883 .with_span_static(
884 case.body
885 .span_iter()
886 .next()
887 .map_or(Default::default(), |(_, s)| *s),
888 "duplicated switch arm here",
889 ),
890 _ => FunctionError::ConflictingSwitchCase(case.value)
891 .with_span_static(
892 case.body
893 .span_iter()
894 .next()
895 .map_or(Default::default(), |(_, s)| *s),
896 "conflicting switch arm here",
897 ),
898 });
899 }
900 }
901 if !self.switch_values.contains(&crate::SwitchValue::Default) {
902 return Err(FunctionError::MissingDefaultCase
903 .with_span_static(span, "missing default case"));
904 }
905 if let Some(case) = cases.last() {
906 if case.fall_through {
907 return Err(FunctionError::LastCaseFallTrough.with_span_static(
908 case.body
909 .span_iter()
910 .next()
911 .map_or(Default::default(), |(_, s)| *s),
912 "bad switch arm here",
913 ));
914 }
915 }
916 let pass_through_abilities = context.abilities
917 & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
918 let sub_context =
919 context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
920 for case in cases {
921 stages &= self.validate_block(&case.body, &sub_context)?.stages;
922 }
923 }
924 S::Loop {
925 ref body,
926 ref continuing,
927 break_if,
928 } => {
929 let base_expression_count = self.valid_expression_list.len();
932 let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
933 stages &= self
934 .validate_block_impl(
935 body,
936 &context.with_abilities(
937 pass_through_abilities
938 | ControlFlowAbility::BREAK
939 | ControlFlowAbility::CONTINUE,
940 ),
941 )?
942 .stages;
943 stages &= self
944 .validate_block_impl(
945 continuing,
946 &context.with_abilities(ControlFlowAbility::empty()),
947 )?
948 .stages;
949
950 if let Some(condition) = break_if {
951 match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
952 Ti::Scalar(crate::Scalar {
953 kind: crate::ScalarKind::Bool,
954 width: _,
955 }) => {}
956 _ => {
957 return Err(FunctionError::InvalidIfType(condition)
958 .with_span_handle(condition, context.expressions))
959 }
960 }
961 }
962
963 for handle in self.valid_expression_list.drain(base_expression_count..) {
964 self.valid_expression_set.remove(handle);
965 }
966 }
967 S::Break => {
968 if !context.abilities.contains(ControlFlowAbility::BREAK) {
969 return Err(FunctionError::BreakOutsideOfLoopOrSwitch
970 .with_span_static(span, "invalid break"));
971 }
972 }
973 S::Continue => {
974 if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
975 return Err(FunctionError::ContinueOutsideOfLoop
976 .with_span_static(span, "invalid continue"));
977 }
978 }
979 S::Return { value } => {
980 if !context.abilities.contains(ControlFlowAbility::RETURN) {
981 return Err(FunctionError::InvalidReturnSpot
982 .with_span_static(span, "invalid return"));
983 }
984 let value_ty = value
985 .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
986 .transpose()?;
987 let okay = match (value_ty, context.return_type) {
990 (None, None) => true,
991 (Some(value_inner), Some(expected_ty)) => {
992 context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
993 }
994 (_, _) => false,
995 };
996
997 if !okay {
998 log::error!(
999 "Returning {:?} where {:?} is expected",
1000 value_ty,
1001 context.return_type,
1002 );
1003 if let Some(handle) = value {
1004 return Err(FunctionError::InvalidReturnType {
1005 expression: value,
1006 expected_ty: context.return_type,
1007 }
1008 .with_span_handle(handle, context.expressions));
1009 } else {
1010 return Err(FunctionError::InvalidReturnType {
1011 expression: value,
1012 expected_ty: context.return_type,
1013 }
1014 .with_span_static(span, "invalid return"));
1015 }
1016 }
1017 }
1018 S::Kill => {
1019 stages &= super::ShaderStages::FRAGMENT;
1020 }
1021 S::ControlBarrier(barrier) | S::MemoryBarrier(barrier) => {
1022 stages &= super::ShaderStages::COMPUTE_LIKE;
1023 if barrier.contains(crate::Barrier::SUB_GROUP) {
1024 if !self.capabilities.contains(
1025 super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
1026 ) {
1027 return Err(FunctionError::MissingCapability(
1028 super::Capabilities::SUBGROUP
1029 | super::Capabilities::SUBGROUP_BARRIER,
1030 )
1031 .with_span_static(span, "missing capability for this operation"));
1032 }
1033 if !self
1034 .subgroup_operations
1035 .contains(super::SubgroupOperationSet::BASIC)
1036 {
1037 return Err(FunctionError::InvalidSubgroup(
1038 SubgroupError::UnsupportedOperation(
1039 super::SubgroupOperationSet::BASIC,
1040 ),
1041 )
1042 .with_span_static(span, "support for this operation is not present"));
1043 }
1044 }
1045 }
1046 S::Store { pointer, value } => {
1047 let mut current = pointer;
1048 loop {
1049 match context.expressions[current] {
1050 crate::Expression::Access { base, .. }
1051 | crate::Expression::AccessIndex { base, .. } => current = base,
1052 crate::Expression::LocalVariable(_)
1053 | crate::Expression::GlobalVariable(_)
1054 | crate::Expression::FunctionArgument(_) => break,
1055 _ => {
1056 return Err(FunctionError::InvalidStorePointer(current)
1057 .with_span_handle(pointer, context.expressions))
1058 }
1059 }
1060 }
1061
1062 let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
1063 let value_ty = value_tr.inner_with(context.types);
1064 match *value_ty {
1065 Ti::Image { .. } | Ti::Sampler { .. } => {
1066 return Err(FunctionError::InvalidStoreTexture {
1067 actual: value,
1068 actual_ty: value_ty.clone(),
1069 }
1070 .with_span_context((
1071 context.expressions.get_span(value),
1072 format!("this value is of type {value_ty:?}"),
1073 ))
1074 .with_span(span, "expects a texture argument"));
1075 }
1076 _ => {}
1077 }
1078
1079 let pointer_ty = context.resolve_pointer_type(pointer);
1080 let pointer_base_tr = pointer_ty.pointer_base_type();
1081 let pointer_base_ty = pointer_base_tr
1082 .as_ref()
1083 .map(|ty| ty.inner_with(context.types));
1084 let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
1085 *value_ty == Ti::Scalar(*scalar)
1087 } else if let Some(tr) = pointer_base_tr {
1088 context.compare_types(value_tr, &tr)
1089 } else {
1090 false
1091 };
1092
1093 if !good {
1094 return Err(FunctionError::InvalidStoreTypes { pointer, value }
1095 .with_span()
1096 .with_handle(pointer, context.expressions)
1097 .with_handle(value, context.expressions));
1098 }
1099
1100 if let Some(space) = pointer_ty.pointer_space() {
1101 if !space.access().contains(crate::StorageAccess::STORE) {
1102 return Err(FunctionError::InvalidStorePointer(pointer)
1103 .with_span_static(
1104 context.expressions.get_span(pointer),
1105 "writing to this location is not permitted",
1106 ));
1107 }
1108 }
1109 }
1110 S::ImageStore {
1111 image,
1112 coordinate,
1113 array_index,
1114 value,
1115 } => {
1116 let global_var;
1119 let image_ty;
1120 match *context.get_expression(image) {
1121 crate::Expression::GlobalVariable(var_handle) => {
1122 global_var = &context.global_vars[var_handle];
1123 image_ty = global_var.ty;
1124 }
1125 crate::Expression::Access { base, .. }
1129 | crate::Expression::AccessIndex { base, .. } => {
1130 let crate::Expression::GlobalVariable(var_handle) =
1131 *context.get_expression(base)
1132 else {
1133 return Err(FunctionError::InvalidImageStore(
1134 ExpressionError::ExpectedGlobalVariable,
1135 )
1136 .with_span_handle(image, context.expressions));
1137 };
1138 global_var = &context.global_vars[var_handle];
1139
1140 let Ti::BindingArray { base, .. } = context.types[global_var.ty].inner
1142 else {
1143 return Err(FunctionError::InvalidImageStore(
1144 ExpressionError::ExpectedBindingArrayType(global_var.ty),
1145 )
1146 .with_span_handle(global_var.ty, context.types));
1147 };
1148
1149 image_ty = base;
1150 }
1151 _ => {
1152 return Err(FunctionError::InvalidImageStore(
1153 ExpressionError::ExpectedGlobalVariable,
1154 )
1155 .with_span_handle(image, context.expressions))
1156 }
1157 };
1158
1159 let Ti::Image {
1161 class,
1162 arrayed,
1163 dim,
1164 } = context.types[image_ty].inner
1165 else {
1166 return Err(FunctionError::InvalidImageStore(
1167 ExpressionError::ExpectedImageType(global_var.ty),
1168 )
1169 .with_span()
1170 .with_handle(global_var.ty, context.types)
1171 .with_handle(image, context.expressions));
1172 };
1173
1174 let crate::ImageClass::Storage { format, .. } = class else {
1176 return Err(FunctionError::InvalidImageStore(
1177 ExpressionError::InvalidImageClass(class),
1178 )
1179 .with_span_handle(image, context.expressions));
1180 };
1181
1182 if context
1184 .resolve_type_inner(coordinate, &self.valid_expression_set)?
1185 .image_storage_coordinates()
1186 .is_none_or(|coord_dim| coord_dim != dim)
1187 {
1188 return Err(FunctionError::InvalidImageStore(
1189 ExpressionError::InvalidImageCoordinateType(dim, coordinate),
1190 )
1191 .with_span_handle(coordinate, context.expressions));
1192 }
1193
1194 if arrayed != array_index.is_some() {
1197 return Err(FunctionError::InvalidImageStore(
1198 ExpressionError::InvalidImageArrayIndex,
1199 )
1200 .with_span_handle(coordinate, context.expressions));
1201 }
1202
1203 if let Some(expr) = array_index {
1205 if !matches!(
1206 *context.resolve_type_inner(expr, &self.valid_expression_set)?,
1207 Ti::Scalar(crate::Scalar {
1208 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1209 width: _,
1210 })
1211 ) {
1212 return Err(FunctionError::InvalidImageStore(
1213 ExpressionError::InvalidImageArrayIndexType(expr),
1214 )
1215 .with_span_handle(expr, context.expressions));
1216 }
1217 }
1218
1219 let value_ty = crate::TypeInner::Vector {
1220 size: crate::VectorSize::Quad,
1221 scalar: format.into(),
1222 };
1223
1224 let actual_value_ty =
1227 context.resolve_type_inner(value, &self.valid_expression_set)?;
1228 if actual_value_ty != &value_ty {
1229 return Err(FunctionError::InvalidStoreValue {
1230 actual: value,
1231 actual_ty: actual_value_ty.clone(),
1232 expected_ty: value_ty.clone(),
1233 }
1234 .with_span_context((
1235 context.expressions.get_span(value),
1236 format!("this value is of type {actual_value_ty:?}"),
1237 ))
1238 .with_span(
1239 span,
1240 format!("expects a value argument of type {value_ty:?}"),
1241 ));
1242 }
1243 }
1244 S::Call {
1245 function,
1246 ref arguments,
1247 result,
1248 } => match self.validate_call(function, arguments, result, context) {
1249 Ok(callee_stages) => stages &= callee_stages,
1250 Err(error) => {
1251 return Err(error.and_then(|error| {
1252 FunctionError::InvalidCall { function, error }
1253 .with_span_static(span, "invalid function call")
1254 }))
1255 }
1256 },
1257 S::Atomic {
1258 pointer,
1259 ref fun,
1260 value,
1261 result,
1262 } => {
1263 self.validate_atomic(pointer, fun, value, result, span, context)?;
1264 }
1265 S::ImageAtomic {
1266 image,
1267 coordinate,
1268 array_index,
1269 fun,
1270 value,
1271 } => {
1272 let var = match *context.get_expression(image) {
1273 crate::Expression::GlobalVariable(var_handle) => {
1274 &context.global_vars[var_handle]
1275 }
1276 crate::Expression::Access { base, .. }
1278 | crate::Expression::AccessIndex { base, .. } => {
1279 match *context.get_expression(base) {
1280 crate::Expression::GlobalVariable(var_handle) => {
1281 &context.global_vars[var_handle]
1282 }
1283 _ => {
1284 return Err(FunctionError::InvalidImageAtomic(
1285 ExpressionError::ExpectedGlobalVariable,
1286 )
1287 .with_span_handle(image, context.expressions))
1288 }
1289 }
1290 }
1291 _ => {
1292 return Err(FunctionError::InvalidImageAtomic(
1293 ExpressionError::ExpectedGlobalVariable,
1294 )
1295 .with_span_handle(image, context.expressions))
1296 }
1297 };
1298
1299 let global_ty = match context.types[var.ty].inner {
1301 Ti::BindingArray { base, .. } => &context.types[base].inner,
1302 ref inner => inner,
1303 };
1304
1305 let value_ty = match *global_ty {
1306 Ti::Image {
1307 class,
1308 arrayed,
1309 dim,
1310 } => {
1311 match context
1312 .resolve_type_inner(coordinate, &self.valid_expression_set)?
1313 .image_storage_coordinates()
1314 {
1315 Some(coord_dim) if coord_dim == dim => {}
1316 _ => {
1317 return Err(FunctionError::InvalidImageAtomic(
1318 ExpressionError::InvalidImageCoordinateType(
1319 dim, coordinate,
1320 ),
1321 )
1322 .with_span_handle(coordinate, context.expressions));
1323 }
1324 };
1325 if arrayed != array_index.is_some() {
1326 return Err(FunctionError::InvalidImageAtomic(
1327 ExpressionError::InvalidImageArrayIndex,
1328 )
1329 .with_span_handle(coordinate, context.expressions));
1330 }
1331 if let Some(expr) = array_index {
1332 match *context
1333 .resolve_type_inner(expr, &self.valid_expression_set)?
1334 {
1335 Ti::Scalar(crate::Scalar {
1336 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1337 width: _,
1338 }) => {}
1339 _ => {
1340 return Err(FunctionError::InvalidImageAtomic(
1341 ExpressionError::InvalidImageArrayIndexType(expr),
1342 )
1343 .with_span_handle(expr, context.expressions));
1344 }
1345 }
1346 }
1347 match class {
1348 crate::ImageClass::Storage { format, access } => {
1349 if !access.contains(crate::StorageAccess::ATOMIC) {
1350 return Err(FunctionError::InvalidImageAtomic(
1351 ExpressionError::InvalidImageStorageAccess(access),
1352 )
1353 .with_span_handle(image, context.expressions));
1354 }
1355 match format {
1356 crate::StorageFormat::R64Uint => {
1357 if !self.capabilities.intersects(
1358 super::Capabilities::TEXTURE_INT64_ATOMIC,
1359 ) {
1360 return Err(FunctionError::MissingCapability(
1361 super::Capabilities::TEXTURE_INT64_ATOMIC,
1362 )
1363 .with_span_static(
1364 span,
1365 "missing capability for this operation",
1366 ));
1367 }
1368 match fun {
1369 crate::AtomicFunction::Min
1370 | crate::AtomicFunction::Max => {}
1371 _ => {
1372 return Err(
1373 FunctionError::InvalidImageAtomicFunction(
1374 fun,
1375 )
1376 .with_span_handle(
1377 image,
1378 context.expressions,
1379 ),
1380 );
1381 }
1382 }
1383 }
1384 crate::StorageFormat::R32Sint
1385 | crate::StorageFormat::R32Uint => {
1386 if !self
1387 .capabilities
1388 .intersects(super::Capabilities::TEXTURE_ATOMIC)
1389 {
1390 return Err(FunctionError::MissingCapability(
1391 super::Capabilities::TEXTURE_ATOMIC,
1392 )
1393 .with_span_static(
1394 span,
1395 "missing capability for this operation",
1396 ));
1397 }
1398 match fun {
1399 crate::AtomicFunction::Add
1400 | crate::AtomicFunction::And
1401 | crate::AtomicFunction::ExclusiveOr
1402 | crate::AtomicFunction::InclusiveOr
1403 | crate::AtomicFunction::Min
1404 | crate::AtomicFunction::Max => {}
1405 _ => {
1406 return Err(
1407 FunctionError::InvalidImageAtomicFunction(
1408 fun,
1409 )
1410 .with_span_handle(
1411 image,
1412 context.expressions,
1413 ),
1414 );
1415 }
1416 }
1417 }
1418 _ => {
1419 return Err(FunctionError::InvalidImageAtomic(
1420 ExpressionError::InvalidImageFormat(format),
1421 )
1422 .with_span_handle(image, context.expressions));
1423 }
1424 }
1425 crate::TypeInner::Scalar(format.into())
1426 }
1427 _ => {
1428 return Err(FunctionError::InvalidImageAtomic(
1429 ExpressionError::InvalidImageClass(class),
1430 )
1431 .with_span_handle(image, context.expressions));
1432 }
1433 }
1434 }
1435 _ => {
1436 return Err(FunctionError::InvalidImageAtomic(
1437 ExpressionError::ExpectedImageType(var.ty),
1438 )
1439 .with_span()
1440 .with_handle(var.ty, context.types)
1441 .with_handle(image, context.expressions))
1442 }
1443 };
1444
1445 if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
1446 return Err(FunctionError::InvalidImageAtomicValue(value)
1447 .with_span_handle(value, context.expressions));
1448 }
1449 }
1450 S::WorkGroupUniformLoad { pointer, result } => {
1451 stages &= super::ShaderStages::COMPUTE_LIKE;
1452 let pointer_inner =
1453 context.resolve_type_inner(pointer, &self.valid_expression_set)?;
1454 match *pointer_inner {
1455 Ti::Pointer {
1456 space: AddressSpace::WorkGroup,
1457 ..
1458 } => {}
1459 Ti::ValuePointer {
1460 space: AddressSpace::WorkGroup,
1461 ..
1462 } => {}
1463 _ => {
1464 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1465 .with_span_static(span, "WorkGroupUniformLoad"))
1466 }
1467 }
1468 self.emit_expression(result, context)?;
1469 let ty = match &context.expressions[result] {
1470 &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1471 _ => {
1472 return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1473 result,
1474 )
1475 .with_span_static(span, "WorkGroupUniformLoad"));
1476 }
1477 };
1478 let expected_pointer_inner = Ti::Pointer {
1479 base: ty,
1480 space: AddressSpace::WorkGroup,
1481 };
1482 let atomic_specialization_ok = match *pointer_inner {
1485 Ti::Pointer {
1486 base: pointer_base,
1487 space: AddressSpace::WorkGroup,
1488 } => match (&context.types[pointer_base].inner, &context.types[ty].inner) {
1489 (&Ti::Atomic(pointer_scalar), &Ti::Scalar(result_scalar)) => {
1490 pointer_scalar == result_scalar
1491 }
1492 _ => false,
1493 },
1494 _ => false,
1495 };
1496 if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types)
1497 && !atomic_specialization_ok
1498 {
1499 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1500 .with_span_static(span, "WorkGroupUniformLoad"));
1501 }
1502 }
1503 S::RayQuery { query, ref fun } => {
1504 let query_var = match *context.get_expression(query) {
1505 crate::Expression::LocalVariable(var) => &context.local_vars[var],
1506 ref other => {
1507 log::error!("Unexpected ray query expression {other:?}");
1508 return Err(FunctionError::InvalidRayQueryExpression(query)
1509 .with_span_static(span, "invalid query expression"));
1510 }
1511 };
1512 let rq_vertex_return = match context.types[query_var.ty].inner {
1513 Ti::RayQuery { vertex_return } => vertex_return,
1514 ref other => {
1515 log::error!("Unexpected ray query type {other:?}");
1516 return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1517 .with_span_static(span, "invalid query type"));
1518 }
1519 };
1520 match *fun {
1521 crate::RayQueryFunction::Initialize {
1522 acceleration_structure,
1523 descriptor,
1524 } => {
1525 match *context.resolve_type_inner(
1526 acceleration_structure,
1527 &self.valid_expression_set,
1528 )? {
1529 Ti::AccelerationStructure { vertex_return } => {
1530 if (!vertex_return) && rq_vertex_return {
1531 return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
1532 }
1533 }
1534 _ => {
1535 return Err(FunctionError::InvalidAccelerationStructure(
1536 acceleration_structure,
1537 )
1538 .with_span_static(span, "invalid acceleration structure"))
1539 }
1540 }
1541 let desc_ty_given = context
1542 .resolve_type_inner(descriptor, &self.valid_expression_set)?;
1543 let desc_ty_expected = context
1544 .special_types
1545 .ray_desc
1546 .map(|handle| &context.types[handle].inner);
1547 if Some(desc_ty_given) != desc_ty_expected {
1548 return Err(FunctionError::InvalidRayDescriptor(descriptor)
1549 .with_span_static(span, "invalid ray descriptor"));
1550 }
1551 }
1552 crate::RayQueryFunction::Proceed { result } => {
1553 self.emit_expression(result, context)?;
1554 }
1555 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1556 match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
1557 Ti::Scalar(crate::Scalar {
1558 kind: crate::ScalarKind::Float,
1559 width: _,
1560 }) => {}
1561 _ => {
1562 return Err(FunctionError::InvalidHitDistanceType(hit_t)
1563 .with_span_static(span, "invalid hit_t"))
1564 }
1565 }
1566 }
1567 crate::RayQueryFunction::ConfirmIntersection => {}
1568 crate::RayQueryFunction::Terminate => {}
1569 }
1570 }
1571 S::SubgroupBallot { result, predicate } => {
1572 stages &= self.subgroup_stages;
1573 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1574 return Err(FunctionError::MissingCapability(
1575 super::Capabilities::SUBGROUP,
1576 )
1577 .with_span_static(span, "missing capability for this operation"));
1578 }
1579 if !self
1580 .subgroup_operations
1581 .contains(super::SubgroupOperationSet::BALLOT)
1582 {
1583 return Err(FunctionError::InvalidSubgroup(
1584 SubgroupError::UnsupportedOperation(
1585 super::SubgroupOperationSet::BALLOT,
1586 ),
1587 )
1588 .with_span_static(span, "support for this operation is not present"));
1589 }
1590 if let Some(predicate) = predicate {
1591 let predicate_inner =
1592 context.resolve_type_inner(predicate, &self.valid_expression_set)?;
1593 if !matches!(
1594 *predicate_inner,
1595 crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1596 ) {
1597 log::error!(
1598 "Subgroup ballot predicate type {predicate_inner:?} expected bool"
1599 );
1600 return Err(SubgroupError::InvalidOperand(predicate)
1601 .with_span_handle(predicate, context.expressions)
1602 .into_other());
1603 }
1604 }
1605 self.emit_expression(result, context)?;
1606 }
1607 S::SubgroupCollectiveOperation {
1608 ref op,
1609 ref collective_op,
1610 argument,
1611 result,
1612 } => {
1613 stages &= self.subgroup_stages;
1614 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1615 return Err(FunctionError::MissingCapability(
1616 super::Capabilities::SUBGROUP,
1617 )
1618 .with_span_static(span, "missing capability for this operation"));
1619 }
1620 let operation = op.required_operations();
1621 if !self.subgroup_operations.contains(operation) {
1622 return Err(FunctionError::InvalidSubgroup(
1623 SubgroupError::UnsupportedOperation(operation),
1624 )
1625 .with_span_static(span, "support for this operation is not present"));
1626 }
1627 self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1628 }
1629 S::SubgroupGather {
1630 ref mode,
1631 argument,
1632 result,
1633 } => {
1634 stages &= self.subgroup_stages;
1635 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1636 return Err(FunctionError::MissingCapability(
1637 super::Capabilities::SUBGROUP,
1638 )
1639 .with_span_static(span, "missing capability for this operation"));
1640 }
1641 let operation = mode.required_operations();
1642 if !self.subgroup_operations.contains(operation) {
1643 return Err(FunctionError::InvalidSubgroup(
1644 SubgroupError::UnsupportedOperation(operation),
1645 )
1646 .with_span_static(span, "support for this operation is not present"));
1647 }
1648 self.validate_subgroup_gather(mode, argument, result, context)?;
1649 }
1650 S::CooperativeStore { target, ref data } => {
1651 stages &= super::ShaderStages::COMPUTE;
1652
1653 let target_scalar =
1654 match *context.resolve_type_inner(target, &self.valid_expression_set)? {
1655 Ti::CooperativeMatrix { scalar, .. } => scalar,
1656 ref other => {
1657 log::error!("Target operand type: {other:?}");
1658 return Err(FunctionError::InvalidCooperativeStoreTarget(target)
1659 .with_span_handle(target, context.expressions));
1660 }
1661 };
1662
1663 let ptr_ty = context.resolve_pointer_type(data.pointer);
1664 let ptr_scalar = ptr_ty
1665 .pointer_base_type()
1666 .and_then(|tr| tr.inner_with(context.types).scalar());
1667 if ptr_scalar != Some(target_scalar) {
1668 return Err(FunctionError::InvalidCooperativeDataPointer(data.pointer)
1669 .with_span_handle(data.pointer, context.expressions));
1670 }
1671
1672 let ptr_space = ptr_ty.pointer_space().unwrap_or(AddressSpace::Handle);
1673 if !ptr_space.access().contains(crate::StorageAccess::STORE) {
1674 return Err(FunctionError::InvalidStorePointer(data.pointer)
1675 .with_span_static(
1676 context.expressions.get_span(data.pointer),
1677 "writing to this location is not permitted",
1678 ));
1679 }
1680 }
1681 }
1682 }
1683 Ok(BlockInfo { stages })
1684 }
1685
1686 fn validate_block(
1687 &mut self,
1688 statements: &crate::Block,
1689 context: &BlockContext,
1690 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1691 let base_expression_count = self.valid_expression_list.len();
1692 let info = self.validate_block_impl(statements, context)?;
1693 for handle in self.valid_expression_list.drain(base_expression_count..) {
1694 self.valid_expression_set.remove(handle);
1695 }
1696 Ok(info)
1697 }
1698
1699 fn validate_local_var(
1700 &self,
1701 var: &crate::LocalVariable,
1702 gctx: crate::proc::GlobalCtx,
1703 fun_info: &FunctionInfo,
1704 local_expr_kind: &crate::proc::ExpressionKindTracker,
1705 ) -> Result<(), LocalVariableError> {
1706 log::debug!("var {var:?}");
1707 let type_info = self
1708 .types
1709 .get(var.ty.index())
1710 .ok_or(LocalVariableError::InvalidType(var.ty))?;
1711 if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1712 return Err(LocalVariableError::InvalidType(var.ty));
1713 }
1714
1715 if let Some(init) = var.init {
1716 if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
1717 return Err(LocalVariableError::InitializerType);
1718 }
1719
1720 if !local_expr_kind.is_const_or_override(init) {
1721 return Err(LocalVariableError::NonConstOrOverrideInitializer);
1722 }
1723 }
1724
1725 Ok(())
1726 }
1727
1728 pub(super) fn validate_function(
1729 &mut self,
1730 fun: &crate::Function,
1731 module: &crate::Module,
1732 mod_info: &ModuleInfo,
1733 entry_point: bool,
1734 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1735 let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1736
1737 let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1738
1739 for (var_handle, var) in fun.local_variables.iter() {
1740 self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1741 .map_err(|source| {
1742 FunctionError::LocalVariable {
1743 handle: var_handle,
1744 name: var.name.clone().unwrap_or_default(),
1745 source,
1746 }
1747 .with_span_handle(var.ty, &module.types)
1748 .with_handle(var_handle, &fun.local_variables)
1749 })?;
1750 }
1751
1752 for (index, argument) in fun.arguments.iter().enumerate() {
1753 match module.types[argument.ty].inner.pointer_space() {
1754 Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1755 Some(other) => {
1756 return Err(FunctionError::InvalidArgumentPointerSpace {
1757 index,
1758 name: argument.name.clone().unwrap_or_default(),
1759 space: other,
1760 }
1761 .with_span_handle(argument.ty, &module.types))
1762 }
1763 }
1764 if !self.types[argument.ty.index()]
1766 .flags
1767 .contains(super::TypeFlags::ARGUMENT)
1768 {
1769 return Err(FunctionError::InvalidArgumentType {
1770 index,
1771 name: argument.name.clone().unwrap_or_default(),
1772 }
1773 .with_span_handle(argument.ty, &module.types));
1774 }
1775
1776 if !entry_point && argument.binding.is_some() {
1777 return Err(FunctionError::PipelineInputRegularFunction {
1778 name: argument.name.clone().unwrap_or_default(),
1779 }
1780 .with_span_handle(argument.ty, &module.types));
1781 }
1782 }
1783
1784 if let Some(ref result) = fun.result {
1785 if !self.types[result.ty.index()]
1786 .flags
1787 .contains(super::TypeFlags::CONSTRUCTIBLE)
1788 {
1789 return Err(FunctionError::NonConstructibleReturnType
1790 .with_span_handle(result.ty, &module.types));
1791 }
1792
1793 if !entry_point && result.binding.is_some() {
1794 return Err(FunctionError::PipelineOutputRegularFunction
1795 .with_span_handle(result.ty, &module.types));
1796 }
1797 }
1798
1799 self.valid_expression_set.clear_for_arena(&fun.expressions);
1800 self.valid_expression_list.clear();
1801 self.needs_visit.clear_for_arena(&fun.expressions);
1802 for (handle, expr) in fun.expressions.iter() {
1803 if expr.needs_pre_emit() {
1804 self.valid_expression_set.insert(handle);
1805 }
1806 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1807 if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1810 *expr
1811 {
1812 self.needs_visit.insert(handle);
1813 }
1814
1815 match self.validate_expression(
1816 handle,
1817 expr,
1818 fun,
1819 module,
1820 &info,
1821 mod_info,
1822 &local_expr_kind,
1823 ) {
1824 Ok(stages) => info.available_stages &= stages,
1825 Err(source) => {
1826 return Err(FunctionError::Expression { handle, source }
1827 .with_span_handle(handle, &fun.expressions))
1828 }
1829 }
1830 }
1831 }
1832
1833 if self.flags.contains(super::ValidationFlags::BLOCKS) {
1834 let stages = self
1835 .validate_block(
1836 &fun.body,
1837 &BlockContext::new(fun, module, &info, &mod_info.functions, &local_expr_kind),
1838 )?
1839 .stages;
1840 info.available_stages &= stages;
1841
1842 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1843 if let Some(handle) = self.needs_visit.iter().next() {
1844 return Err(FunctionError::UnvisitedExpression(handle)
1845 .with_span_handle(handle, &fun.expressions));
1846 }
1847 }
1848 }
1849 Ok(info)
1850 }
1851}