1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
78 }
79}
80
81#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84 result: Uniformity,
85 exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89 type Output = Self;
90 fn bitor(self, other: Self) -> Self {
91 FunctionUniformity {
92 result: Uniformity {
93 non_uniform_result: self
94 .result
95 .non_uniform_result
96 .or(other.result.non_uniform_result),
97 requirements: self.result.requirements | other.result.requirements,
98 },
99 exit: self.exit | other.exit,
100 }
101 }
102}
103
104impl FunctionUniformity {
105 const fn new() -> Self {
106 FunctionUniformity {
107 result: Uniformity::new(),
108 exit: ExitFlags::empty(),
109 }
110 }
111
112 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114 if self.exit.contains(ExitFlags::MAY_RETURN) {
115 Some(UniformityDisruptor::Return)
116 } else if self.exit.contains(ExitFlags::MAY_KILL) {
117 Some(UniformityDisruptor::Discard)
118 } else {
119 None
120 }
121 }
122}
123
124bitflags::bitflags! {
125 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct GlobalUse: u8 {
130 const READ = 0x1;
132 const WRITE = 0x2;
134 const QUERY = 0x4;
136 const ATOMIC = 0x8;
138 }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145 pub image: Handle<crate::GlobalVariable>,
146 pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152pub struct ExpressionInfo {
154 pub uniformity: Uniformity,
160
161 pub ref_count: usize,
167
168 assignable_global: Option<Handle<crate::GlobalVariable>>,
182
183 pub ty: TypeResolution,
185}
186
187impl ExpressionInfo {
188 const fn new() -> Self {
189 ExpressionInfo {
190 uniformity: Uniformity::new(),
191 ref_count: 0,
192 assignable_global: None,
193 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
195 kind: crate::ScalarKind::Bool,
196 width: 0,
197 })),
198 }
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
204#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
205enum GlobalOrArgument {
206 Global(Handle<crate::GlobalVariable>),
207 Argument(u32),
208}
209
210impl GlobalOrArgument {
211 fn from_expression(
212 expression_arena: &Arena<crate::Expression>,
213 expression: Handle<crate::Expression>,
214 ) -> Result<GlobalOrArgument, ExpressionError> {
215 Ok(match expression_arena[expression] {
216 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
217 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
218 crate::Expression::Access { base, .. }
219 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
220 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
221 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
222 },
223 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
224 })
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
230#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
231struct Sampling {
232 image: GlobalOrArgument,
233 sampler: GlobalOrArgument,
234}
235
236#[derive(Debug, Clone)]
237#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
238#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
239pub struct FunctionInfo {
240 #[allow(dead_code)]
242 flags: ValidationFlags,
243 pub available_stages: ShaderStages,
245 pub uniformity: Uniformity,
247 pub may_kill: bool,
249
250 pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266 global_uses: Box<[GlobalUse]>,
273
274 expressions: Box<[ExpressionInfo]>,
281
282 sampling: crate::FastHashSet<Sampling>,
295
296 pub dual_source_blending: bool,
298
299 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308 pub const fn global_variable_count(&self) -> usize {
309 self.global_uses.len()
310 }
311 pub const fn expression_count(&self) -> usize {
312 self.expressions.len()
313 }
314 pub fn dominates_global_use(&self, other: &Self) -> bool {
315 for (self_global_uses, other_global_uses) in
316 self.global_uses.iter().zip(other.global_uses.iter())
317 {
318 if !self_global_uses.contains(*other_global_uses) {
319 return false;
320 }
321 }
322 true
323 }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327 type Output = GlobalUse;
328 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329 &self.global_uses[handle.index()]
330 }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334 type Output = ExpressionInfo;
335 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336 &self.expressions[handle.index()]
337 }
338}
339
340#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345 Expression(Handle<crate::Expression>),
346 #[error("There is a Return earlier in the control flow of the function")]
347 Return,
348 #[error("There is a Discard earlier in the entry point across all called functions")]
349 Discard,
350}
351
352impl FunctionInfo {
353 #[must_use]
361 fn add_ref_impl(
362 &mut self,
363 expr: Handle<crate::Expression>,
364 global_use: GlobalUse,
365 ) -> NonUniformResult {
366 let info = &mut self.expressions[expr.index()];
367 info.ref_count += 1;
368 if let Some(global) = info.assignable_global {
370 self.global_uses[global.index()] |= global_use;
371 }
372 info.uniformity.non_uniform_result
373 }
374
375 pub(super) fn insert_global_use(
384 &mut self,
385 global_use: GlobalUse,
386 global: Handle<crate::GlobalVariable>,
387 ) {
388 self.global_uses[global.index()] |= global_use;
389 }
390
391 #[must_use]
398 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
399 self.add_ref_impl(expr, GlobalUse::READ)
400 }
401
402 #[must_use]
421 fn add_assignable_ref(
422 &mut self,
423 expr: Handle<crate::Expression>,
424 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
425 ) -> NonUniformResult {
426 let info = &mut self.expressions[expr.index()];
427 info.ref_count += 1;
428 if let Some(global) = info.assignable_global {
431 if let Some(_old) = assignable_global.replace(global) {
432 unreachable!()
433 }
434 }
435 info.uniformity.non_uniform_result
436 }
437
438 fn process_call(
440 &mut self,
441 callee: &Self,
442 arguments: &[Handle<crate::Expression>],
443 expression_arena: &Arena<crate::Expression>,
444 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
445 self.sampling_set
446 .extend(callee.sampling_set.iter().cloned());
447 for sampling in callee.sampling.iter() {
448 let image_storage = match sampling.image {
451 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452 GlobalOrArgument::Argument(i) => {
453 let Some(handle) = arguments.get(i as usize).cloned() else {
454 break;
456 };
457 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458 |source| {
459 FunctionError::Expression { handle, source }
460 .with_span_handle(handle, expression_arena)
461 },
462 )?
463 }
464 };
465
466 let sampler_storage = match sampling.sampler {
467 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
468 GlobalOrArgument::Argument(i) => {
469 let Some(handle) = arguments.get(i as usize).cloned() else {
470 break;
472 };
473 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
474 |source| {
475 FunctionError::Expression { handle, source }
476 .with_span_handle(handle, expression_arena)
477 },
478 )?
479 }
480 };
481
482 match (image_storage, sampler_storage) {
487 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
488 self.sampling_set.insert(SamplingKey { image, sampler });
489 }
490 (image, sampler) => {
491 self.sampling.insert(Sampling { image, sampler });
492 }
493 }
494 }
495
496 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
498 *mine |= *other;
499 }
500
501 Ok(FunctionUniformity {
502 result: callee.uniformity.clone(),
503 exit: if callee.may_kill {
504 ExitFlags::MAY_KILL
505 } else {
506 ExitFlags::empty()
507 },
508 })
509 }
510
511 #[allow(clippy::or_fun_call)]
531 fn process_expression(
532 &mut self,
533 handle: Handle<crate::Expression>,
534 expression_arena: &Arena<crate::Expression>,
535 other_functions: &[FunctionInfo],
536 resolve_context: &ResolveContext,
537 capabilities: super::Capabilities,
538 ) -> Result<(), ExpressionError> {
539 use crate::{Expression as E, SampleLevel as Sl};
540
541 let expression = &expression_arena[handle];
542 let mut assignable_global = None;
543 let uniformity = match *expression {
544 E::Access { base, index } => {
545 let base_ty = self[base].ty.inner_with(resolve_context.types);
546
547 let mut needed_caps = super::Capabilities::empty();
549 let is_binding_array = match *base_ty {
550 crate::TypeInner::BindingArray {
551 base: array_element_ty_handle,
552 ..
553 } => {
554 let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
556 let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
557 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
558 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
559
560 let array_element_ty =
562 &resolve_context.types[array_element_ty_handle].inner;
563
564 needed_caps |= match *array_element_ty {
565 crate::TypeInner::Image { class, .. } => match class {
567 crate::ImageClass::Storage { .. } => sto,
568 _ => st_sb,
569 },
570 crate::TypeInner::Sampler { .. } => sampler,
571 _ => {
573 if let E::GlobalVariable(global_handle) = expression_arena[base] {
574 let global = &resolve_context.global_vars[global_handle];
575 match global.space {
576 crate::AddressSpace::Uniform => uni,
577 crate::AddressSpace::Storage { .. } => st_sb,
578 _ => unreachable!(),
579 }
580 } else {
581 unreachable!()
582 }
583 }
584 };
585
586 true
587 }
588 _ => false,
589 };
590
591 if self[index].uniformity.non_uniform_result.is_some()
592 && !capabilities.contains(needed_caps)
593 && is_binding_array
594 {
595 return Err(ExpressionError::MissingCapabilities(needed_caps));
596 }
597
598 Uniformity {
599 non_uniform_result: self
600 .add_assignable_ref(base, &mut assignable_global)
601 .or(self.add_ref(index)),
602 requirements: UniformityRequirements::empty(),
603 }
604 }
605 E::AccessIndex { base, .. } => Uniformity {
606 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
607 requirements: UniformityRequirements::empty(),
608 },
609 E::Splat { size: _, value } => Uniformity {
611 non_uniform_result: self.add_ref(value),
612 requirements: UniformityRequirements::empty(),
613 },
614 E::Swizzle { vector, .. } => Uniformity {
615 non_uniform_result: self.add_ref(vector),
616 requirements: UniformityRequirements::empty(),
617 },
618 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
619 E::Compose { ref components, .. } => {
620 let non_uniform_result = components
621 .iter()
622 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
623 Uniformity {
624 non_uniform_result,
625 requirements: UniformityRequirements::empty(),
626 }
627 }
628 E::FunctionArgument(index) => {
630 let arg = &resolve_context.arguments[index as usize];
631 let uniform = match arg.binding {
632 Some(crate::Binding::BuiltIn(
633 crate::BuiltIn::WorkGroupId
635 | crate::BuiltIn::WorkGroupSize
636 | crate::BuiltIn::NumWorkGroups,
637 )) => true,
638 _ => false,
639 };
640 Uniformity {
641 non_uniform_result: if uniform { None } else { Some(handle) },
642 requirements: UniformityRequirements::empty(),
643 }
644 }
645 E::GlobalVariable(gh) => {
647 use crate::AddressSpace as As;
648 assignable_global = Some(gh);
649 let var = &resolve_context.global_vars[gh];
650 let uniform = match var.space {
651 As::Function | As::Private => false,
653 As::WorkGroup | As::TaskPayload => true,
656 As::Uniform | As::PushConstant => true,
658 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
660 As::Handle => false,
661 };
662 Uniformity {
663 non_uniform_result: if uniform { None } else { Some(handle) },
664 requirements: UniformityRequirements::empty(),
665 }
666 }
667 E::LocalVariable(_) => Uniformity {
668 non_uniform_result: Some(handle),
669 requirements: UniformityRequirements::empty(),
670 },
671 E::Load { pointer } => Uniformity {
672 non_uniform_result: self.add_ref(pointer),
673 requirements: UniformityRequirements::empty(),
674 },
675 E::ImageSample {
676 image,
677 sampler,
678 gather: _,
679 coordinate,
680 array_index,
681 offset,
682 level,
683 depth_ref,
684 clamp_to_edge: _,
685 } => {
686 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
687 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
688
689 match (image_storage, sampler_storage) {
690 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
691 self.sampling_set.insert(SamplingKey { image, sampler });
692 }
693 _ => {
694 self.sampling.insert(Sampling {
695 image: image_storage,
696 sampler: sampler_storage,
697 });
698 }
699 }
700
701 let array_nur = array_index.and_then(|h| self.add_ref(h));
703 let level_nur = match level {
704 Sl::Auto | Sl::Zero => None,
705 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
706 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
707 };
708 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
709 let offset_nur = offset.and_then(|h| self.add_ref(h));
710 Uniformity {
711 non_uniform_result: self
712 .add_ref(image)
713 .or(self.add_ref(sampler))
714 .or(self.add_ref(coordinate))
715 .or(array_nur)
716 .or(level_nur)
717 .or(dref_nur)
718 .or(offset_nur),
719 requirements: if level.implicit_derivatives() {
720 UniformityRequirements::IMPLICIT_LEVEL
721 } else {
722 UniformityRequirements::empty()
723 },
724 }
725 }
726 E::ImageLoad {
727 image,
728 coordinate,
729 array_index,
730 sample,
731 level,
732 } => {
733 let array_nur = array_index.and_then(|h| self.add_ref(h));
734 let sample_nur = sample.and_then(|h| self.add_ref(h));
735 let level_nur = level.and_then(|h| self.add_ref(h));
736 Uniformity {
737 non_uniform_result: self
738 .add_ref(image)
739 .or(self.add_ref(coordinate))
740 .or(array_nur)
741 .or(sample_nur)
742 .or(level_nur),
743 requirements: UniformityRequirements::empty(),
744 }
745 }
746 E::ImageQuery { image, query } => {
747 let query_nur = match query {
748 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
749 _ => None,
750 };
751 Uniformity {
752 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
753 requirements: UniformityRequirements::empty(),
754 }
755 }
756 E::Unary { expr, .. } => Uniformity {
757 non_uniform_result: self.add_ref(expr),
758 requirements: UniformityRequirements::empty(),
759 },
760 E::Binary { left, right, .. } => Uniformity {
761 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
762 requirements: UniformityRequirements::empty(),
763 },
764 E::Select {
765 condition,
766 accept,
767 reject,
768 } => Uniformity {
769 non_uniform_result: self
770 .add_ref(condition)
771 .or(self.add_ref(accept))
772 .or(self.add_ref(reject)),
773 requirements: UniformityRequirements::empty(),
774 },
775 E::Derivative { expr, .. } => Uniformity {
777 non_uniform_result: self.add_ref(expr),
779 requirements: UniformityRequirements::DERIVATIVE,
780 },
781 E::Relational { argument, .. } => Uniformity {
782 non_uniform_result: self.add_ref(argument),
783 requirements: UniformityRequirements::empty(),
784 },
785 E::Math {
786 fun: _,
787 arg,
788 arg1,
789 arg2,
790 arg3,
791 } => {
792 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
793 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
794 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
795 Uniformity {
796 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
797 requirements: UniformityRequirements::empty(),
798 }
799 }
800 E::As { expr, .. } => Uniformity {
801 non_uniform_result: self.add_ref(expr),
802 requirements: UniformityRequirements::empty(),
803 },
804 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
805 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
806 non_uniform_result: Some(handle),
807 requirements: UniformityRequirements::empty(),
808 },
809 E::WorkGroupUniformLoadResult { .. } => Uniformity {
810 non_uniform_result: None,
812 requirements: UniformityRequirements::empty(),
815 },
816 E::ArrayLength(expr) => Uniformity {
817 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
818 requirements: UniformityRequirements::empty(),
819 },
820 E::RayQueryGetIntersection {
821 query,
822 committed: _,
823 } => Uniformity {
824 non_uniform_result: self.add_ref(query),
825 requirements: UniformityRequirements::empty(),
826 },
827 E::SubgroupBallotResult => Uniformity {
828 non_uniform_result: Some(handle),
829 requirements: UniformityRequirements::empty(),
830 },
831 E::SubgroupOperationResult { .. } => Uniformity {
832 non_uniform_result: Some(handle),
833 requirements: UniformityRequirements::empty(),
834 },
835 E::RayQueryVertexPositions {
836 query,
837 committed: _,
838 } => Uniformity {
839 non_uniform_result: self.add_ref(query),
840 requirements: UniformityRequirements::empty(),
841 },
842 };
843
844 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
845 self.expressions[handle.index()] = ExpressionInfo {
846 uniformity,
847 ref_count: 0,
848 assignable_global,
849 ty,
850 };
851 Ok(())
852 }
853
854 #[allow(clippy::or_fun_call)]
864 fn process_block(
865 &mut self,
866 statements: &crate::Block,
867 other_functions: &[FunctionInfo],
868 mut disruptor: Option<UniformityDisruptor>,
869 expression_arena: &Arena<crate::Expression>,
870 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
871 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
872 use crate::Statement as S;
873
874 let mut combined_uniformity = FunctionUniformity::new();
875 for statement in statements {
876 let uniformity = match *statement {
877 S::Emit(ref range) => {
878 let mut requirements = UniformityRequirements::empty();
879 for expr in range.clone() {
880 let req = self.expressions[expr.index()].uniformity.requirements;
881 if self
882 .flags
883 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
884 && !req.is_empty()
885 {
886 if let Some(cause) = disruptor {
887 let severity = DiagnosticFilterNode::search(
888 self.diagnostic_filter_leaf,
889 diagnostic_filter_arena,
890 StandardFilterableTriggeringRule::DerivativeUniformity,
891 );
892 severity.report_diag(
893 FunctionError::NonUniformControlFlow(req, expr, cause)
894 .with_span_handle(expr, expression_arena),
895 |e, level| log::log!(level, "{e}"),
901 )?;
902 }
903 }
904 requirements |= req;
905 }
906 FunctionUniformity {
907 result: Uniformity {
908 non_uniform_result: None,
909 requirements,
910 },
911 exit: ExitFlags::empty(),
912 }
913 }
914 S::Break | S::Continue => FunctionUniformity::new(),
915 S::Kill => FunctionUniformity {
916 result: Uniformity::new(),
917 exit: if disruptor.is_some() {
918 ExitFlags::MAY_KILL
919 } else {
920 ExitFlags::empty()
921 },
922 },
923 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
924 result: Uniformity {
925 non_uniform_result: None,
926 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
927 },
928 exit: ExitFlags::empty(),
929 },
930 S::WorkGroupUniformLoad { pointer, .. } => {
931 let _condition_nur = self.add_ref(pointer);
932
933 FunctionUniformity {
952 result: Uniformity {
953 non_uniform_result: None,
954 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
955 },
956 exit: ExitFlags::empty(),
957 }
958 }
959 S::Block(ref b) => self.process_block(
960 b,
961 other_functions,
962 disruptor,
963 expression_arena,
964 diagnostic_filter_arena,
965 )?,
966 S::If {
967 condition,
968 ref accept,
969 ref reject,
970 } => {
971 let condition_nur = self.add_ref(condition);
972 let branch_disruptor =
973 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
974 let accept_uniformity = self.process_block(
975 accept,
976 other_functions,
977 branch_disruptor,
978 expression_arena,
979 diagnostic_filter_arena,
980 )?;
981 let reject_uniformity = self.process_block(
982 reject,
983 other_functions,
984 branch_disruptor,
985 expression_arena,
986 diagnostic_filter_arena,
987 )?;
988 accept_uniformity | reject_uniformity
989 }
990 S::Switch {
991 selector,
992 ref cases,
993 } => {
994 let selector_nur = self.add_ref(selector);
995 let branch_disruptor =
996 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
997 let mut uniformity = FunctionUniformity::new();
998 let mut case_disruptor = branch_disruptor;
999 for case in cases.iter() {
1000 let case_uniformity = self.process_block(
1001 &case.body,
1002 other_functions,
1003 case_disruptor,
1004 expression_arena,
1005 diagnostic_filter_arena,
1006 )?;
1007 case_disruptor = if case.fall_through {
1008 case_disruptor.or(case_uniformity.exit_disruptor())
1009 } else {
1010 branch_disruptor
1011 };
1012 uniformity = uniformity | case_uniformity;
1013 }
1014 uniformity
1015 }
1016 S::Loop {
1017 ref body,
1018 ref continuing,
1019 break_if,
1020 } => {
1021 let body_uniformity = self.process_block(
1022 body,
1023 other_functions,
1024 disruptor,
1025 expression_arena,
1026 diagnostic_filter_arena,
1027 )?;
1028 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1029 let continuing_uniformity = self.process_block(
1030 continuing,
1031 other_functions,
1032 continuing_disruptor,
1033 expression_arena,
1034 diagnostic_filter_arena,
1035 )?;
1036 if let Some(expr) = break_if {
1037 let _ = self.add_ref(expr);
1038 }
1039 body_uniformity | continuing_uniformity
1040 }
1041 S::Return { value } => FunctionUniformity {
1042 result: Uniformity {
1043 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1044 requirements: UniformityRequirements::empty(),
1045 },
1046 exit: if disruptor.is_some() {
1047 ExitFlags::MAY_RETURN
1048 } else {
1049 ExitFlags::empty()
1050 },
1051 },
1052 S::Store { pointer, value } => {
1056 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1057 let _ = self.add_ref(value);
1058 FunctionUniformity::new()
1059 }
1060 S::ImageStore {
1061 image,
1062 coordinate,
1063 array_index,
1064 value,
1065 } => {
1066 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1067 if let Some(expr) = array_index {
1068 let _ = self.add_ref(expr);
1069 }
1070 let _ = self.add_ref(coordinate);
1071 let _ = self.add_ref(value);
1072 FunctionUniformity::new()
1073 }
1074 S::Call {
1075 function,
1076 ref arguments,
1077 result: _,
1078 } => {
1079 for &argument in arguments {
1080 let _ = self.add_ref(argument);
1081 }
1082 let info = &other_functions[function.index()];
1083 self.process_call(info, arguments, expression_arena)?
1085 }
1086 S::Atomic {
1087 pointer,
1088 ref fun,
1089 value,
1090 result: _,
1091 } => {
1092 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1093 let _ = self.add_ref(value);
1094 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1095 let _ = self.add_ref(cmp);
1096 }
1097 FunctionUniformity::new()
1098 }
1099 S::ImageAtomic {
1100 image,
1101 coordinate,
1102 array_index,
1103 fun: _,
1104 value,
1105 } => {
1106 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1107 let _ = self.add_ref(coordinate);
1108 if let Some(expr) = array_index {
1109 let _ = self.add_ref(expr);
1110 }
1111 let _ = self.add_ref(value);
1112 FunctionUniformity::new()
1113 }
1114 S::RayQuery { query, ref fun } => {
1115 let _ = self.add_ref(query);
1116 match *fun {
1117 crate::RayQueryFunction::Initialize {
1118 acceleration_structure,
1119 descriptor,
1120 } => {
1121 let _ = self.add_ref(acceleration_structure);
1122 let _ = self.add_ref(descriptor);
1123 }
1124 crate::RayQueryFunction::Proceed { result: _ } => {}
1125 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1126 let _ = self.add_ref(hit_t);
1127 }
1128 crate::RayQueryFunction::ConfirmIntersection => {}
1129 crate::RayQueryFunction::Terminate => {}
1130 }
1131 FunctionUniformity::new()
1132 }
1133 S::SubgroupBallot {
1134 result: _,
1135 predicate,
1136 } => {
1137 if let Some(predicate) = predicate {
1138 let _ = self.add_ref(predicate);
1139 }
1140 FunctionUniformity::new()
1141 }
1142 S::SubgroupCollectiveOperation {
1143 op: _,
1144 collective_op: _,
1145 argument,
1146 result: _,
1147 } => {
1148 let _ = self.add_ref(argument);
1149 FunctionUniformity::new()
1150 }
1151 S::SubgroupGather {
1152 mode,
1153 argument,
1154 result: _,
1155 } => {
1156 let _ = self.add_ref(argument);
1157 match mode {
1158 crate::GatherMode::BroadcastFirst => {}
1159 crate::GatherMode::Broadcast(index)
1160 | crate::GatherMode::Shuffle(index)
1161 | crate::GatherMode::ShuffleDown(index)
1162 | crate::GatherMode::ShuffleUp(index)
1163 | crate::GatherMode::ShuffleXor(index)
1164 | crate::GatherMode::QuadBroadcast(index) => {
1165 let _ = self.add_ref(index);
1166 }
1167 crate::GatherMode::QuadSwap(_) => {}
1168 }
1169 FunctionUniformity::new()
1170 }
1171 };
1172
1173 disruptor = disruptor.or(uniformity.exit_disruptor());
1174 combined_uniformity = combined_uniformity | uniformity;
1175 }
1176 Ok(combined_uniformity)
1177 }
1178}
1179
1180impl ModuleInfo {
1181 pub(super) fn process_const_expression(
1183 &mut self,
1184 handle: Handle<crate::Expression>,
1185 resolve_context: &ResolveContext,
1186 gctx: crate::proc::GlobalCtx,
1187 ) -> Result<(), super::ConstExpressionError> {
1188 self.const_expression_types[handle.index()] =
1189 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1190 Ok(())
1191 }
1192
1193 pub(super) fn process_function(
1196 &self,
1197 fun: &crate::Function,
1198 module: &crate::Module,
1199 flags: ValidationFlags,
1200 capabilities: super::Capabilities,
1201 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1202 let mut info = FunctionInfo {
1203 flags,
1204 available_stages: ShaderStages::all(),
1205 uniformity: Uniformity::new(),
1206 may_kill: false,
1207 sampling_set: crate::FastHashSet::default(),
1208 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1209 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1210 sampling: crate::FastHashSet::default(),
1211 dual_source_blending: false,
1212 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1213 };
1214 let resolve_context =
1215 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1216
1217 for (handle, _) in fun.expressions.iter() {
1218 if let Err(source) = info.process_expression(
1219 handle,
1220 &fun.expressions,
1221 &self.functions,
1222 &resolve_context,
1223 capabilities,
1224 ) {
1225 return Err(FunctionError::Expression { handle, source }
1226 .with_span_handle(handle, &fun.expressions));
1227 }
1228 }
1229
1230 for (_, expr) in fun.local_variables.iter() {
1231 if let Some(init) = expr.init {
1232 let _ = info.add_ref(init);
1233 }
1234 }
1235
1236 let uniformity = info.process_block(
1237 &fun.body,
1238 &self.functions,
1239 None,
1240 &fun.expressions,
1241 &module.diagnostic_filters,
1242 )?;
1243 info.uniformity = uniformity.result;
1244 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1245
1246 for &handle in fun.named_expressions.keys() {
1252 if let Some(global) = info[handle].assignable_global {
1253 if info.global_uses[global.index()].is_empty() {
1254 info.global_uses[global.index()] = GlobalUse::QUERY;
1255 }
1256 }
1257 }
1258
1259 Ok(info)
1260 }
1261
1262 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1263 &self.entry_points[index]
1264 }
1265}
1266
1267#[test]
1268fn uniform_control_flow() {
1269 use crate::{Expression as E, Statement as S};
1270
1271 let mut type_arena = crate::UniqueArena::new();
1272 let ty = type_arena.insert(
1273 crate::Type {
1274 name: None,
1275 inner: crate::TypeInner::Vector {
1276 size: crate::VectorSize::Bi,
1277 scalar: crate::Scalar::F32,
1278 },
1279 },
1280 Default::default(),
1281 );
1282 let mut global_var_arena = Arena::new();
1283 let non_uniform_global = global_var_arena.append(
1284 crate::GlobalVariable {
1285 name: None,
1286 init: None,
1287 ty,
1288 space: crate::AddressSpace::Handle,
1289 binding: None,
1290 },
1291 Default::default(),
1292 );
1293 let uniform_global = global_var_arena.append(
1294 crate::GlobalVariable {
1295 name: None,
1296 init: None,
1297 ty,
1298 binding: None,
1299 space: crate::AddressSpace::Uniform,
1300 },
1301 Default::default(),
1302 );
1303
1304 let mut expressions = Arena::new();
1305 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1307 let derivative_expr = expressions.append(
1309 E::Derivative {
1310 axis: crate::DerivativeAxis::X,
1311 ctrl: crate::DerivativeControl::None,
1312 expr: constant_expr,
1313 },
1314 Default::default(),
1315 );
1316 let emit_range_constant_derivative = expressions.range_from(0);
1317 let non_uniform_global_expr =
1318 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1319 let uniform_global_expr =
1320 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1321 let emit_range_globals = expressions.range_from(2);
1322
1323 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1325 let access_expr = expressions.append(
1327 E::AccessIndex {
1328 base: non_uniform_global_expr,
1329 index: 1,
1330 },
1331 Default::default(),
1332 );
1333 let emit_range_query_access_globals = expressions.range_from(2);
1334
1335 let mut info = FunctionInfo {
1336 flags: ValidationFlags::all(),
1337 available_stages: ShaderStages::all(),
1338 uniformity: Uniformity::new(),
1339 may_kill: false,
1340 sampling_set: crate::FastHashSet::default(),
1341 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1342 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1343 sampling: crate::FastHashSet::default(),
1344 dual_source_blending: false,
1345 diagnostic_filter_leaf: None,
1346 };
1347 let resolve_context = ResolveContext {
1348 constants: &Arena::new(),
1349 overrides: &Arena::new(),
1350 types: &type_arena,
1351 special_types: &crate::SpecialTypes::default(),
1352 global_vars: &global_var_arena,
1353 local_vars: &Arena::new(),
1354 functions: &Arena::new(),
1355 arguments: &[],
1356 };
1357 for (handle, _) in expressions.iter() {
1358 info.process_expression(
1359 handle,
1360 &expressions,
1361 &[],
1362 &resolve_context,
1363 super::Capabilities::empty(),
1364 )
1365 .unwrap();
1366 }
1367 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1368 assert_eq!(info[uniform_global_expr].ref_count, 1);
1369 assert_eq!(info[query_expr].ref_count, 0);
1370 assert_eq!(info[access_expr].ref_count, 0);
1371 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1372 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1373
1374 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1375 let stmt_if_uniform = S::If {
1376 condition: uniform_global_expr,
1377 accept: crate::Block::new(),
1378 reject: vec![
1379 S::Emit(emit_range_constant_derivative.clone()),
1380 S::Store {
1381 pointer: constant_expr,
1382 value: derivative_expr,
1383 },
1384 ]
1385 .into(),
1386 };
1387 assert_eq!(
1388 info.process_block(
1389 &vec![stmt_emit1, stmt_if_uniform].into(),
1390 &[],
1391 None,
1392 &expressions,
1393 &Arena::new(),
1394 ),
1395 Ok(FunctionUniformity {
1396 result: Uniformity {
1397 non_uniform_result: None,
1398 requirements: UniformityRequirements::DERIVATIVE,
1399 },
1400 exit: ExitFlags::empty(),
1401 }),
1402 );
1403 assert_eq!(info[constant_expr].ref_count, 2);
1404 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1405
1406 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1407 let stmt_if_non_uniform = S::If {
1408 condition: non_uniform_global_expr,
1409 accept: vec![
1410 S::Emit(emit_range_constant_derivative),
1411 S::Store {
1412 pointer: constant_expr,
1413 value: derivative_expr,
1414 },
1415 ]
1416 .into(),
1417 reject: crate::Block::new(),
1418 };
1419 {
1420 let block_info = info.process_block(
1421 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1422 &[],
1423 None,
1424 &expressions,
1425 &Arena::new(),
1426 );
1427 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1428 assert_eq!(info[derivative_expr].ref_count, 2);
1429 } else {
1430 assert_eq!(
1431 block_info,
1432 Err(FunctionError::NonUniformControlFlow(
1433 UniformityRequirements::DERIVATIVE,
1434 derivative_expr,
1435 UniformityDisruptor::Expression(non_uniform_global_expr)
1436 )
1437 .with_span()),
1438 );
1439 assert_eq!(info[derivative_expr].ref_count, 1);
1440
1441 let mut diagnostic_filters = Arena::new();
1443 let diagnostic_filter_leaf = diagnostic_filters.append(
1444 DiagnosticFilterNode {
1445 inner: crate::diagnostic_filter::DiagnosticFilter {
1446 new_severity: crate::diagnostic_filter::Severity::Off,
1447 triggering_rule:
1448 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1449 StandardFilterableTriggeringRule::DerivativeUniformity,
1450 ),
1451 },
1452 parent: None,
1453 },
1454 crate::Span::default(),
1455 );
1456 let mut info = FunctionInfo {
1457 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1458 ..info.clone()
1459 };
1460
1461 let block_info = info.process_block(
1462 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1463 &[],
1464 None,
1465 &expressions,
1466 &diagnostic_filters,
1467 );
1468 assert_eq!(
1469 block_info,
1470 Ok(FunctionUniformity {
1471 result: Uniformity {
1472 non_uniform_result: None,
1473 requirements: UniformityRequirements::DERIVATIVE,
1474 },
1475 exit: ExitFlags::empty()
1476 }),
1477 );
1478 assert_eq!(info[derivative_expr].ref_count, 2);
1479 }
1480 }
1481 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1482
1483 let stmt_emit3 = S::Emit(emit_range_globals);
1484 let stmt_return_non_uniform = S::Return {
1485 value: Some(non_uniform_global_expr),
1486 };
1487 assert_eq!(
1488 info.process_block(
1489 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1490 &[],
1491 Some(UniformityDisruptor::Return),
1492 &expressions,
1493 &Arena::new(),
1494 ),
1495 Ok(FunctionUniformity {
1496 result: Uniformity {
1497 non_uniform_result: Some(non_uniform_global_expr),
1498 requirements: UniformityRequirements::empty(),
1499 },
1500 exit: ExitFlags::MAY_RETURN,
1501 }),
1502 );
1503 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1504
1505 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1507 let stmt_assign = S::Store {
1508 pointer: access_expr,
1509 value: query_expr,
1510 };
1511 let stmt_return_pointer = S::Return {
1512 value: Some(access_expr),
1513 };
1514 let stmt_kill = S::Kill;
1515 assert_eq!(
1516 info.process_block(
1517 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1518 &[],
1519 Some(UniformityDisruptor::Discard),
1520 &expressions,
1521 &Arena::new(),
1522 ),
1523 Ok(FunctionUniformity {
1524 result: Uniformity {
1525 non_uniform_result: Some(non_uniform_global_expr),
1526 requirements: UniformityRequirements::empty(),
1527 },
1528 exit: ExitFlags::all(),
1529 }),
1530 );
1531 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1532}