1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
78 }
79}
80
81#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84 result: Uniformity,
85 exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89 type Output = Self;
90 fn bitor(self, other: Self) -> Self {
91 FunctionUniformity {
92 result: Uniformity {
93 non_uniform_result: self
94 .result
95 .non_uniform_result
96 .or(other.result.non_uniform_result),
97 requirements: self.result.requirements | other.result.requirements,
98 },
99 exit: self.exit | other.exit,
100 }
101 }
102}
103
104impl FunctionUniformity {
105 const fn new() -> Self {
106 FunctionUniformity {
107 result: Uniformity::new(),
108 exit: ExitFlags::empty(),
109 }
110 }
111
112 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114 if self.exit.contains(ExitFlags::MAY_RETURN) {
115 Some(UniformityDisruptor::Return)
116 } else if self.exit.contains(ExitFlags::MAY_KILL) {
117 Some(UniformityDisruptor::Discard)
118 } else {
119 None
120 }
121 }
122}
123
124bitflags::bitflags! {
125 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct GlobalUse: u8 {
130 const READ = 0x1;
132 const WRITE = 0x2;
134 const QUERY = 0x4;
136 const ATOMIC = 0x8;
138 }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145 pub image: Handle<crate::GlobalVariable>,
146 pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152pub struct ExpressionInfo {
154 pub uniformity: Uniformity,
160
161 pub ref_count: usize,
167
168 assignable_global: Option<Handle<crate::GlobalVariable>>,
182
183 pub ty: TypeResolution,
185}
186
187impl ExpressionInfo {
188 const fn new() -> Self {
189 ExpressionInfo {
190 uniformity: Uniformity::new(),
191 ref_count: 0,
192 assignable_global: None,
193 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
195 kind: crate::ScalarKind::Bool,
196 width: 0,
197 })),
198 }
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
204#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
205enum GlobalOrArgument {
206 Global(Handle<crate::GlobalVariable>),
207 Argument(u32),
208}
209
210impl GlobalOrArgument {
211 fn from_expression(
212 expression_arena: &Arena<crate::Expression>,
213 expression: Handle<crate::Expression>,
214 ) -> Result<GlobalOrArgument, ExpressionError> {
215 Ok(match expression_arena[expression] {
216 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
217 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
218 crate::Expression::Access { base, .. }
219 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
220 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
221 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
222 },
223 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
224 })
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
230#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
231struct Sampling {
232 image: GlobalOrArgument,
233 sampler: GlobalOrArgument,
234}
235
236#[derive(Debug, Clone)]
237#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
238#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
239pub struct FunctionInfo {
240 #[allow(dead_code)]
242 flags: ValidationFlags,
243 pub available_stages: ShaderStages,
245 pub uniformity: Uniformity,
247 pub may_kill: bool,
249
250 pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266 global_uses: Box<[GlobalUse]>,
273
274 expressions: Box<[ExpressionInfo]>,
281
282 sampling: crate::FastHashSet<Sampling>,
295
296 pub dual_source_blending: bool,
298
299 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308 pub const fn global_variable_count(&self) -> usize {
309 self.global_uses.len()
310 }
311 pub const fn expression_count(&self) -> usize {
312 self.expressions.len()
313 }
314 pub fn dominates_global_use(&self, other: &Self) -> bool {
315 for (self_global_uses, other_global_uses) in
316 self.global_uses.iter().zip(other.global_uses.iter())
317 {
318 if !self_global_uses.contains(*other_global_uses) {
319 return false;
320 }
321 }
322 true
323 }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327 type Output = GlobalUse;
328 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329 &self.global_uses[handle.index()]
330 }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334 type Output = ExpressionInfo;
335 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336 &self.expressions[handle.index()]
337 }
338}
339
340#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345 Expression(Handle<crate::Expression>),
346 #[error("There is a Return earlier in the control flow of the function")]
347 Return,
348 #[error("There is a Discard earlier in the entry point across all called functions")]
349 Discard,
350}
351
352impl FunctionInfo {
353 #[must_use]
361 fn add_ref_impl(
362 &mut self,
363 expr: Handle<crate::Expression>,
364 global_use: GlobalUse,
365 ) -> NonUniformResult {
366 let info = &mut self.expressions[expr.index()];
367 info.ref_count += 1;
368 if let Some(global) = info.assignable_global {
370 self.global_uses[global.index()] |= global_use;
371 }
372 info.uniformity.non_uniform_result
373 }
374
375 #[must_use]
382 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
383 self.add_ref_impl(expr, GlobalUse::READ)
384 }
385
386 #[must_use]
405 fn add_assignable_ref(
406 &mut self,
407 expr: Handle<crate::Expression>,
408 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
409 ) -> NonUniformResult {
410 let info = &mut self.expressions[expr.index()];
411 info.ref_count += 1;
412 if let Some(global) = info.assignable_global {
415 if let Some(_old) = assignable_global.replace(global) {
416 unreachable!()
417 }
418 }
419 info.uniformity.non_uniform_result
420 }
421
422 fn process_call(
424 &mut self,
425 callee: &Self,
426 arguments: &[Handle<crate::Expression>],
427 expression_arena: &Arena<crate::Expression>,
428 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
429 self.sampling_set
430 .extend(callee.sampling_set.iter().cloned());
431 for sampling in callee.sampling.iter() {
432 let image_storage = match sampling.image {
435 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
436 GlobalOrArgument::Argument(i) => {
437 let Some(handle) = arguments.get(i as usize).cloned() else {
438 break;
440 };
441 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
442 |source| {
443 FunctionError::Expression { handle, source }
444 .with_span_handle(handle, expression_arena)
445 },
446 )?
447 }
448 };
449
450 let sampler_storage = match sampling.sampler {
451 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452 GlobalOrArgument::Argument(i) => {
453 let Some(handle) = arguments.get(i as usize).cloned() else {
454 break;
456 };
457 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458 |source| {
459 FunctionError::Expression { handle, source }
460 .with_span_handle(handle, expression_arena)
461 },
462 )?
463 }
464 };
465
466 match (image_storage, sampler_storage) {
471 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
472 self.sampling_set.insert(SamplingKey { image, sampler });
473 }
474 (image, sampler) => {
475 self.sampling.insert(Sampling { image, sampler });
476 }
477 }
478 }
479
480 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
482 *mine |= *other;
483 }
484
485 Ok(FunctionUniformity {
486 result: callee.uniformity.clone(),
487 exit: if callee.may_kill {
488 ExitFlags::MAY_KILL
489 } else {
490 ExitFlags::empty()
491 },
492 })
493 }
494
495 #[allow(clippy::or_fun_call)]
515 fn process_expression(
516 &mut self,
517 handle: Handle<crate::Expression>,
518 expression_arena: &Arena<crate::Expression>,
519 other_functions: &[FunctionInfo],
520 resolve_context: &ResolveContext,
521 capabilities: super::Capabilities,
522 ) -> Result<(), ExpressionError> {
523 use crate::{Expression as E, SampleLevel as Sl};
524
525 let expression = &expression_arena[handle];
526 let mut assignable_global = None;
527 let uniformity = match *expression {
528 E::Access { base, index } => {
529 let base_ty = self[base].ty.inner_with(resolve_context.types);
530
531 let mut needed_caps = super::Capabilities::empty();
533 let is_binding_array = match *base_ty {
534 crate::TypeInner::BindingArray {
535 base: array_element_ty_handle,
536 ..
537 } => {
538 let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
540 let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
541 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
542 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
543
544 let array_element_ty =
546 &resolve_context.types[array_element_ty_handle].inner;
547
548 needed_caps |= match *array_element_ty {
549 crate::TypeInner::Image { class, .. } => match class {
551 crate::ImageClass::Storage { .. } => sto,
552 _ => st_sb,
553 },
554 crate::TypeInner::Sampler { .. } => sampler,
555 _ => {
557 if let E::GlobalVariable(global_handle) = expression_arena[base] {
558 let global = &resolve_context.global_vars[global_handle];
559 match global.space {
560 crate::AddressSpace::Uniform => uni,
561 crate::AddressSpace::Storage { .. } => st_sb,
562 _ => unreachable!(),
563 }
564 } else {
565 unreachable!()
566 }
567 }
568 };
569
570 true
571 }
572 _ => false,
573 };
574
575 if self[index].uniformity.non_uniform_result.is_some()
576 && !capabilities.contains(needed_caps)
577 && is_binding_array
578 {
579 return Err(ExpressionError::MissingCapabilities(needed_caps));
580 }
581
582 Uniformity {
583 non_uniform_result: self
584 .add_assignable_ref(base, &mut assignable_global)
585 .or(self.add_ref(index)),
586 requirements: UniformityRequirements::empty(),
587 }
588 }
589 E::AccessIndex { base, .. } => Uniformity {
590 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
591 requirements: UniformityRequirements::empty(),
592 },
593 E::Splat { size: _, value } => Uniformity {
595 non_uniform_result: self.add_ref(value),
596 requirements: UniformityRequirements::empty(),
597 },
598 E::Swizzle { vector, .. } => Uniformity {
599 non_uniform_result: self.add_ref(vector),
600 requirements: UniformityRequirements::empty(),
601 },
602 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
603 E::Compose { ref components, .. } => {
604 let non_uniform_result = components
605 .iter()
606 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
607 Uniformity {
608 non_uniform_result,
609 requirements: UniformityRequirements::empty(),
610 }
611 }
612 E::FunctionArgument(index) => {
614 let arg = &resolve_context.arguments[index as usize];
615 let uniform = match arg.binding {
616 Some(crate::Binding::BuiltIn(
617 crate::BuiltIn::WorkGroupId
619 | crate::BuiltIn::WorkGroupSize
620 | crate::BuiltIn::NumWorkGroups,
621 )) => true,
622 _ => false,
623 };
624 Uniformity {
625 non_uniform_result: if uniform { None } else { Some(handle) },
626 requirements: UniformityRequirements::empty(),
627 }
628 }
629 E::GlobalVariable(gh) => {
631 use crate::AddressSpace as As;
632 assignable_global = Some(gh);
633 let var = &resolve_context.global_vars[gh];
634 let uniform = match var.space {
635 As::Function | As::Private => false,
637 As::WorkGroup => true,
639 As::Uniform | As::PushConstant => true,
641 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
643 As::Handle => false,
644 };
645 Uniformity {
646 non_uniform_result: if uniform { None } else { Some(handle) },
647 requirements: UniformityRequirements::empty(),
648 }
649 }
650 E::LocalVariable(_) => Uniformity {
651 non_uniform_result: Some(handle),
652 requirements: UniformityRequirements::empty(),
653 },
654 E::Load { pointer } => Uniformity {
655 non_uniform_result: self.add_ref(pointer),
656 requirements: UniformityRequirements::empty(),
657 },
658 E::ImageSample {
659 image,
660 sampler,
661 gather: _,
662 coordinate,
663 array_index,
664 offset,
665 level,
666 depth_ref,
667 clamp_to_edge: _,
668 } => {
669 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
670 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
671
672 match (image_storage, sampler_storage) {
673 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
674 self.sampling_set.insert(SamplingKey { image, sampler });
675 }
676 _ => {
677 self.sampling.insert(Sampling {
678 image: image_storage,
679 sampler: sampler_storage,
680 });
681 }
682 }
683
684 let array_nur = array_index.and_then(|h| self.add_ref(h));
686 let level_nur = match level {
687 Sl::Auto | Sl::Zero => None,
688 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
689 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
690 };
691 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
692 let offset_nur = offset.and_then(|h| self.add_ref(h));
693 Uniformity {
694 non_uniform_result: self
695 .add_ref(image)
696 .or(self.add_ref(sampler))
697 .or(self.add_ref(coordinate))
698 .or(array_nur)
699 .or(level_nur)
700 .or(dref_nur)
701 .or(offset_nur),
702 requirements: if level.implicit_derivatives() {
703 UniformityRequirements::IMPLICIT_LEVEL
704 } else {
705 UniformityRequirements::empty()
706 },
707 }
708 }
709 E::ImageLoad {
710 image,
711 coordinate,
712 array_index,
713 sample,
714 level,
715 } => {
716 let array_nur = array_index.and_then(|h| self.add_ref(h));
717 let sample_nur = sample.and_then(|h| self.add_ref(h));
718 let level_nur = level.and_then(|h| self.add_ref(h));
719 Uniformity {
720 non_uniform_result: self
721 .add_ref(image)
722 .or(self.add_ref(coordinate))
723 .or(array_nur)
724 .or(sample_nur)
725 .or(level_nur),
726 requirements: UniformityRequirements::empty(),
727 }
728 }
729 E::ImageQuery { image, query } => {
730 let query_nur = match query {
731 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
732 _ => None,
733 };
734 Uniformity {
735 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
736 requirements: UniformityRequirements::empty(),
737 }
738 }
739 E::Unary { expr, .. } => Uniformity {
740 non_uniform_result: self.add_ref(expr),
741 requirements: UniformityRequirements::empty(),
742 },
743 E::Binary { left, right, .. } => Uniformity {
744 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
745 requirements: UniformityRequirements::empty(),
746 },
747 E::Select {
748 condition,
749 accept,
750 reject,
751 } => Uniformity {
752 non_uniform_result: self
753 .add_ref(condition)
754 .or(self.add_ref(accept))
755 .or(self.add_ref(reject)),
756 requirements: UniformityRequirements::empty(),
757 },
758 E::Derivative { expr, .. } => Uniformity {
760 non_uniform_result: self.add_ref(expr),
762 requirements: UniformityRequirements::DERIVATIVE,
763 },
764 E::Relational { argument, .. } => Uniformity {
765 non_uniform_result: self.add_ref(argument),
766 requirements: UniformityRequirements::empty(),
767 },
768 E::Math {
769 fun: _,
770 arg,
771 arg1,
772 arg2,
773 arg3,
774 } => {
775 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
776 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
777 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
778 Uniformity {
779 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
780 requirements: UniformityRequirements::empty(),
781 }
782 }
783 E::As { expr, .. } => Uniformity {
784 non_uniform_result: self.add_ref(expr),
785 requirements: UniformityRequirements::empty(),
786 },
787 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
788 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
789 non_uniform_result: Some(handle),
790 requirements: UniformityRequirements::empty(),
791 },
792 E::WorkGroupUniformLoadResult { .. } => Uniformity {
793 non_uniform_result: None,
795 requirements: UniformityRequirements::empty(),
798 },
799 E::ArrayLength(expr) => Uniformity {
800 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
801 requirements: UniformityRequirements::empty(),
802 },
803 E::RayQueryGetIntersection {
804 query,
805 committed: _,
806 } => Uniformity {
807 non_uniform_result: self.add_ref(query),
808 requirements: UniformityRequirements::empty(),
809 },
810 E::SubgroupBallotResult => Uniformity {
811 non_uniform_result: Some(handle),
812 requirements: UniformityRequirements::empty(),
813 },
814 E::SubgroupOperationResult { .. } => Uniformity {
815 non_uniform_result: Some(handle),
816 requirements: UniformityRequirements::empty(),
817 },
818 E::RayQueryVertexPositions {
819 query,
820 committed: _,
821 } => Uniformity {
822 non_uniform_result: self.add_ref(query),
823 requirements: UniformityRequirements::empty(),
824 },
825 };
826
827 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
828 self.expressions[handle.index()] = ExpressionInfo {
829 uniformity,
830 ref_count: 0,
831 assignable_global,
832 ty,
833 };
834 Ok(())
835 }
836
837 #[allow(clippy::or_fun_call)]
847 fn process_block(
848 &mut self,
849 statements: &crate::Block,
850 other_functions: &[FunctionInfo],
851 mut disruptor: Option<UniformityDisruptor>,
852 expression_arena: &Arena<crate::Expression>,
853 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
854 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
855 use crate::Statement as S;
856
857 let mut combined_uniformity = FunctionUniformity::new();
858 for statement in statements {
859 let uniformity = match *statement {
860 S::Emit(ref range) => {
861 let mut requirements = UniformityRequirements::empty();
862 for expr in range.clone() {
863 let req = self.expressions[expr.index()].uniformity.requirements;
864 if self
865 .flags
866 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
867 && !req.is_empty()
868 {
869 if let Some(cause) = disruptor {
870 let severity = DiagnosticFilterNode::search(
871 self.diagnostic_filter_leaf,
872 diagnostic_filter_arena,
873 StandardFilterableTriggeringRule::DerivativeUniformity,
874 );
875 severity.report_diag(
876 FunctionError::NonUniformControlFlow(req, expr, cause)
877 .with_span_handle(expr, expression_arena),
878 |e, level| log::log!(level, "{e}"),
884 )?;
885 }
886 }
887 requirements |= req;
888 }
889 FunctionUniformity {
890 result: Uniformity {
891 non_uniform_result: None,
892 requirements,
893 },
894 exit: ExitFlags::empty(),
895 }
896 }
897 S::Break | S::Continue => FunctionUniformity::new(),
898 S::Kill => FunctionUniformity {
899 result: Uniformity::new(),
900 exit: if disruptor.is_some() {
901 ExitFlags::MAY_KILL
902 } else {
903 ExitFlags::empty()
904 },
905 },
906 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
907 result: Uniformity {
908 non_uniform_result: None,
909 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
910 },
911 exit: ExitFlags::empty(),
912 },
913 S::WorkGroupUniformLoad { pointer, .. } => {
914 let _condition_nur = self.add_ref(pointer);
915
916 FunctionUniformity {
935 result: Uniformity {
936 non_uniform_result: None,
937 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
938 },
939 exit: ExitFlags::empty(),
940 }
941 }
942 S::Block(ref b) => self.process_block(
943 b,
944 other_functions,
945 disruptor,
946 expression_arena,
947 diagnostic_filter_arena,
948 )?,
949 S::If {
950 condition,
951 ref accept,
952 ref reject,
953 } => {
954 let condition_nur = self.add_ref(condition);
955 let branch_disruptor =
956 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
957 let accept_uniformity = self.process_block(
958 accept,
959 other_functions,
960 branch_disruptor,
961 expression_arena,
962 diagnostic_filter_arena,
963 )?;
964 let reject_uniformity = self.process_block(
965 reject,
966 other_functions,
967 branch_disruptor,
968 expression_arena,
969 diagnostic_filter_arena,
970 )?;
971 accept_uniformity | reject_uniformity
972 }
973 S::Switch {
974 selector,
975 ref cases,
976 } => {
977 let selector_nur = self.add_ref(selector);
978 let branch_disruptor =
979 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
980 let mut uniformity = FunctionUniformity::new();
981 let mut case_disruptor = branch_disruptor;
982 for case in cases.iter() {
983 let case_uniformity = self.process_block(
984 &case.body,
985 other_functions,
986 case_disruptor,
987 expression_arena,
988 diagnostic_filter_arena,
989 )?;
990 case_disruptor = if case.fall_through {
991 case_disruptor.or(case_uniformity.exit_disruptor())
992 } else {
993 branch_disruptor
994 };
995 uniformity = uniformity | case_uniformity;
996 }
997 uniformity
998 }
999 S::Loop {
1000 ref body,
1001 ref continuing,
1002 break_if,
1003 } => {
1004 let body_uniformity = self.process_block(
1005 body,
1006 other_functions,
1007 disruptor,
1008 expression_arena,
1009 diagnostic_filter_arena,
1010 )?;
1011 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1012 let continuing_uniformity = self.process_block(
1013 continuing,
1014 other_functions,
1015 continuing_disruptor,
1016 expression_arena,
1017 diagnostic_filter_arena,
1018 )?;
1019 if let Some(expr) = break_if {
1020 let _ = self.add_ref(expr);
1021 }
1022 body_uniformity | continuing_uniformity
1023 }
1024 S::Return { value } => FunctionUniformity {
1025 result: Uniformity {
1026 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1027 requirements: UniformityRequirements::empty(),
1028 },
1029 exit: if disruptor.is_some() {
1030 ExitFlags::MAY_RETURN
1031 } else {
1032 ExitFlags::empty()
1033 },
1034 },
1035 S::Store { pointer, value } => {
1039 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1040 let _ = self.add_ref(value);
1041 FunctionUniformity::new()
1042 }
1043 S::ImageStore {
1044 image,
1045 coordinate,
1046 array_index,
1047 value,
1048 } => {
1049 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1050 if let Some(expr) = array_index {
1051 let _ = self.add_ref(expr);
1052 }
1053 let _ = self.add_ref(coordinate);
1054 let _ = self.add_ref(value);
1055 FunctionUniformity::new()
1056 }
1057 S::Call {
1058 function,
1059 ref arguments,
1060 result: _,
1061 } => {
1062 for &argument in arguments {
1063 let _ = self.add_ref(argument);
1064 }
1065 let info = &other_functions[function.index()];
1066 self.process_call(info, arguments, expression_arena)?
1068 }
1069 S::Atomic {
1070 pointer,
1071 ref fun,
1072 value,
1073 result: _,
1074 } => {
1075 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1076 let _ = self.add_ref(value);
1077 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1078 let _ = self.add_ref(cmp);
1079 }
1080 FunctionUniformity::new()
1081 }
1082 S::ImageAtomic {
1083 image,
1084 coordinate,
1085 array_index,
1086 fun: _,
1087 value,
1088 } => {
1089 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1090 let _ = self.add_ref(coordinate);
1091 if let Some(expr) = array_index {
1092 let _ = self.add_ref(expr);
1093 }
1094 let _ = self.add_ref(value);
1095 FunctionUniformity::new()
1096 }
1097 S::RayQuery { query, ref fun } => {
1098 let _ = self.add_ref(query);
1099 match *fun {
1100 crate::RayQueryFunction::Initialize {
1101 acceleration_structure,
1102 descriptor,
1103 } => {
1104 let _ = self.add_ref(acceleration_structure);
1105 let _ = self.add_ref(descriptor);
1106 }
1107 crate::RayQueryFunction::Proceed { result: _ } => {}
1108 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1109 let _ = self.add_ref(hit_t);
1110 }
1111 crate::RayQueryFunction::ConfirmIntersection => {}
1112 crate::RayQueryFunction::Terminate => {}
1113 }
1114 FunctionUniformity::new()
1115 }
1116 S::SubgroupBallot {
1117 result: _,
1118 predicate,
1119 } => {
1120 if let Some(predicate) = predicate {
1121 let _ = self.add_ref(predicate);
1122 }
1123 FunctionUniformity::new()
1124 }
1125 S::SubgroupCollectiveOperation {
1126 op: _,
1127 collective_op: _,
1128 argument,
1129 result: _,
1130 } => {
1131 let _ = self.add_ref(argument);
1132 FunctionUniformity::new()
1133 }
1134 S::SubgroupGather {
1135 mode,
1136 argument,
1137 result: _,
1138 } => {
1139 let _ = self.add_ref(argument);
1140 match mode {
1141 crate::GatherMode::BroadcastFirst => {}
1142 crate::GatherMode::Broadcast(index)
1143 | crate::GatherMode::Shuffle(index)
1144 | crate::GatherMode::ShuffleDown(index)
1145 | crate::GatherMode::ShuffleUp(index)
1146 | crate::GatherMode::ShuffleXor(index)
1147 | crate::GatherMode::QuadBroadcast(index) => {
1148 let _ = self.add_ref(index);
1149 }
1150 crate::GatherMode::QuadSwap(_) => {}
1151 }
1152 FunctionUniformity::new()
1153 }
1154 };
1155
1156 disruptor = disruptor.or(uniformity.exit_disruptor());
1157 combined_uniformity = combined_uniformity | uniformity;
1158 }
1159 Ok(combined_uniformity)
1160 }
1161}
1162
1163impl ModuleInfo {
1164 pub(super) fn process_const_expression(
1166 &mut self,
1167 handle: Handle<crate::Expression>,
1168 resolve_context: &ResolveContext,
1169 gctx: crate::proc::GlobalCtx,
1170 ) -> Result<(), super::ConstExpressionError> {
1171 self.const_expression_types[handle.index()] =
1172 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1173 Ok(())
1174 }
1175
1176 pub(super) fn process_function(
1179 &self,
1180 fun: &crate::Function,
1181 module: &crate::Module,
1182 flags: ValidationFlags,
1183 capabilities: super::Capabilities,
1184 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1185 let mut info = FunctionInfo {
1186 flags,
1187 available_stages: ShaderStages::all(),
1188 uniformity: Uniformity::new(),
1189 may_kill: false,
1190 sampling_set: crate::FastHashSet::default(),
1191 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1192 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1193 sampling: crate::FastHashSet::default(),
1194 dual_source_blending: false,
1195 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1196 };
1197 let resolve_context =
1198 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1199
1200 for (handle, _) in fun.expressions.iter() {
1201 if let Err(source) = info.process_expression(
1202 handle,
1203 &fun.expressions,
1204 &self.functions,
1205 &resolve_context,
1206 capabilities,
1207 ) {
1208 return Err(FunctionError::Expression { handle, source }
1209 .with_span_handle(handle, &fun.expressions));
1210 }
1211 }
1212
1213 for (_, expr) in fun.local_variables.iter() {
1214 if let Some(init) = expr.init {
1215 let _ = info.add_ref(init);
1216 }
1217 }
1218
1219 let uniformity = info.process_block(
1220 &fun.body,
1221 &self.functions,
1222 None,
1223 &fun.expressions,
1224 &module.diagnostic_filters,
1225 )?;
1226 info.uniformity = uniformity.result;
1227 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1228
1229 for &handle in fun.named_expressions.keys() {
1235 if let Some(global) = info[handle].assignable_global {
1236 if info.global_uses[global.index()].is_empty() {
1237 info.global_uses[global.index()] = GlobalUse::QUERY;
1238 }
1239 }
1240 }
1241
1242 Ok(info)
1243 }
1244
1245 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1246 &self.entry_points[index]
1247 }
1248}
1249
1250#[test]
1251fn uniform_control_flow() {
1252 use crate::{Expression as E, Statement as S};
1253
1254 let mut type_arena = crate::UniqueArena::new();
1255 let ty = type_arena.insert(
1256 crate::Type {
1257 name: None,
1258 inner: crate::TypeInner::Vector {
1259 size: crate::VectorSize::Bi,
1260 scalar: crate::Scalar::F32,
1261 },
1262 },
1263 Default::default(),
1264 );
1265 let mut global_var_arena = Arena::new();
1266 let non_uniform_global = global_var_arena.append(
1267 crate::GlobalVariable {
1268 name: None,
1269 init: None,
1270 ty,
1271 space: crate::AddressSpace::Handle,
1272 binding: None,
1273 },
1274 Default::default(),
1275 );
1276 let uniform_global = global_var_arena.append(
1277 crate::GlobalVariable {
1278 name: None,
1279 init: None,
1280 ty,
1281 binding: None,
1282 space: crate::AddressSpace::Uniform,
1283 },
1284 Default::default(),
1285 );
1286
1287 let mut expressions = Arena::new();
1288 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1290 let derivative_expr = expressions.append(
1292 E::Derivative {
1293 axis: crate::DerivativeAxis::X,
1294 ctrl: crate::DerivativeControl::None,
1295 expr: constant_expr,
1296 },
1297 Default::default(),
1298 );
1299 let emit_range_constant_derivative = expressions.range_from(0);
1300 let non_uniform_global_expr =
1301 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1302 let uniform_global_expr =
1303 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1304 let emit_range_globals = expressions.range_from(2);
1305
1306 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1308 let access_expr = expressions.append(
1310 E::AccessIndex {
1311 base: non_uniform_global_expr,
1312 index: 1,
1313 },
1314 Default::default(),
1315 );
1316 let emit_range_query_access_globals = expressions.range_from(2);
1317
1318 let mut info = FunctionInfo {
1319 flags: ValidationFlags::all(),
1320 available_stages: ShaderStages::all(),
1321 uniformity: Uniformity::new(),
1322 may_kill: false,
1323 sampling_set: crate::FastHashSet::default(),
1324 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1325 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1326 sampling: crate::FastHashSet::default(),
1327 dual_source_blending: false,
1328 diagnostic_filter_leaf: None,
1329 };
1330 let resolve_context = ResolveContext {
1331 constants: &Arena::new(),
1332 overrides: &Arena::new(),
1333 types: &type_arena,
1334 special_types: &crate::SpecialTypes::default(),
1335 global_vars: &global_var_arena,
1336 local_vars: &Arena::new(),
1337 functions: &Arena::new(),
1338 arguments: &[],
1339 };
1340 for (handle, _) in expressions.iter() {
1341 info.process_expression(
1342 handle,
1343 &expressions,
1344 &[],
1345 &resolve_context,
1346 super::Capabilities::empty(),
1347 )
1348 .unwrap();
1349 }
1350 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1351 assert_eq!(info[uniform_global_expr].ref_count, 1);
1352 assert_eq!(info[query_expr].ref_count, 0);
1353 assert_eq!(info[access_expr].ref_count, 0);
1354 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1355 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1356
1357 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1358 let stmt_if_uniform = S::If {
1359 condition: uniform_global_expr,
1360 accept: crate::Block::new(),
1361 reject: vec![
1362 S::Emit(emit_range_constant_derivative.clone()),
1363 S::Store {
1364 pointer: constant_expr,
1365 value: derivative_expr,
1366 },
1367 ]
1368 .into(),
1369 };
1370 assert_eq!(
1371 info.process_block(
1372 &vec![stmt_emit1, stmt_if_uniform].into(),
1373 &[],
1374 None,
1375 &expressions,
1376 &Arena::new(),
1377 ),
1378 Ok(FunctionUniformity {
1379 result: Uniformity {
1380 non_uniform_result: None,
1381 requirements: UniformityRequirements::DERIVATIVE,
1382 },
1383 exit: ExitFlags::empty(),
1384 }),
1385 );
1386 assert_eq!(info[constant_expr].ref_count, 2);
1387 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1388
1389 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1390 let stmt_if_non_uniform = S::If {
1391 condition: non_uniform_global_expr,
1392 accept: vec![
1393 S::Emit(emit_range_constant_derivative),
1394 S::Store {
1395 pointer: constant_expr,
1396 value: derivative_expr,
1397 },
1398 ]
1399 .into(),
1400 reject: crate::Block::new(),
1401 };
1402 {
1403 let block_info = info.process_block(
1404 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1405 &[],
1406 None,
1407 &expressions,
1408 &Arena::new(),
1409 );
1410 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1411 assert_eq!(info[derivative_expr].ref_count, 2);
1412 } else {
1413 assert_eq!(
1414 block_info,
1415 Err(FunctionError::NonUniformControlFlow(
1416 UniformityRequirements::DERIVATIVE,
1417 derivative_expr,
1418 UniformityDisruptor::Expression(non_uniform_global_expr)
1419 )
1420 .with_span()),
1421 );
1422 assert_eq!(info[derivative_expr].ref_count, 1);
1423
1424 let mut diagnostic_filters = Arena::new();
1426 let diagnostic_filter_leaf = diagnostic_filters.append(
1427 DiagnosticFilterNode {
1428 inner: crate::diagnostic_filter::DiagnosticFilter {
1429 new_severity: crate::diagnostic_filter::Severity::Off,
1430 triggering_rule:
1431 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1432 StandardFilterableTriggeringRule::DerivativeUniformity,
1433 ),
1434 },
1435 parent: None,
1436 },
1437 crate::Span::default(),
1438 );
1439 let mut info = FunctionInfo {
1440 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1441 ..info.clone()
1442 };
1443
1444 let block_info = info.process_block(
1445 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1446 &[],
1447 None,
1448 &expressions,
1449 &diagnostic_filters,
1450 );
1451 assert_eq!(
1452 block_info,
1453 Ok(FunctionUniformity {
1454 result: Uniformity {
1455 non_uniform_result: None,
1456 requirements: UniformityRequirements::DERIVATIVE,
1457 },
1458 exit: ExitFlags::empty()
1459 }),
1460 );
1461 assert_eq!(info[derivative_expr].ref_count, 2);
1462 }
1463 }
1464 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1465
1466 let stmt_emit3 = S::Emit(emit_range_globals);
1467 let stmt_return_non_uniform = S::Return {
1468 value: Some(non_uniform_global_expr),
1469 };
1470 assert_eq!(
1471 info.process_block(
1472 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1473 &[],
1474 Some(UniformityDisruptor::Return),
1475 &expressions,
1476 &Arena::new(),
1477 ),
1478 Ok(FunctionUniformity {
1479 result: Uniformity {
1480 non_uniform_result: Some(non_uniform_global_expr),
1481 requirements: UniformityRequirements::empty(),
1482 },
1483 exit: ExitFlags::MAY_RETURN,
1484 }),
1485 );
1486 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1487
1488 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1490 let stmt_assign = S::Store {
1491 pointer: access_expr,
1492 value: query_expr,
1493 };
1494 let stmt_return_pointer = S::Return {
1495 value: Some(access_expr),
1496 };
1497 let stmt_kill = S::Kill;
1498 assert_eq!(
1499 info.process_block(
1500 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1501 &[],
1502 Some(UniformityDisruptor::Discard),
1503 &expressions,
1504 &Arena::new(),
1505 ),
1506 Ok(FunctionUniformity {
1507 result: Uniformity {
1508 non_uniform_result: Some(non_uniform_global_expr),
1509 requirements: UniformityRequirements::empty(),
1510 },
1511 exit: ExitFlags::all(),
1512 }),
1513 );
1514 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1515}