1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 const COOP_OPS = 0x8;
33 }
34}
35
36#[derive(Clone, Debug)]
38#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
39#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
40#[cfg_attr(test, derive(PartialEq))]
41pub struct Uniformity {
42 pub non_uniform_result: NonUniformResult,
54 pub requirements: UniformityRequirements,
56}
57
58impl Uniformity {
59 const fn new() -> Self {
60 Uniformity {
61 non_uniform_result: None,
62 requirements: UniformityRequirements::empty(),
63 }
64 }
65}
66
67bitflags::bitflags! {
68 #[derive(Clone, Copy, Debug, PartialEq)]
69 struct ExitFlags: u8 {
70 const MAY_RETURN = 0x1;
74 const MAY_KILL = 0x2;
79 }
80}
81
82#[cfg_attr(test, derive(Debug, PartialEq))]
84struct FunctionUniformity {
85 result: Uniformity,
86 exit: ExitFlags,
87}
88
89impl ops::BitOr for FunctionUniformity {
90 type Output = Self;
91 fn bitor(self, other: Self) -> Self {
92 FunctionUniformity {
93 result: Uniformity {
94 non_uniform_result: self
95 .result
96 .non_uniform_result
97 .or(other.result.non_uniform_result),
98 requirements: self.result.requirements | other.result.requirements,
99 },
100 exit: self.exit | other.exit,
101 }
102 }
103}
104
105impl FunctionUniformity {
106 const fn new() -> Self {
107 FunctionUniformity {
108 result: Uniformity::new(),
109 exit: ExitFlags::empty(),
110 }
111 }
112
113 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
115 if self.exit.contains(ExitFlags::MAY_RETURN) {
116 Some(UniformityDisruptor::Return)
117 } else if self.exit.contains(ExitFlags::MAY_KILL) {
118 Some(UniformityDisruptor::Discard)
119 } else {
120 None
121 }
122 }
123}
124
125bitflags::bitflags! {
126 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
128 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
129 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
130 pub struct GlobalUse: u8 {
131 const READ = 0x1;
133 const WRITE = 0x2;
135 const QUERY = 0x4;
137 const ATOMIC = 0x8;
139 }
140}
141
142#[derive(Clone, Debug, Eq, Hash, PartialEq)]
143#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
144#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
145pub struct SamplingKey {
146 pub image: Handle<crate::GlobalVariable>,
147 pub sampler: Handle<crate::GlobalVariable>,
148}
149
150#[derive(Clone, Debug)]
151#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
152#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
153pub struct ExpressionInfo {
155 pub uniformity: Uniformity,
161
162 pub ref_count: usize,
168
169 assignable_global: Option<Handle<crate::GlobalVariable>>,
183
184 pub ty: TypeResolution,
186}
187
188impl ExpressionInfo {
189 const fn new() -> Self {
190 ExpressionInfo {
191 uniformity: Uniformity::new(),
192 ref_count: 0,
193 assignable_global: None,
194 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
196 kind: crate::ScalarKind::Bool,
197 width: 0,
198 })),
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
205#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
206enum GlobalOrArgument {
207 Global(Handle<crate::GlobalVariable>),
208 Argument(u32),
209}
210
211impl GlobalOrArgument {
212 fn from_expression(
213 expression_arena: &Arena<crate::Expression>,
214 expression: Handle<crate::Expression>,
215 ) -> Result<GlobalOrArgument, ExpressionError> {
216 Ok(match expression_arena[expression] {
217 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
218 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
219 crate::Expression::Access { base, .. }
220 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
221 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
222 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
223 },
224 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
225 })
226 }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, Hash)]
230#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
231#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
232struct Sampling {
233 image: GlobalOrArgument,
234 sampler: GlobalOrArgument,
235}
236
237#[derive(Debug, Clone)]
238#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
239#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
240pub struct FunctionInfo {
241 flags: ValidationFlags,
243 pub available_stages: ShaderStages,
245 pub uniformity: Uniformity,
247 pub may_kill: bool,
249
250 pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266 pub global_uses: Box<[GlobalUse]>,
273
274 expressions: Box<[ExpressionInfo]>,
281
282 sampling: crate::FastHashSet<Sampling>,
295
296 pub dual_source_blending: bool,
298
299 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308 pub const fn global_variable_count(&self) -> usize {
309 self.global_uses.len()
310 }
311 pub const fn expression_count(&self) -> usize {
312 self.expressions.len()
313 }
314 pub fn dominates_global_use(&self, other: &Self) -> bool {
315 for (self_global_uses, other_global_uses) in
316 self.global_uses.iter().zip(other.global_uses.iter())
317 {
318 if !self_global_uses.contains(*other_global_uses) {
319 return false;
320 }
321 }
322 true
323 }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327 type Output = GlobalUse;
328 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329 &self.global_uses[handle.index()]
330 }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334 type Output = ExpressionInfo;
335 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336 &self.expressions[handle.index()]
337 }
338}
339
340#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345 Expression(Handle<crate::Expression>),
346 #[error("There is a Return earlier in the control flow of the function")]
347 Return,
348 #[error("There is a Discard earlier in the entry point across all called functions")]
349 Discard,
350}
351
352impl FunctionInfo {
353 #[must_use]
361 fn add_ref_impl(
362 &mut self,
363 expr: Handle<crate::Expression>,
364 global_use: GlobalUse,
365 ) -> NonUniformResult {
366 let info = &mut self.expressions[expr.index()];
367 info.ref_count += 1;
368 if let Some(global) = info.assignable_global {
370 self.global_uses[global.index()] |= global_use;
371 }
372 info.uniformity.non_uniform_result
373 }
374
375 pub(super) fn insert_global_use(
384 &mut self,
385 global_use: GlobalUse,
386 global: Handle<crate::GlobalVariable>,
387 ) {
388 self.global_uses[global.index()] |= global_use;
389 }
390
391 #[must_use]
398 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
399 self.add_ref_impl(expr, GlobalUse::READ)
400 }
401
402 #[must_use]
421 fn add_assignable_ref(
422 &mut self,
423 expr: Handle<crate::Expression>,
424 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
425 ) -> NonUniformResult {
426 let info = &mut self.expressions[expr.index()];
427 info.ref_count += 1;
428 if let Some(global) = info.assignable_global {
431 if let Some(_old) = assignable_global.replace(global) {
432 unreachable!()
433 }
434 }
435 info.uniformity.non_uniform_result
436 }
437
438 fn process_call(
440 &mut self,
441 callee: &Self,
442 arguments: &[Handle<crate::Expression>],
443 expression_arena: &Arena<crate::Expression>,
444 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
445 self.sampling_set
446 .extend(callee.sampling_set.iter().cloned());
447 for sampling in callee.sampling.iter() {
448 let image_storage = match sampling.image {
451 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452 GlobalOrArgument::Argument(i) => {
453 let Some(handle) = arguments.get(i as usize).cloned() else {
454 break;
456 };
457 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458 |source| {
459 FunctionError::Expression { handle, source }
460 .with_span_handle(handle, expression_arena)
461 },
462 )?
463 }
464 };
465
466 let sampler_storage = match sampling.sampler {
467 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
468 GlobalOrArgument::Argument(i) => {
469 let Some(handle) = arguments.get(i as usize).cloned() else {
470 break;
472 };
473 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
474 |source| {
475 FunctionError::Expression { handle, source }
476 .with_span_handle(handle, expression_arena)
477 },
478 )?
479 }
480 };
481
482 match (image_storage, sampler_storage) {
487 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
488 self.sampling_set.insert(SamplingKey { image, sampler });
489 }
490 (image, sampler) => {
491 self.sampling.insert(Sampling { image, sampler });
492 }
493 }
494 }
495
496 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
498 *mine |= *other;
499 }
500
501 Ok(FunctionUniformity {
502 result: callee.uniformity.clone(),
503 exit: if callee.may_kill {
504 ExitFlags::MAY_KILL
505 } else {
506 ExitFlags::empty()
507 },
508 })
509 }
510
511 #[allow(clippy::or_fun_call)]
531 fn process_expression(
532 &mut self,
533 handle: Handle<crate::Expression>,
534 expression_arena: &Arena<crate::Expression>,
535 other_functions: &[FunctionInfo],
536 resolve_context: &ResolveContext,
537 capabilities: super::Capabilities,
538 ) -> Result<(), ExpressionError> {
539 use crate::{Expression as E, SampleLevel as Sl};
540
541 let expression = &expression_arena[handle];
542 let mut assignable_global = None;
543 let uniformity = match *expression {
544 E::Access { base, index } => {
545 let base_ty = self[base].ty.inner_with(resolve_context.types);
546
547 let mut needed_caps = super::Capabilities::empty();
549 let is_binding_array = match *base_ty {
550 crate::TypeInner::BindingArray {
551 base: array_element_ty_handle,
552 ..
553 } => {
554 let array_element_ty =
556 &resolve_context.types[array_element_ty_handle].inner;
557
558 needed_caps |= match *array_element_ty {
559 crate::TypeInner::Image { class, .. } => match class {
561 crate::ImageClass::Storage { .. } => {
562 super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
563 }
564 _ => {
565 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
566 }
567 },
568 crate::TypeInner::Sampler { .. } => {
569 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
570 }
571 _ => {
573 if let E::GlobalVariable(global_handle) = expression_arena[base] {
574 let global = &resolve_context.global_vars[global_handle];
575 match global.space {
576 crate::AddressSpace::Uniform => {
577 super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
578 }
579 crate::AddressSpace::Storage { .. } => {
580 super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
581 }
582 _ => unreachable!(),
583 }
584 } else {
585 unreachable!()
586 }
587 }
588 };
589
590 true
591 }
592 _ => false,
593 };
594
595 if self[index].uniformity.non_uniform_result.is_some()
596 && !capabilities.contains(needed_caps)
597 && is_binding_array
598 {
599 return Err(ExpressionError::MissingCapabilities(needed_caps));
600 }
601
602 Uniformity {
603 non_uniform_result: self
604 .add_assignable_ref(base, &mut assignable_global)
605 .or(self.add_ref(index)),
606 requirements: UniformityRequirements::empty(),
607 }
608 }
609 E::AccessIndex { base, .. } => Uniformity {
610 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
611 requirements: UniformityRequirements::empty(),
612 },
613 E::Splat { size: _, value } => Uniformity {
615 non_uniform_result: self.add_ref(value),
616 requirements: UniformityRequirements::empty(),
617 },
618 E::Swizzle { vector, .. } => Uniformity {
619 non_uniform_result: self.add_ref(vector),
620 requirements: UniformityRequirements::empty(),
621 },
622 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
623 E::Compose { ref components, .. } => {
624 let non_uniform_result = components
625 .iter()
626 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
627 Uniformity {
628 non_uniform_result,
629 requirements: UniformityRequirements::empty(),
630 }
631 }
632 E::FunctionArgument(index) => {
634 let arg = &resolve_context.arguments[index as usize];
635 let uniform = match arg.binding {
636 Some(crate::Binding::BuiltIn(
637 crate::BuiltIn::WorkGroupId
639 | crate::BuiltIn::WorkGroupSize
640 | crate::BuiltIn::NumWorkGroups,
641 )) => true,
642 _ => false,
643 };
644 Uniformity {
645 non_uniform_result: if uniform { None } else { Some(handle) },
646 requirements: UniformityRequirements::empty(),
647 }
648 }
649 E::GlobalVariable(gh) => {
651 use crate::AddressSpace as As;
652 assignable_global = Some(gh);
653 let var = &resolve_context.global_vars[gh];
654 let uniform = match var.space {
655 As::Function | As::Private | As::RayPayload | As::IncomingRayPayload => false,
657 As::WorkGroup | As::TaskPayload => true,
660 As::Uniform | As::Immediate => true,
662 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
664 As::Handle => false,
665 };
666 Uniformity {
667 non_uniform_result: if uniform { None } else { Some(handle) },
668 requirements: UniformityRequirements::empty(),
669 }
670 }
671 E::LocalVariable(_) => Uniformity {
672 non_uniform_result: Some(handle),
673 requirements: UniformityRequirements::empty(),
674 },
675 E::Load { pointer } => Uniformity {
676 non_uniform_result: self.add_ref(pointer),
677 requirements: UniformityRequirements::empty(),
678 },
679 E::ImageSample {
680 image,
681 sampler,
682 gather: _,
683 coordinate,
684 array_index,
685 offset,
686 level,
687 depth_ref,
688 clamp_to_edge: _,
689 } => {
690 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
691 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
692
693 match (image_storage, sampler_storage) {
694 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
695 self.sampling_set.insert(SamplingKey { image, sampler });
696 }
697 _ => {
698 self.sampling.insert(Sampling {
699 image: image_storage,
700 sampler: sampler_storage,
701 });
702 }
703 }
704
705 let array_nur = array_index.and_then(|h| self.add_ref(h));
707 let level_nur = match level {
708 Sl::Auto | Sl::Zero => None,
709 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
710 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
711 };
712 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
713 let offset_nur = offset.and_then(|h| self.add_ref(h));
714 Uniformity {
715 non_uniform_result: self
716 .add_ref(image)
717 .or(self.add_ref(sampler))
718 .or(self.add_ref(coordinate))
719 .or(array_nur)
720 .or(level_nur)
721 .or(dref_nur)
722 .or(offset_nur),
723 requirements: if level.implicit_derivatives() {
724 UniformityRequirements::IMPLICIT_LEVEL
725 } else {
726 UniformityRequirements::empty()
727 },
728 }
729 }
730 E::ImageLoad {
731 image,
732 coordinate,
733 array_index,
734 sample,
735 level,
736 } => {
737 let array_nur = array_index.and_then(|h| self.add_ref(h));
738 let sample_nur = sample.and_then(|h| self.add_ref(h));
739 let level_nur = level.and_then(|h| self.add_ref(h));
740 Uniformity {
741 non_uniform_result: self
742 .add_ref(image)
743 .or(self.add_ref(coordinate))
744 .or(array_nur)
745 .or(sample_nur)
746 .or(level_nur),
747 requirements: UniformityRequirements::empty(),
748 }
749 }
750 E::ImageQuery { image, query } => {
751 let query_nur = match query {
752 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
753 _ => None,
754 };
755 Uniformity {
756 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
757 requirements: UniformityRequirements::empty(),
758 }
759 }
760 E::Unary { expr, .. } => Uniformity {
761 non_uniform_result: self.add_ref(expr),
762 requirements: UniformityRequirements::empty(),
763 },
764 E::Binary { left, right, .. } => Uniformity {
765 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
766 requirements: UniformityRequirements::empty(),
767 },
768 E::Select {
769 condition,
770 accept,
771 reject,
772 } => Uniformity {
773 non_uniform_result: self
774 .add_ref(condition)
775 .or(self.add_ref(accept))
776 .or(self.add_ref(reject)),
777 requirements: UniformityRequirements::empty(),
778 },
779 E::Derivative { expr, .. } => Uniformity {
781 non_uniform_result: self.add_ref(expr),
783 requirements: UniformityRequirements::DERIVATIVE,
784 },
785 E::Relational { argument, .. } => Uniformity {
786 non_uniform_result: self.add_ref(argument),
787 requirements: UniformityRequirements::empty(),
788 },
789 E::Math {
790 fun: _,
791 arg,
792 arg1,
793 arg2,
794 arg3,
795 } => {
796 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
797 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
798 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
799 Uniformity {
800 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
801 requirements: UniformityRequirements::empty(),
802 }
803 }
804 E::As { expr, .. } => Uniformity {
805 non_uniform_result: self.add_ref(expr),
806 requirements: UniformityRequirements::empty(),
807 },
808 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
809 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
810 non_uniform_result: Some(handle),
811 requirements: UniformityRequirements::empty(),
812 },
813 E::WorkGroupUniformLoadResult { .. } => Uniformity {
814 non_uniform_result: None,
816 requirements: UniformityRequirements::empty(),
819 },
820 E::ArrayLength(expr) => Uniformity {
821 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
822 requirements: UniformityRequirements::empty(),
823 },
824 E::RayQueryGetIntersection {
825 query,
826 committed: _,
827 } => Uniformity {
828 non_uniform_result: self.add_ref(query),
829 requirements: UniformityRequirements::empty(),
830 },
831 E::SubgroupBallotResult => Uniformity {
832 non_uniform_result: Some(handle),
833 requirements: UniformityRequirements::empty(),
834 },
835 E::SubgroupOperationResult { .. } => Uniformity {
836 non_uniform_result: Some(handle),
837 requirements: UniformityRequirements::empty(),
838 },
839 E::RayQueryVertexPositions {
840 query,
841 committed: _,
842 } => Uniformity {
843 non_uniform_result: self.add_ref(query),
844 requirements: UniformityRequirements::empty(),
845 },
846 E::CooperativeLoad { ref data, .. } => Uniformity {
847 non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)),
848 requirements: UniformityRequirements::COOP_OPS,
849 },
850 E::CooperativeMultiplyAdd { a, b, c } => Uniformity {
851 non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))),
852 requirements: UniformityRequirements::COOP_OPS,
853 },
854 };
855
856 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
857 self.expressions[handle.index()] = ExpressionInfo {
858 uniformity,
859 ref_count: 0,
860 assignable_global,
861 ty,
862 };
863 Ok(())
864 }
865
866 #[allow(clippy::or_fun_call)]
876 fn process_block(
877 &mut self,
878 statements: &crate::Block,
879 other_functions: &[FunctionInfo],
880 mut disruptor: Option<UniformityDisruptor>,
881 expression_arena: &Arena<crate::Expression>,
882 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
883 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
884 use crate::Statement as S;
885
886 let mut combined_uniformity = FunctionUniformity::new();
887 for statement in statements {
888 let uniformity = match *statement {
889 S::Emit(ref range) => {
890 let mut requirements = UniformityRequirements::empty();
891 for expr in range.clone() {
892 let req = self.expressions[expr.index()].uniformity.requirements;
893 if self
894 .flags
895 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
896 && !req.is_empty()
897 {
898 if let Some(cause) = disruptor {
899 let severity = DiagnosticFilterNode::search(
900 self.diagnostic_filter_leaf,
901 diagnostic_filter_arena,
902 StandardFilterableTriggeringRule::DerivativeUniformity,
903 );
904 severity.report_diag(
905 FunctionError::NonUniformControlFlow(req, expr, cause)
906 .with_span_handle(expr, expression_arena),
907 |e, level| log::log!(level, "{e}"),
913 )?;
914 }
915 }
916 requirements |= req;
917 }
918 FunctionUniformity {
919 result: Uniformity {
920 non_uniform_result: None,
921 requirements,
922 },
923 exit: ExitFlags::empty(),
924 }
925 }
926 S::Break | S::Continue => FunctionUniformity::new(),
927 S::Kill => FunctionUniformity {
928 result: Uniformity::new(),
929 exit: if disruptor.is_some() {
930 ExitFlags::MAY_KILL
931 } else {
932 ExitFlags::empty()
933 },
934 },
935 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
936 result: Uniformity {
937 non_uniform_result: None,
938 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
939 },
940 exit: ExitFlags::empty(),
941 },
942 S::WorkGroupUniformLoad { pointer, .. } => {
943 let _condition_nur = self.add_ref(pointer);
944
945 FunctionUniformity {
964 result: Uniformity {
965 non_uniform_result: None,
966 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
967 },
968 exit: ExitFlags::empty(),
969 }
970 }
971 S::Block(ref b) => self.process_block(
972 b,
973 other_functions,
974 disruptor,
975 expression_arena,
976 diagnostic_filter_arena,
977 )?,
978 S::If {
979 condition,
980 ref accept,
981 ref reject,
982 } => {
983 let condition_nur = self.add_ref(condition);
984 let branch_disruptor =
985 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
986 let accept_uniformity = self.process_block(
987 accept,
988 other_functions,
989 branch_disruptor,
990 expression_arena,
991 diagnostic_filter_arena,
992 )?;
993 let reject_uniformity = self.process_block(
994 reject,
995 other_functions,
996 branch_disruptor,
997 expression_arena,
998 diagnostic_filter_arena,
999 )?;
1000 accept_uniformity | reject_uniformity
1001 }
1002 S::Switch {
1003 selector,
1004 ref cases,
1005 } => {
1006 let selector_nur = self.add_ref(selector);
1007 let branch_disruptor =
1008 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1009 let mut uniformity = FunctionUniformity::new();
1010 let mut case_disruptor = branch_disruptor;
1011 for case in cases.iter() {
1012 let case_uniformity = self.process_block(
1013 &case.body,
1014 other_functions,
1015 case_disruptor,
1016 expression_arena,
1017 diagnostic_filter_arena,
1018 )?;
1019 case_disruptor = if case.fall_through {
1020 case_disruptor.or(case_uniformity.exit_disruptor())
1021 } else {
1022 branch_disruptor
1023 };
1024 uniformity = uniformity | case_uniformity;
1025 }
1026 uniformity
1027 }
1028 S::Loop {
1029 ref body,
1030 ref continuing,
1031 break_if,
1032 } => {
1033 let body_uniformity = self.process_block(
1034 body,
1035 other_functions,
1036 disruptor,
1037 expression_arena,
1038 diagnostic_filter_arena,
1039 )?;
1040 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1041 let continuing_uniformity = self.process_block(
1042 continuing,
1043 other_functions,
1044 continuing_disruptor,
1045 expression_arena,
1046 diagnostic_filter_arena,
1047 )?;
1048 if let Some(expr) = break_if {
1049 let _ = self.add_ref(expr);
1050 }
1051 body_uniformity | continuing_uniformity
1052 }
1053 S::Return { value } => FunctionUniformity {
1054 result: Uniformity {
1055 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1056 requirements: UniformityRequirements::empty(),
1057 },
1058 exit: if disruptor.is_some() {
1059 ExitFlags::MAY_RETURN
1060 } else {
1061 ExitFlags::empty()
1062 },
1063 },
1064 S::Store { pointer, value } => {
1068 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1069 let _ = self.add_ref(value);
1070 FunctionUniformity::new()
1071 }
1072 S::ImageStore {
1073 image,
1074 coordinate,
1075 array_index,
1076 value,
1077 } => {
1078 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1079 if let Some(expr) = array_index {
1080 let _ = self.add_ref(expr);
1081 }
1082 let _ = self.add_ref(coordinate);
1083 let _ = self.add_ref(value);
1084 FunctionUniformity::new()
1085 }
1086 S::Call {
1087 function,
1088 ref arguments,
1089 result: _,
1090 } => {
1091 for &argument in arguments {
1092 let _ = self.add_ref(argument);
1093 }
1094 let info = &other_functions[function.index()];
1095 self.process_call(info, arguments, expression_arena)?
1097 }
1098 S::Atomic {
1099 pointer,
1100 ref fun,
1101 value,
1102 result: _,
1103 } => {
1104 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1105 let _ = self.add_ref(value);
1106 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1107 let _ = self.add_ref(cmp);
1108 }
1109 FunctionUniformity::new()
1110 }
1111 S::ImageAtomic {
1112 image,
1113 coordinate,
1114 array_index,
1115 fun: _,
1116 value,
1117 } => {
1118 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1119 let _ = self.add_ref(coordinate);
1120 if let Some(expr) = array_index {
1121 let _ = self.add_ref(expr);
1122 }
1123 let _ = self.add_ref(value);
1124 FunctionUniformity::new()
1125 }
1126 S::RayQuery { query, ref fun } => {
1127 let _ = self.add_ref(query);
1128 match *fun {
1129 crate::RayQueryFunction::Initialize {
1130 acceleration_structure,
1131 descriptor,
1132 } => {
1133 let _ = self.add_ref(acceleration_structure);
1134 let _ = self.add_ref(descriptor);
1135 }
1136 crate::RayQueryFunction::Proceed { result: _ } => {}
1137 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1138 let _ = self.add_ref(hit_t);
1139 }
1140 crate::RayQueryFunction::ConfirmIntersection => {}
1141 crate::RayQueryFunction::Terminate => {}
1142 }
1143 FunctionUniformity::new()
1144 }
1145 S::SubgroupBallot {
1146 result: _,
1147 predicate,
1148 } => {
1149 if let Some(predicate) = predicate {
1150 let _ = self.add_ref(predicate);
1151 }
1152 FunctionUniformity::new()
1153 }
1154 S::SubgroupCollectiveOperation {
1155 op: _,
1156 collective_op: _,
1157 argument,
1158 result: _,
1159 } => {
1160 let _ = self.add_ref(argument);
1161 FunctionUniformity::new()
1162 }
1163 S::SubgroupGather {
1164 mode,
1165 argument,
1166 result: _,
1167 } => {
1168 let _ = self.add_ref(argument);
1169 match mode {
1170 crate::GatherMode::BroadcastFirst => {}
1171 crate::GatherMode::Broadcast(index)
1172 | crate::GatherMode::Shuffle(index)
1173 | crate::GatherMode::ShuffleDown(index)
1174 | crate::GatherMode::ShuffleUp(index)
1175 | crate::GatherMode::ShuffleXor(index)
1176 | crate::GatherMode::QuadBroadcast(index) => {
1177 let _ = self.add_ref(index);
1178 }
1179 crate::GatherMode::QuadSwap(_) => {}
1180 }
1181 FunctionUniformity::new()
1182 }
1183 S::CooperativeStore { target, ref data } => FunctionUniformity {
1184 result: Uniformity {
1185 non_uniform_result: self
1186 .add_ref(target)
1187 .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE))
1188 .or(self.add_ref(data.stride)),
1189 requirements: UniformityRequirements::COOP_OPS,
1190 },
1191 exit: ExitFlags::empty(),
1192 },
1193 S::RayPipelineFunction(ref fun) => {
1194 match *fun {
1195 crate::RayPipelineFunction::TraceRay {
1196 acceleration_structure,
1197 descriptor,
1198 payload,
1199 } => {
1200 let _ = self.add_ref(acceleration_structure);
1201 let _ = self.add_ref(descriptor);
1202 let _ = self.add_ref(payload);
1203 }
1204 }
1205 FunctionUniformity::new()
1206 }
1207 };
1208
1209 disruptor = disruptor.or(uniformity.exit_disruptor());
1210 combined_uniformity = combined_uniformity | uniformity;
1211 }
1212 Ok(combined_uniformity)
1213 }
1214}
1215
1216impl ModuleInfo {
1217 pub(super) fn process_const_expression(
1219 &mut self,
1220 handle: Handle<crate::Expression>,
1221 resolve_context: &ResolveContext,
1222 gctx: crate::proc::GlobalCtx,
1223 ) -> Result<(), super::ConstExpressionError> {
1224 self.const_expression_types[handle.index()] =
1225 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1226 Ok(())
1227 }
1228
1229 pub(super) fn process_function(
1232 &self,
1233 fun: &crate::Function,
1234 module: &crate::Module,
1235 flags: ValidationFlags,
1236 capabilities: super::Capabilities,
1237 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1238 let mut info = FunctionInfo {
1239 flags,
1240 available_stages: ShaderStages::all(),
1241 uniformity: Uniformity::new(),
1242 may_kill: false,
1243 sampling_set: crate::FastHashSet::default(),
1244 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1245 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1246 sampling: crate::FastHashSet::default(),
1247 dual_source_blending: false,
1248 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1249 };
1250 let resolve_context =
1251 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1252
1253 for (handle, _) in fun.expressions.iter() {
1254 if let Err(source) = info.process_expression(
1255 handle,
1256 &fun.expressions,
1257 &self.functions,
1258 &resolve_context,
1259 capabilities,
1260 ) {
1261 return Err(FunctionError::Expression { handle, source }
1262 .with_span_handle(handle, &fun.expressions));
1263 }
1264 }
1265
1266 for (_, expr) in fun.local_variables.iter() {
1267 if let Some(init) = expr.init {
1268 let _ = info.add_ref(init);
1269 }
1270 }
1271
1272 let uniformity = info.process_block(
1273 &fun.body,
1274 &self.functions,
1275 None,
1276 &fun.expressions,
1277 &module.diagnostic_filters,
1278 )?;
1279 info.uniformity = uniformity.result;
1280 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1281
1282 for &handle in fun.named_expressions.keys() {
1288 if let Some(global) = info[handle].assignable_global {
1289 if info.global_uses[global.index()].is_empty() {
1290 info.global_uses[global.index()] = GlobalUse::QUERY;
1291 }
1292 }
1293 }
1294
1295 Ok(info)
1296 }
1297
1298 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1299 &self.entry_points[index]
1300 }
1301}
1302
1303#[test]
1304fn uniform_control_flow() {
1305 use crate::{Expression as E, Statement as S};
1306
1307 let mut type_arena = crate::UniqueArena::new();
1308 let ty = type_arena.insert(
1309 crate::Type {
1310 name: None,
1311 inner: crate::TypeInner::Vector {
1312 size: crate::VectorSize::Bi,
1313 scalar: crate::Scalar::F32,
1314 },
1315 },
1316 Default::default(),
1317 );
1318 let mut global_var_arena = Arena::new();
1319 let non_uniform_global = global_var_arena.append(
1320 crate::GlobalVariable {
1321 name: None,
1322 init: None,
1323 ty,
1324 space: crate::AddressSpace::Handle,
1325 binding: None,
1326 },
1327 Default::default(),
1328 );
1329 let uniform_global = global_var_arena.append(
1330 crate::GlobalVariable {
1331 name: None,
1332 init: None,
1333 ty,
1334 binding: None,
1335 space: crate::AddressSpace::Uniform,
1336 },
1337 Default::default(),
1338 );
1339
1340 let mut expressions = Arena::new();
1341 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1343 let derivative_expr = expressions.append(
1345 E::Derivative {
1346 axis: crate::DerivativeAxis::X,
1347 ctrl: crate::DerivativeControl::None,
1348 expr: constant_expr,
1349 },
1350 Default::default(),
1351 );
1352 let emit_range_constant_derivative = expressions.range_from(0);
1353 let non_uniform_global_expr =
1354 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1355 let uniform_global_expr =
1356 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1357 let emit_range_globals = expressions.range_from(2);
1358
1359 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1361 let access_expr = expressions.append(
1363 E::AccessIndex {
1364 base: non_uniform_global_expr,
1365 index: 1,
1366 },
1367 Default::default(),
1368 );
1369 let emit_range_query_access_globals = expressions.range_from(2);
1370
1371 let mut info = FunctionInfo {
1372 flags: ValidationFlags::all(),
1373 available_stages: ShaderStages::all(),
1374 uniformity: Uniformity::new(),
1375 may_kill: false,
1376 sampling_set: crate::FastHashSet::default(),
1377 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1378 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1379 sampling: crate::FastHashSet::default(),
1380 dual_source_blending: false,
1381 diagnostic_filter_leaf: None,
1382 };
1383 let resolve_context = ResolveContext {
1384 constants: &Arena::new(),
1385 overrides: &Arena::new(),
1386 types: &type_arena,
1387 special_types: &crate::SpecialTypes::default(),
1388 global_vars: &global_var_arena,
1389 local_vars: &Arena::new(),
1390 functions: &Arena::new(),
1391 arguments: &[],
1392 };
1393 for (handle, _) in expressions.iter() {
1394 info.process_expression(
1395 handle,
1396 &expressions,
1397 &[],
1398 &resolve_context,
1399 super::Capabilities::empty(),
1400 )
1401 .unwrap();
1402 }
1403 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1404 assert_eq!(info[uniform_global_expr].ref_count, 1);
1405 assert_eq!(info[query_expr].ref_count, 0);
1406 assert_eq!(info[access_expr].ref_count, 0);
1407 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1408 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1409
1410 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1411 let stmt_if_uniform = S::If {
1412 condition: uniform_global_expr,
1413 accept: crate::Block::new(),
1414 reject: vec![
1415 S::Emit(emit_range_constant_derivative.clone()),
1416 S::Store {
1417 pointer: constant_expr,
1418 value: derivative_expr,
1419 },
1420 ]
1421 .into(),
1422 };
1423 assert_eq!(
1424 info.process_block(
1425 &vec![stmt_emit1, stmt_if_uniform].into(),
1426 &[],
1427 None,
1428 &expressions,
1429 &Arena::new(),
1430 ),
1431 Ok(FunctionUniformity {
1432 result: Uniformity {
1433 non_uniform_result: None,
1434 requirements: UniformityRequirements::DERIVATIVE,
1435 },
1436 exit: ExitFlags::empty(),
1437 }),
1438 );
1439 assert_eq!(info[constant_expr].ref_count, 2);
1440 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1441
1442 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1443 let stmt_if_non_uniform = S::If {
1444 condition: non_uniform_global_expr,
1445 accept: vec![
1446 S::Emit(emit_range_constant_derivative),
1447 S::Store {
1448 pointer: constant_expr,
1449 value: derivative_expr,
1450 },
1451 ]
1452 .into(),
1453 reject: crate::Block::new(),
1454 };
1455 {
1456 let block_info = info.process_block(
1457 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1458 &[],
1459 None,
1460 &expressions,
1461 &Arena::new(),
1462 );
1463 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1464 assert_eq!(info[derivative_expr].ref_count, 2);
1465 } else {
1466 assert_eq!(
1467 block_info,
1468 Err(FunctionError::NonUniformControlFlow(
1469 UniformityRequirements::DERIVATIVE,
1470 derivative_expr,
1471 UniformityDisruptor::Expression(non_uniform_global_expr)
1472 )
1473 .with_span()),
1474 );
1475 assert_eq!(info[derivative_expr].ref_count, 1);
1476
1477 let mut diagnostic_filters = Arena::new();
1479 let diagnostic_filter_leaf = diagnostic_filters.append(
1480 DiagnosticFilterNode {
1481 inner: crate::diagnostic_filter::DiagnosticFilter {
1482 new_severity: crate::diagnostic_filter::Severity::Off,
1483 triggering_rule:
1484 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1485 StandardFilterableTriggeringRule::DerivativeUniformity,
1486 ),
1487 },
1488 parent: None,
1489 },
1490 crate::Span::default(),
1491 );
1492 let mut info = FunctionInfo {
1493 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1494 ..info.clone()
1495 };
1496
1497 let block_info = info.process_block(
1498 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1499 &[],
1500 None,
1501 &expressions,
1502 &diagnostic_filters,
1503 );
1504 assert_eq!(
1505 block_info,
1506 Ok(FunctionUniformity {
1507 result: Uniformity {
1508 non_uniform_result: None,
1509 requirements: UniformityRequirements::DERIVATIVE,
1510 },
1511 exit: ExitFlags::empty()
1512 }),
1513 );
1514 assert_eq!(info[derivative_expr].ref_count, 2);
1515 }
1516 }
1517 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1518
1519 let stmt_emit3 = S::Emit(emit_range_globals);
1520 let stmt_return_non_uniform = S::Return {
1521 value: Some(non_uniform_global_expr),
1522 };
1523 assert_eq!(
1524 info.process_block(
1525 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1526 &[],
1527 Some(UniformityDisruptor::Return),
1528 &expressions,
1529 &Arena::new(),
1530 ),
1531 Ok(FunctionUniformity {
1532 result: Uniformity {
1533 non_uniform_result: Some(non_uniform_global_expr),
1534 requirements: UniformityRequirements::empty(),
1535 },
1536 exit: ExitFlags::MAY_RETURN,
1537 }),
1538 );
1539 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1540
1541 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1543 let stmt_assign = S::Store {
1544 pointer: access_expr,
1545 value: query_expr,
1546 };
1547 let stmt_return_pointer = S::Return {
1548 value: Some(access_expr),
1549 };
1550 let stmt_kill = S::Kill;
1551 assert_eq!(
1552 info.process_block(
1553 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1554 &[],
1555 Some(UniformityDisruptor::Discard),
1556 &expressions,
1557 &Arena::new(),
1558 ),
1559 Ok(FunctionUniformity {
1560 result: Uniformity {
1561 non_uniform_result: Some(non_uniform_global_expr),
1562 requirements: UniformityRequirements::empty(),
1563 },
1564 exit: ExitFlags::all(),
1565 }),
1566 );
1567 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1568}