1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 const COOP_OPS = 0x8;
33 }
34}
35
36#[derive(Clone, Debug)]
38#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
39#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
40#[cfg_attr(test, derive(PartialEq))]
41pub struct Uniformity {
42 pub non_uniform_result: NonUniformResult,
54 pub requirements: UniformityRequirements,
56}
57
58impl Uniformity {
59 const fn new() -> Self {
60 Uniformity {
61 non_uniform_result: None,
62 requirements: UniformityRequirements::empty(),
63 }
64 }
65}
66
67bitflags::bitflags! {
68 #[derive(Clone, Copy, Debug, PartialEq)]
69 struct ExitFlags: u8 {
70 const MAY_RETURN = 0x1;
74 const MAY_KILL = 0x2;
79 }
80}
81
82#[cfg_attr(test, derive(Debug, PartialEq))]
84struct FunctionUniformity {
85 result: Uniformity,
86 exit: ExitFlags,
87}
88
89impl ops::BitOr for FunctionUniformity {
90 type Output = Self;
91 fn bitor(self, other: Self) -> Self {
92 FunctionUniformity {
93 result: Uniformity {
94 non_uniform_result: self
95 .result
96 .non_uniform_result
97 .or(other.result.non_uniform_result),
98 requirements: self.result.requirements | other.result.requirements,
99 },
100 exit: self.exit | other.exit,
101 }
102 }
103}
104
105impl FunctionUniformity {
106 const fn new() -> Self {
107 FunctionUniformity {
108 result: Uniformity::new(),
109 exit: ExitFlags::empty(),
110 }
111 }
112
113 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
115 if self.exit.contains(ExitFlags::MAY_RETURN) {
116 Some(UniformityDisruptor::Return)
117 } else if self.exit.contains(ExitFlags::MAY_KILL) {
118 Some(UniformityDisruptor::Discard)
119 } else {
120 None
121 }
122 }
123}
124
125bitflags::bitflags! {
126 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
128 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
129 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
130 pub struct GlobalUse: u8 {
131 const READ = 0x1;
133 const WRITE = 0x2;
135 const QUERY = 0x4;
137 const ATOMIC = 0x8;
139 }
140}
141
142#[derive(Clone, Debug, Eq, Hash, PartialEq)]
143#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
144#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
145pub struct SamplingKey {
146 pub image: Handle<crate::GlobalVariable>,
147 pub sampler: Handle<crate::GlobalVariable>,
148}
149
150#[derive(Clone, Debug)]
151#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
152#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
153pub struct ExpressionInfo {
155 pub uniformity: Uniformity,
161
162 pub ref_count: usize,
168
169 assignable_global: Option<Handle<crate::GlobalVariable>>,
183
184 pub ty: TypeResolution,
186}
187
188impl ExpressionInfo {
189 const fn new() -> Self {
190 ExpressionInfo {
191 uniformity: Uniformity::new(),
192 ref_count: 0,
193 assignable_global: None,
194 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
196 kind: crate::ScalarKind::Bool,
197 width: 0,
198 })),
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
205#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
206enum GlobalOrArgument {
207 Global(Handle<crate::GlobalVariable>),
208 Argument(u32),
209}
210
211impl GlobalOrArgument {
212 fn from_expression(
213 expression_arena: &Arena<crate::Expression>,
214 expression: Handle<crate::Expression>,
215 ) -> Result<GlobalOrArgument, ExpressionError> {
216 Ok(match expression_arena[expression] {
217 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
218 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
219 crate::Expression::Access { base, .. }
220 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
221 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
222 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
223 },
224 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
225 })
226 }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, Hash)]
230#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
231#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
232struct Sampling {
233 image: GlobalOrArgument,
234 sampler: GlobalOrArgument,
235}
236
237#[derive(Debug, Clone)]
238#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
239#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
240pub struct FunctionInfo {
241 #[allow(dead_code)]
243 flags: ValidationFlags,
244 pub available_stages: ShaderStages,
246 pub uniformity: Uniformity,
248 pub may_kill: bool,
250
251 pub sampling_set: crate::FastHashSet<SamplingKey>,
266
267 global_uses: Box<[GlobalUse]>,
274
275 expressions: Box<[ExpressionInfo]>,
282
283 sampling: crate::FastHashSet<Sampling>,
296
297 pub dual_source_blending: bool,
299
300 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
306}
307
308impl FunctionInfo {
309 pub const fn global_variable_count(&self) -> usize {
310 self.global_uses.len()
311 }
312 pub const fn expression_count(&self) -> usize {
313 self.expressions.len()
314 }
315 pub fn dominates_global_use(&self, other: &Self) -> bool {
316 for (self_global_uses, other_global_uses) in
317 self.global_uses.iter().zip(other.global_uses.iter())
318 {
319 if !self_global_uses.contains(*other_global_uses) {
320 return false;
321 }
322 }
323 true
324 }
325}
326
327impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
328 type Output = GlobalUse;
329 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
330 &self.global_uses[handle.index()]
331 }
332}
333
334impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
335 type Output = ExpressionInfo;
336 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
337 &self.expressions[handle.index()]
338 }
339}
340
341#[derive(Clone, Copy, Debug, thiserror::Error)]
343#[cfg_attr(test, derive(PartialEq))]
344pub enum UniformityDisruptor {
345 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
346 Expression(Handle<crate::Expression>),
347 #[error("There is a Return earlier in the control flow of the function")]
348 Return,
349 #[error("There is a Discard earlier in the entry point across all called functions")]
350 Discard,
351}
352
353impl FunctionInfo {
354 #[must_use]
362 fn add_ref_impl(
363 &mut self,
364 expr: Handle<crate::Expression>,
365 global_use: GlobalUse,
366 ) -> NonUniformResult {
367 let info = &mut self.expressions[expr.index()];
368 info.ref_count += 1;
369 if let Some(global) = info.assignable_global {
371 self.global_uses[global.index()] |= global_use;
372 }
373 info.uniformity.non_uniform_result
374 }
375
376 pub(super) fn insert_global_use(
385 &mut self,
386 global_use: GlobalUse,
387 global: Handle<crate::GlobalVariable>,
388 ) {
389 self.global_uses[global.index()] |= global_use;
390 }
391
392 #[must_use]
399 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
400 self.add_ref_impl(expr, GlobalUse::READ)
401 }
402
403 #[must_use]
422 fn add_assignable_ref(
423 &mut self,
424 expr: Handle<crate::Expression>,
425 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
426 ) -> NonUniformResult {
427 let info = &mut self.expressions[expr.index()];
428 info.ref_count += 1;
429 if let Some(global) = info.assignable_global {
432 if let Some(_old) = assignable_global.replace(global) {
433 unreachable!()
434 }
435 }
436 info.uniformity.non_uniform_result
437 }
438
439 fn process_call(
441 &mut self,
442 callee: &Self,
443 arguments: &[Handle<crate::Expression>],
444 expression_arena: &Arena<crate::Expression>,
445 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
446 self.sampling_set
447 .extend(callee.sampling_set.iter().cloned());
448 for sampling in callee.sampling.iter() {
449 let image_storage = match sampling.image {
452 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
453 GlobalOrArgument::Argument(i) => {
454 let Some(handle) = arguments.get(i as usize).cloned() else {
455 break;
457 };
458 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
459 |source| {
460 FunctionError::Expression { handle, source }
461 .with_span_handle(handle, expression_arena)
462 },
463 )?
464 }
465 };
466
467 let sampler_storage = match sampling.sampler {
468 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
469 GlobalOrArgument::Argument(i) => {
470 let Some(handle) = arguments.get(i as usize).cloned() else {
471 break;
473 };
474 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
475 |source| {
476 FunctionError::Expression { handle, source }
477 .with_span_handle(handle, expression_arena)
478 },
479 )?
480 }
481 };
482
483 match (image_storage, sampler_storage) {
488 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
489 self.sampling_set.insert(SamplingKey { image, sampler });
490 }
491 (image, sampler) => {
492 self.sampling.insert(Sampling { image, sampler });
493 }
494 }
495 }
496
497 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
499 *mine |= *other;
500 }
501
502 Ok(FunctionUniformity {
503 result: callee.uniformity.clone(),
504 exit: if callee.may_kill {
505 ExitFlags::MAY_KILL
506 } else {
507 ExitFlags::empty()
508 },
509 })
510 }
511
512 #[allow(clippy::or_fun_call)]
532 fn process_expression(
533 &mut self,
534 handle: Handle<crate::Expression>,
535 expression_arena: &Arena<crate::Expression>,
536 other_functions: &[FunctionInfo],
537 resolve_context: &ResolveContext,
538 capabilities: super::Capabilities,
539 ) -> Result<(), ExpressionError> {
540 use crate::{Expression as E, SampleLevel as Sl};
541
542 let expression = &expression_arena[handle];
543 let mut assignable_global = None;
544 let uniformity = match *expression {
545 E::Access { base, index } => {
546 let base_ty = self[base].ty.inner_with(resolve_context.types);
547
548 let mut needed_caps = super::Capabilities::empty();
550 let is_binding_array = match *base_ty {
551 crate::TypeInner::BindingArray {
552 base: array_element_ty_handle,
553 ..
554 } => {
555 let array_element_ty =
557 &resolve_context.types[array_element_ty_handle].inner;
558
559 needed_caps |= match *array_element_ty {
560 crate::TypeInner::Image { class, .. } => match class {
562 crate::ImageClass::Storage { .. } => {
563 super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
564 }
565 _ => {
566 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
567 }
568 },
569 crate::TypeInner::Sampler { .. } => {
570 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
571 }
572 _ => {
574 if let E::GlobalVariable(global_handle) = expression_arena[base] {
575 let global = &resolve_context.global_vars[global_handle];
576 match global.space {
577 crate::AddressSpace::Uniform => {
578 super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
579 }
580 crate::AddressSpace::Storage { .. } => {
581 super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
582 }
583 _ => unreachable!(),
584 }
585 } else {
586 unreachable!()
587 }
588 }
589 };
590
591 true
592 }
593 _ => false,
594 };
595
596 if self[index].uniformity.non_uniform_result.is_some()
597 && !capabilities.contains(needed_caps)
598 && is_binding_array
599 {
600 return Err(ExpressionError::MissingCapabilities(needed_caps));
601 }
602
603 Uniformity {
604 non_uniform_result: self
605 .add_assignable_ref(base, &mut assignable_global)
606 .or(self.add_ref(index)),
607 requirements: UniformityRequirements::empty(),
608 }
609 }
610 E::AccessIndex { base, .. } => Uniformity {
611 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
612 requirements: UniformityRequirements::empty(),
613 },
614 E::Splat { size: _, value } => Uniformity {
616 non_uniform_result: self.add_ref(value),
617 requirements: UniformityRequirements::empty(),
618 },
619 E::Swizzle { vector, .. } => Uniformity {
620 non_uniform_result: self.add_ref(vector),
621 requirements: UniformityRequirements::empty(),
622 },
623 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
624 E::Compose { ref components, .. } => {
625 let non_uniform_result = components
626 .iter()
627 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
628 Uniformity {
629 non_uniform_result,
630 requirements: UniformityRequirements::empty(),
631 }
632 }
633 E::FunctionArgument(index) => {
635 let arg = &resolve_context.arguments[index as usize];
636 let uniform = match arg.binding {
637 Some(crate::Binding::BuiltIn(
638 crate::BuiltIn::WorkGroupId
640 | crate::BuiltIn::WorkGroupSize
641 | crate::BuiltIn::NumWorkGroups,
642 )) => true,
643 _ => false,
644 };
645 Uniformity {
646 non_uniform_result: if uniform { None } else { Some(handle) },
647 requirements: UniformityRequirements::empty(),
648 }
649 }
650 E::GlobalVariable(gh) => {
652 use crate::AddressSpace as As;
653 assignable_global = Some(gh);
654 let var = &resolve_context.global_vars[gh];
655 let uniform = match var.space {
656 As::Function | As::Private => false,
658 As::WorkGroup | As::TaskPayload => true,
661 As::Uniform | As::Immediate => true,
663 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
665 As::Handle => false,
666 };
667 Uniformity {
668 non_uniform_result: if uniform { None } else { Some(handle) },
669 requirements: UniformityRequirements::empty(),
670 }
671 }
672 E::LocalVariable(_) => Uniformity {
673 non_uniform_result: Some(handle),
674 requirements: UniformityRequirements::empty(),
675 },
676 E::Load { pointer } => Uniformity {
677 non_uniform_result: self.add_ref(pointer),
678 requirements: UniformityRequirements::empty(),
679 },
680 E::ImageSample {
681 image,
682 sampler,
683 gather: _,
684 coordinate,
685 array_index,
686 offset,
687 level,
688 depth_ref,
689 clamp_to_edge: _,
690 } => {
691 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
692 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
693
694 match (image_storage, sampler_storage) {
695 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
696 self.sampling_set.insert(SamplingKey { image, sampler });
697 }
698 _ => {
699 self.sampling.insert(Sampling {
700 image: image_storage,
701 sampler: sampler_storage,
702 });
703 }
704 }
705
706 let array_nur = array_index.and_then(|h| self.add_ref(h));
708 let level_nur = match level {
709 Sl::Auto | Sl::Zero => None,
710 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
711 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
712 };
713 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
714 let offset_nur = offset.and_then(|h| self.add_ref(h));
715 Uniformity {
716 non_uniform_result: self
717 .add_ref(image)
718 .or(self.add_ref(sampler))
719 .or(self.add_ref(coordinate))
720 .or(array_nur)
721 .or(level_nur)
722 .or(dref_nur)
723 .or(offset_nur),
724 requirements: if level.implicit_derivatives() {
725 UniformityRequirements::IMPLICIT_LEVEL
726 } else {
727 UniformityRequirements::empty()
728 },
729 }
730 }
731 E::ImageLoad {
732 image,
733 coordinate,
734 array_index,
735 sample,
736 level,
737 } => {
738 let array_nur = array_index.and_then(|h| self.add_ref(h));
739 let sample_nur = sample.and_then(|h| self.add_ref(h));
740 let level_nur = level.and_then(|h| self.add_ref(h));
741 Uniformity {
742 non_uniform_result: self
743 .add_ref(image)
744 .or(self.add_ref(coordinate))
745 .or(array_nur)
746 .or(sample_nur)
747 .or(level_nur),
748 requirements: UniformityRequirements::empty(),
749 }
750 }
751 E::ImageQuery { image, query } => {
752 let query_nur = match query {
753 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
754 _ => None,
755 };
756 Uniformity {
757 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
758 requirements: UniformityRequirements::empty(),
759 }
760 }
761 E::Unary { expr, .. } => Uniformity {
762 non_uniform_result: self.add_ref(expr),
763 requirements: UniformityRequirements::empty(),
764 },
765 E::Binary { left, right, .. } => Uniformity {
766 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
767 requirements: UniformityRequirements::empty(),
768 },
769 E::Select {
770 condition,
771 accept,
772 reject,
773 } => Uniformity {
774 non_uniform_result: self
775 .add_ref(condition)
776 .or(self.add_ref(accept))
777 .or(self.add_ref(reject)),
778 requirements: UniformityRequirements::empty(),
779 },
780 E::Derivative { expr, .. } => Uniformity {
782 non_uniform_result: self.add_ref(expr),
784 requirements: UniformityRequirements::DERIVATIVE,
785 },
786 E::Relational { argument, .. } => Uniformity {
787 non_uniform_result: self.add_ref(argument),
788 requirements: UniformityRequirements::empty(),
789 },
790 E::Math {
791 fun: _,
792 arg,
793 arg1,
794 arg2,
795 arg3,
796 } => {
797 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
798 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
799 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
800 Uniformity {
801 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
802 requirements: UniformityRequirements::empty(),
803 }
804 }
805 E::As { expr, .. } => Uniformity {
806 non_uniform_result: self.add_ref(expr),
807 requirements: UniformityRequirements::empty(),
808 },
809 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
810 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
811 non_uniform_result: Some(handle),
812 requirements: UniformityRequirements::empty(),
813 },
814 E::WorkGroupUniformLoadResult { .. } => Uniformity {
815 non_uniform_result: None,
817 requirements: UniformityRequirements::empty(),
820 },
821 E::ArrayLength(expr) => Uniformity {
822 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
823 requirements: UniformityRequirements::empty(),
824 },
825 E::RayQueryGetIntersection {
826 query,
827 committed: _,
828 } => Uniformity {
829 non_uniform_result: self.add_ref(query),
830 requirements: UniformityRequirements::empty(),
831 },
832 E::SubgroupBallotResult => Uniformity {
833 non_uniform_result: Some(handle),
834 requirements: UniformityRequirements::empty(),
835 },
836 E::SubgroupOperationResult { .. } => Uniformity {
837 non_uniform_result: Some(handle),
838 requirements: UniformityRequirements::empty(),
839 },
840 E::RayQueryVertexPositions {
841 query,
842 committed: _,
843 } => Uniformity {
844 non_uniform_result: self.add_ref(query),
845 requirements: UniformityRequirements::empty(),
846 },
847 E::CooperativeLoad { ref data, .. } => Uniformity {
848 non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)),
849 requirements: UniformityRequirements::COOP_OPS,
850 },
851 E::CooperativeMultiplyAdd { a, b, c } => Uniformity {
852 non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))),
853 requirements: UniformityRequirements::COOP_OPS,
854 },
855 };
856
857 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
858 self.expressions[handle.index()] = ExpressionInfo {
859 uniformity,
860 ref_count: 0,
861 assignable_global,
862 ty,
863 };
864 Ok(())
865 }
866
867 #[allow(clippy::or_fun_call)]
877 fn process_block(
878 &mut self,
879 statements: &crate::Block,
880 other_functions: &[FunctionInfo],
881 mut disruptor: Option<UniformityDisruptor>,
882 expression_arena: &Arena<crate::Expression>,
883 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
884 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
885 use crate::Statement as S;
886
887 let mut combined_uniformity = FunctionUniformity::new();
888 for statement in statements {
889 let uniformity = match *statement {
890 S::Emit(ref range) => {
891 let mut requirements = UniformityRequirements::empty();
892 for expr in range.clone() {
893 let req = self.expressions[expr.index()].uniformity.requirements;
894 if self
895 .flags
896 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
897 && !req.is_empty()
898 {
899 if let Some(cause) = disruptor {
900 let severity = DiagnosticFilterNode::search(
901 self.diagnostic_filter_leaf,
902 diagnostic_filter_arena,
903 StandardFilterableTriggeringRule::DerivativeUniformity,
904 );
905 severity.report_diag(
906 FunctionError::NonUniformControlFlow(req, expr, cause)
907 .with_span_handle(expr, expression_arena),
908 |e, level| log::log!(level, "{e}"),
914 )?;
915 }
916 }
917 requirements |= req;
918 }
919 FunctionUniformity {
920 result: Uniformity {
921 non_uniform_result: None,
922 requirements,
923 },
924 exit: ExitFlags::empty(),
925 }
926 }
927 S::Break | S::Continue => FunctionUniformity::new(),
928 S::Kill => FunctionUniformity {
929 result: Uniformity::new(),
930 exit: if disruptor.is_some() {
931 ExitFlags::MAY_KILL
932 } else {
933 ExitFlags::empty()
934 },
935 },
936 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
937 result: Uniformity {
938 non_uniform_result: None,
939 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
940 },
941 exit: ExitFlags::empty(),
942 },
943 S::WorkGroupUniformLoad { pointer, .. } => {
944 let _condition_nur = self.add_ref(pointer);
945
946 FunctionUniformity {
965 result: Uniformity {
966 non_uniform_result: None,
967 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
968 },
969 exit: ExitFlags::empty(),
970 }
971 }
972 S::Block(ref b) => self.process_block(
973 b,
974 other_functions,
975 disruptor,
976 expression_arena,
977 diagnostic_filter_arena,
978 )?,
979 S::If {
980 condition,
981 ref accept,
982 ref reject,
983 } => {
984 let condition_nur = self.add_ref(condition);
985 let branch_disruptor =
986 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
987 let accept_uniformity = self.process_block(
988 accept,
989 other_functions,
990 branch_disruptor,
991 expression_arena,
992 diagnostic_filter_arena,
993 )?;
994 let reject_uniformity = self.process_block(
995 reject,
996 other_functions,
997 branch_disruptor,
998 expression_arena,
999 diagnostic_filter_arena,
1000 )?;
1001 accept_uniformity | reject_uniformity
1002 }
1003 S::Switch {
1004 selector,
1005 ref cases,
1006 } => {
1007 let selector_nur = self.add_ref(selector);
1008 let branch_disruptor =
1009 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1010 let mut uniformity = FunctionUniformity::new();
1011 let mut case_disruptor = branch_disruptor;
1012 for case in cases.iter() {
1013 let case_uniformity = self.process_block(
1014 &case.body,
1015 other_functions,
1016 case_disruptor,
1017 expression_arena,
1018 diagnostic_filter_arena,
1019 )?;
1020 case_disruptor = if case.fall_through {
1021 case_disruptor.or(case_uniformity.exit_disruptor())
1022 } else {
1023 branch_disruptor
1024 };
1025 uniformity = uniformity | case_uniformity;
1026 }
1027 uniformity
1028 }
1029 S::Loop {
1030 ref body,
1031 ref continuing,
1032 break_if,
1033 } => {
1034 let body_uniformity = self.process_block(
1035 body,
1036 other_functions,
1037 disruptor,
1038 expression_arena,
1039 diagnostic_filter_arena,
1040 )?;
1041 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1042 let continuing_uniformity = self.process_block(
1043 continuing,
1044 other_functions,
1045 continuing_disruptor,
1046 expression_arena,
1047 diagnostic_filter_arena,
1048 )?;
1049 if let Some(expr) = break_if {
1050 let _ = self.add_ref(expr);
1051 }
1052 body_uniformity | continuing_uniformity
1053 }
1054 S::Return { value } => FunctionUniformity {
1055 result: Uniformity {
1056 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1057 requirements: UniformityRequirements::empty(),
1058 },
1059 exit: if disruptor.is_some() {
1060 ExitFlags::MAY_RETURN
1061 } else {
1062 ExitFlags::empty()
1063 },
1064 },
1065 S::Store { pointer, value } => {
1069 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1070 let _ = self.add_ref(value);
1071 FunctionUniformity::new()
1072 }
1073 S::ImageStore {
1074 image,
1075 coordinate,
1076 array_index,
1077 value,
1078 } => {
1079 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1080 if let Some(expr) = array_index {
1081 let _ = self.add_ref(expr);
1082 }
1083 let _ = self.add_ref(coordinate);
1084 let _ = self.add_ref(value);
1085 FunctionUniformity::new()
1086 }
1087 S::Call {
1088 function,
1089 ref arguments,
1090 result: _,
1091 } => {
1092 for &argument in arguments {
1093 let _ = self.add_ref(argument);
1094 }
1095 let info = &other_functions[function.index()];
1096 self.process_call(info, arguments, expression_arena)?
1098 }
1099 S::Atomic {
1100 pointer,
1101 ref fun,
1102 value,
1103 result: _,
1104 } => {
1105 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1106 let _ = self.add_ref(value);
1107 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1108 let _ = self.add_ref(cmp);
1109 }
1110 FunctionUniformity::new()
1111 }
1112 S::ImageAtomic {
1113 image,
1114 coordinate,
1115 array_index,
1116 fun: _,
1117 value,
1118 } => {
1119 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1120 let _ = self.add_ref(coordinate);
1121 if let Some(expr) = array_index {
1122 let _ = self.add_ref(expr);
1123 }
1124 let _ = self.add_ref(value);
1125 FunctionUniformity::new()
1126 }
1127 S::RayQuery { query, ref fun } => {
1128 let _ = self.add_ref(query);
1129 match *fun {
1130 crate::RayQueryFunction::Initialize {
1131 acceleration_structure,
1132 descriptor,
1133 } => {
1134 let _ = self.add_ref(acceleration_structure);
1135 let _ = self.add_ref(descriptor);
1136 }
1137 crate::RayQueryFunction::Proceed { result: _ } => {}
1138 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1139 let _ = self.add_ref(hit_t);
1140 }
1141 crate::RayQueryFunction::ConfirmIntersection => {}
1142 crate::RayQueryFunction::Terminate => {}
1143 }
1144 FunctionUniformity::new()
1145 }
1146 S::SubgroupBallot {
1147 result: _,
1148 predicate,
1149 } => {
1150 if let Some(predicate) = predicate {
1151 let _ = self.add_ref(predicate);
1152 }
1153 FunctionUniformity::new()
1154 }
1155 S::SubgroupCollectiveOperation {
1156 op: _,
1157 collective_op: _,
1158 argument,
1159 result: _,
1160 } => {
1161 let _ = self.add_ref(argument);
1162 FunctionUniformity::new()
1163 }
1164 S::SubgroupGather {
1165 mode,
1166 argument,
1167 result: _,
1168 } => {
1169 let _ = self.add_ref(argument);
1170 match mode {
1171 crate::GatherMode::BroadcastFirst => {}
1172 crate::GatherMode::Broadcast(index)
1173 | crate::GatherMode::Shuffle(index)
1174 | crate::GatherMode::ShuffleDown(index)
1175 | crate::GatherMode::ShuffleUp(index)
1176 | crate::GatherMode::ShuffleXor(index)
1177 | crate::GatherMode::QuadBroadcast(index) => {
1178 let _ = self.add_ref(index);
1179 }
1180 crate::GatherMode::QuadSwap(_) => {}
1181 }
1182 FunctionUniformity::new()
1183 }
1184 S::CooperativeStore { target, ref data } => FunctionUniformity {
1185 result: Uniformity {
1186 non_uniform_result: self
1187 .add_ref(target)
1188 .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE))
1189 .or(self.add_ref(data.stride)),
1190 requirements: UniformityRequirements::COOP_OPS,
1191 },
1192 exit: ExitFlags::empty(),
1193 },
1194 };
1195
1196 disruptor = disruptor.or(uniformity.exit_disruptor());
1197 combined_uniformity = combined_uniformity | uniformity;
1198 }
1199 Ok(combined_uniformity)
1200 }
1201}
1202
1203impl ModuleInfo {
1204 pub(super) fn process_const_expression(
1206 &mut self,
1207 handle: Handle<crate::Expression>,
1208 resolve_context: &ResolveContext,
1209 gctx: crate::proc::GlobalCtx,
1210 ) -> Result<(), super::ConstExpressionError> {
1211 self.const_expression_types[handle.index()] =
1212 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1213 Ok(())
1214 }
1215
1216 pub(super) fn process_function(
1219 &self,
1220 fun: &crate::Function,
1221 module: &crate::Module,
1222 flags: ValidationFlags,
1223 capabilities: super::Capabilities,
1224 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1225 let mut info = FunctionInfo {
1226 flags,
1227 available_stages: ShaderStages::all(),
1228 uniformity: Uniformity::new(),
1229 may_kill: false,
1230 sampling_set: crate::FastHashSet::default(),
1231 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1232 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1233 sampling: crate::FastHashSet::default(),
1234 dual_source_blending: false,
1235 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1236 };
1237 let resolve_context =
1238 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1239
1240 for (handle, _) in fun.expressions.iter() {
1241 if let Err(source) = info.process_expression(
1242 handle,
1243 &fun.expressions,
1244 &self.functions,
1245 &resolve_context,
1246 capabilities,
1247 ) {
1248 return Err(FunctionError::Expression { handle, source }
1249 .with_span_handle(handle, &fun.expressions));
1250 }
1251 }
1252
1253 for (_, expr) in fun.local_variables.iter() {
1254 if let Some(init) = expr.init {
1255 let _ = info.add_ref(init);
1256 }
1257 }
1258
1259 let uniformity = info.process_block(
1260 &fun.body,
1261 &self.functions,
1262 None,
1263 &fun.expressions,
1264 &module.diagnostic_filters,
1265 )?;
1266 info.uniformity = uniformity.result;
1267 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1268
1269 for &handle in fun.named_expressions.keys() {
1275 if let Some(global) = info[handle].assignable_global {
1276 if info.global_uses[global.index()].is_empty() {
1277 info.global_uses[global.index()] = GlobalUse::QUERY;
1278 }
1279 }
1280 }
1281
1282 Ok(info)
1283 }
1284
1285 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1286 &self.entry_points[index]
1287 }
1288}
1289
1290#[test]
1291fn uniform_control_flow() {
1292 use crate::{Expression as E, Statement as S};
1293
1294 let mut type_arena = crate::UniqueArena::new();
1295 let ty = type_arena.insert(
1296 crate::Type {
1297 name: None,
1298 inner: crate::TypeInner::Vector {
1299 size: crate::VectorSize::Bi,
1300 scalar: crate::Scalar::F32,
1301 },
1302 },
1303 Default::default(),
1304 );
1305 let mut global_var_arena = Arena::new();
1306 let non_uniform_global = global_var_arena.append(
1307 crate::GlobalVariable {
1308 name: None,
1309 init: None,
1310 ty,
1311 space: crate::AddressSpace::Handle,
1312 binding: None,
1313 },
1314 Default::default(),
1315 );
1316 let uniform_global = global_var_arena.append(
1317 crate::GlobalVariable {
1318 name: None,
1319 init: None,
1320 ty,
1321 binding: None,
1322 space: crate::AddressSpace::Uniform,
1323 },
1324 Default::default(),
1325 );
1326
1327 let mut expressions = Arena::new();
1328 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1330 let derivative_expr = expressions.append(
1332 E::Derivative {
1333 axis: crate::DerivativeAxis::X,
1334 ctrl: crate::DerivativeControl::None,
1335 expr: constant_expr,
1336 },
1337 Default::default(),
1338 );
1339 let emit_range_constant_derivative = expressions.range_from(0);
1340 let non_uniform_global_expr =
1341 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1342 let uniform_global_expr =
1343 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1344 let emit_range_globals = expressions.range_from(2);
1345
1346 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1348 let access_expr = expressions.append(
1350 E::AccessIndex {
1351 base: non_uniform_global_expr,
1352 index: 1,
1353 },
1354 Default::default(),
1355 );
1356 let emit_range_query_access_globals = expressions.range_from(2);
1357
1358 let mut info = FunctionInfo {
1359 flags: ValidationFlags::all(),
1360 available_stages: ShaderStages::all(),
1361 uniformity: Uniformity::new(),
1362 may_kill: false,
1363 sampling_set: crate::FastHashSet::default(),
1364 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1365 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1366 sampling: crate::FastHashSet::default(),
1367 dual_source_blending: false,
1368 diagnostic_filter_leaf: None,
1369 };
1370 let resolve_context = ResolveContext {
1371 constants: &Arena::new(),
1372 overrides: &Arena::new(),
1373 types: &type_arena,
1374 special_types: &crate::SpecialTypes::default(),
1375 global_vars: &global_var_arena,
1376 local_vars: &Arena::new(),
1377 functions: &Arena::new(),
1378 arguments: &[],
1379 };
1380 for (handle, _) in expressions.iter() {
1381 info.process_expression(
1382 handle,
1383 &expressions,
1384 &[],
1385 &resolve_context,
1386 super::Capabilities::empty(),
1387 )
1388 .unwrap();
1389 }
1390 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1391 assert_eq!(info[uniform_global_expr].ref_count, 1);
1392 assert_eq!(info[query_expr].ref_count, 0);
1393 assert_eq!(info[access_expr].ref_count, 0);
1394 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1395 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1396
1397 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1398 let stmt_if_uniform = S::If {
1399 condition: uniform_global_expr,
1400 accept: crate::Block::new(),
1401 reject: vec![
1402 S::Emit(emit_range_constant_derivative.clone()),
1403 S::Store {
1404 pointer: constant_expr,
1405 value: derivative_expr,
1406 },
1407 ]
1408 .into(),
1409 };
1410 assert_eq!(
1411 info.process_block(
1412 &vec![stmt_emit1, stmt_if_uniform].into(),
1413 &[],
1414 None,
1415 &expressions,
1416 &Arena::new(),
1417 ),
1418 Ok(FunctionUniformity {
1419 result: Uniformity {
1420 non_uniform_result: None,
1421 requirements: UniformityRequirements::DERIVATIVE,
1422 },
1423 exit: ExitFlags::empty(),
1424 }),
1425 );
1426 assert_eq!(info[constant_expr].ref_count, 2);
1427 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1428
1429 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1430 let stmt_if_non_uniform = S::If {
1431 condition: non_uniform_global_expr,
1432 accept: vec![
1433 S::Emit(emit_range_constant_derivative),
1434 S::Store {
1435 pointer: constant_expr,
1436 value: derivative_expr,
1437 },
1438 ]
1439 .into(),
1440 reject: crate::Block::new(),
1441 };
1442 {
1443 let block_info = info.process_block(
1444 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1445 &[],
1446 None,
1447 &expressions,
1448 &Arena::new(),
1449 );
1450 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1451 assert_eq!(info[derivative_expr].ref_count, 2);
1452 } else {
1453 assert_eq!(
1454 block_info,
1455 Err(FunctionError::NonUniformControlFlow(
1456 UniformityRequirements::DERIVATIVE,
1457 derivative_expr,
1458 UniformityDisruptor::Expression(non_uniform_global_expr)
1459 )
1460 .with_span()),
1461 );
1462 assert_eq!(info[derivative_expr].ref_count, 1);
1463
1464 let mut diagnostic_filters = Arena::new();
1466 let diagnostic_filter_leaf = diagnostic_filters.append(
1467 DiagnosticFilterNode {
1468 inner: crate::diagnostic_filter::DiagnosticFilter {
1469 new_severity: crate::diagnostic_filter::Severity::Off,
1470 triggering_rule:
1471 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1472 StandardFilterableTriggeringRule::DerivativeUniformity,
1473 ),
1474 },
1475 parent: None,
1476 },
1477 crate::Span::default(),
1478 );
1479 let mut info = FunctionInfo {
1480 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1481 ..info.clone()
1482 };
1483
1484 let block_info = info.process_block(
1485 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1486 &[],
1487 None,
1488 &expressions,
1489 &diagnostic_filters,
1490 );
1491 assert_eq!(
1492 block_info,
1493 Ok(FunctionUniformity {
1494 result: Uniformity {
1495 non_uniform_result: None,
1496 requirements: UniformityRequirements::DERIVATIVE,
1497 },
1498 exit: ExitFlags::empty()
1499 }),
1500 );
1501 assert_eq!(info[derivative_expr].ref_count, 2);
1502 }
1503 }
1504 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1505
1506 let stmt_emit3 = S::Emit(emit_range_globals);
1507 let stmt_return_non_uniform = S::Return {
1508 value: Some(non_uniform_global_expr),
1509 };
1510 assert_eq!(
1511 info.process_block(
1512 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1513 &[],
1514 Some(UniformityDisruptor::Return),
1515 &expressions,
1516 &Arena::new(),
1517 ),
1518 Ok(FunctionUniformity {
1519 result: Uniformity {
1520 non_uniform_result: Some(non_uniform_global_expr),
1521 requirements: UniformityRequirements::empty(),
1522 },
1523 exit: ExitFlags::MAY_RETURN,
1524 }),
1525 );
1526 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1527
1528 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1530 let stmt_assign = S::Store {
1531 pointer: access_expr,
1532 value: query_expr,
1533 };
1534 let stmt_return_pointer = S::Return {
1535 value: Some(access_expr),
1536 };
1537 let stmt_kill = S::Kill;
1538 assert_eq!(
1539 info.process_block(
1540 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1541 &[],
1542 Some(UniformityDisruptor::Discard),
1543 &expressions,
1544 &Arena::new(),
1545 ),
1546 Ok(FunctionUniformity {
1547 result: Uniformity {
1548 non_uniform_result: Some(non_uniform_global_expr),
1549 requirements: UniformityRequirements::empty(),
1550 },
1551 exit: ExitFlags::all(),
1552 }),
1553 );
1554 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1555}