1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{
12 ExpressionError, FunctionError, ImmediateSlots, ModuleInfo, ShaderStages, ValidationFlags,
13};
14use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
15use crate::span::{AddSpan as _, WithSpan};
16use crate::{
17 arena::{Arena, Handle},
18 proc::{ResolveContext, TypeResolution},
19};
20
21pub type NonUniformResult = Option<Handle<crate::Expression>>;
22
23const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
24
25bitflags::bitflags! {
26 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
28 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
29 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
30 pub struct UniformityRequirements: u8 {
31 const WORK_GROUP_BARRIER = 0x1;
32 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
33 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
34 const COOP_OPS = 0x8;
35 }
36}
37
38#[derive(Clone, Debug)]
40#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
41#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct Uniformity {
44 pub non_uniform_result: NonUniformResult,
56 pub requirements: UniformityRequirements,
58}
59
60impl Uniformity {
61 const fn new() -> Self {
62 Uniformity {
63 non_uniform_result: None,
64 requirements: UniformityRequirements::empty(),
65 }
66 }
67}
68
69bitflags::bitflags! {
70 #[derive(Clone, Copy, Debug, PartialEq)]
71 struct ExitFlags: u8 {
72 const MAY_RETURN = 0x1;
76 const MAY_KILL = 0x2;
81 }
82}
83
84#[cfg_attr(test, derive(Debug, PartialEq))]
86struct FunctionUniformity {
87 result: Uniformity,
88 exit: ExitFlags,
89}
90
91impl ops::BitOr for FunctionUniformity {
92 type Output = Self;
93 fn bitor(self, other: Self) -> Self {
94 FunctionUniformity {
95 result: Uniformity {
96 non_uniform_result: self
97 .result
98 .non_uniform_result
99 .or(other.result.non_uniform_result),
100 requirements: self.result.requirements | other.result.requirements,
101 },
102 exit: self.exit | other.exit,
103 }
104 }
105}
106
107impl FunctionUniformity {
108 const fn new() -> Self {
109 FunctionUniformity {
110 result: Uniformity::new(),
111 exit: ExitFlags::empty(),
112 }
113 }
114
115 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
117 if self.exit.contains(ExitFlags::MAY_RETURN) {
118 Some(UniformityDisruptor::Return)
119 } else if self.exit.contains(ExitFlags::MAY_KILL) {
120 Some(UniformityDisruptor::Discard)
121 } else {
122 None
123 }
124 }
125}
126
127bitflags::bitflags! {
128 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
130 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
131 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
132 pub struct GlobalUse: u8 {
133 const READ = 0x1;
135 const WRITE = 0x2;
137 const QUERY = 0x4;
139 const ATOMIC = 0x8;
141 }
142}
143
144#[derive(Clone, Debug, Eq, Hash, PartialEq)]
145#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147pub struct SamplingKey {
148 pub image: Handle<crate::GlobalVariable>,
149 pub sampler: Handle<crate::GlobalVariable>,
150}
151
152#[derive(Clone, Debug)]
153#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
154#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
155pub struct ExpressionInfo {
157 pub uniformity: Uniformity,
163
164 pub ref_count: usize,
170
171 assignable_global: Option<Handle<crate::GlobalVariable>>,
185
186 pub ty: TypeResolution,
188}
189
190impl ExpressionInfo {
191 const fn new() -> Self {
192 ExpressionInfo {
193 uniformity: Uniformity::new(),
194 ref_count: 0,
195 assignable_global: None,
196 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
198 kind: crate::ScalarKind::Bool,
199 width: 0,
200 })),
201 }
202 }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
206#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
207#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
208enum GlobalOrArgument {
209 Global(Handle<crate::GlobalVariable>),
210 Argument(u32),
211}
212
213impl GlobalOrArgument {
214 fn from_expression(
215 expression_arena: &Arena<crate::Expression>,
216 expression: Handle<crate::Expression>,
217 ) -> Result<GlobalOrArgument, ExpressionError> {
218 Ok(match expression_arena[expression] {
219 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
220 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
221 crate::Expression::Access { base, .. }
222 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
223 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
224 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
225 },
226 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
227 })
228 }
229}
230
231#[derive(Debug, Clone, PartialEq, Eq, Hash)]
232#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
233#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
234struct Sampling {
235 image: GlobalOrArgument,
236 sampler: GlobalOrArgument,
237}
238
239#[derive(Debug, Clone)]
240#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
241#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
242pub struct FunctionInfo {
243 flags: ValidationFlags,
245 pub available_stages: ShaderStages,
247 pub uniformity: Uniformity,
249 pub may_kill: bool,
251
252 pub sampling_set: crate::FastHashSet<SamplingKey>,
267
268 pub global_uses: Box<[GlobalUse]>,
275
276 expressions: Box<[ExpressionInfo]>,
283
284 sampling: crate::FastHashSet<Sampling>,
297
298 pub dual_source_blending: bool,
300
301 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
307
308 pub immediate_slots_used: ImmediateSlots,
311}
312
313impl FunctionInfo {
314 pub const fn global_variable_count(&self) -> usize {
315 self.global_uses.len()
316 }
317 pub const fn expression_count(&self) -> usize {
318 self.expressions.len()
319 }
320 pub fn dominates_global_use(&self, other: &Self) -> bool {
321 for (self_global_uses, other_global_uses) in
322 self.global_uses.iter().zip(other.global_uses.iter())
323 {
324 if !self_global_uses.contains(*other_global_uses) {
325 return false;
326 }
327 }
328 true
329 }
330}
331
332impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
333 type Output = GlobalUse;
334 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
335 &self.global_uses[handle.index()]
336 }
337}
338
339impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
340 type Output = ExpressionInfo;
341 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
342 &self.expressions[handle.index()]
343 }
344}
345
346#[derive(Clone, Copy, Debug, thiserror::Error)]
348#[cfg_attr(test, derive(PartialEq))]
349pub enum UniformityDisruptor {
350 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
351 Expression(Handle<crate::Expression>),
352 #[error("There is a Return earlier in the control flow of the function")]
353 Return,
354 #[error("There is a Discard earlier in the entry point across all called functions")]
355 Discard,
356}
357
358impl FunctionInfo {
359 #[must_use]
367 fn add_ref_impl(
368 &mut self,
369 expr: Handle<crate::Expression>,
370 global_use: GlobalUse,
371 ) -> NonUniformResult {
372 let info = &mut self.expressions[expr.index()];
373 info.ref_count += 1;
374 if let Some(global) = info.assignable_global {
376 self.global_uses[global.index()] |= global_use;
377 }
378 info.uniformity.non_uniform_result
379 }
380
381 pub(super) fn insert_global_use(
390 &mut self,
391 global_use: GlobalUse,
392 global: Handle<crate::GlobalVariable>,
393 ) {
394 self.global_uses[global.index()] |= global_use;
395 }
396
397 #[must_use]
404 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
405 self.add_ref_impl(expr, GlobalUse::READ)
406 }
407
408 #[must_use]
427 fn add_assignable_ref(
428 &mut self,
429 expr: Handle<crate::Expression>,
430 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
431 ) -> NonUniformResult {
432 let info = &mut self.expressions[expr.index()];
433 info.ref_count += 1;
434 if let Some(global) = info.assignable_global {
437 if let Some(_old) = assignable_global.replace(global) {
438 unreachable!()
439 }
440 }
441 info.uniformity.non_uniform_result
442 }
443
444 fn process_call(
446 &mut self,
447 callee: &Self,
448 arguments: &[Handle<crate::Expression>],
449 expression_arena: &Arena<crate::Expression>,
450 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
451 self.sampling_set
452 .extend(callee.sampling_set.iter().cloned());
453 for sampling in callee.sampling.iter() {
454 let image_storage = match sampling.image {
457 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
458 GlobalOrArgument::Argument(i) => {
459 let Some(handle) = arguments.get(i as usize).cloned() else {
460 break;
462 };
463 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
464 |source| {
465 FunctionError::Expression { handle, source }
466 .with_span_handle(handle, expression_arena)
467 },
468 )?
469 }
470 };
471
472 let sampler_storage = match sampling.sampler {
473 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
474 GlobalOrArgument::Argument(i) => {
475 let Some(handle) = arguments.get(i as usize).cloned() else {
476 break;
478 };
479 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
480 |source| {
481 FunctionError::Expression { handle, source }
482 .with_span_handle(handle, expression_arena)
483 },
484 )?
485 }
486 };
487
488 match (image_storage, sampler_storage) {
493 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
494 self.sampling_set.insert(SamplingKey { image, sampler });
495 }
496 (image, sampler) => {
497 self.sampling.insert(Sampling { image, sampler });
498 }
499 }
500 }
501
502 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
504 *mine |= *other;
505 }
506 self.immediate_slots_used |= callee.immediate_slots_used;
507
508 Ok(FunctionUniformity {
509 result: callee.uniformity.clone(),
510 exit: if callee.may_kill {
511 ExitFlags::MAY_KILL
512 } else {
513 ExitFlags::empty()
514 },
515 })
516 }
517
518 #[allow(clippy::or_fun_call)]
538 fn process_expression(
539 &mut self,
540 handle: Handle<crate::Expression>,
541 expression_arena: &Arena<crate::Expression>,
542 other_functions: &[FunctionInfo],
543 resolve_context: &ResolveContext,
544 capabilities: super::Capabilities,
545 ) -> Result<(), ExpressionError> {
546 use crate::{Expression as E, SampleLevel as Sl};
547
548 let expression = &expression_arena[handle];
549 let mut assignable_global = None;
550 let uniformity = match *expression {
551 E::Access { base, index } => {
552 let base_ty = self[base].ty.inner_with(resolve_context.types);
553
554 let mut needed_caps = super::Capabilities::empty();
556 let is_binding_array = match *base_ty {
557 crate::TypeInner::BindingArray {
558 base: array_element_ty_handle,
559 ..
560 } => {
561 let array_element_ty =
563 &resolve_context.types[array_element_ty_handle].inner;
564
565 needed_caps |= match *array_element_ty {
566 crate::TypeInner::Image { class, .. } => match class {
568 crate::ImageClass::Storage { .. } => {
569 super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
570 }
571 _ => {
572 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
573 }
574 },
575 crate::TypeInner::Sampler { .. } => {
576 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
577 }
578 _ => {
580 if let E::GlobalVariable(global_handle) = expression_arena[base] {
581 let global = &resolve_context.global_vars[global_handle];
582 match global.space {
583 crate::AddressSpace::Uniform => {
584 super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
585 }
586 crate::AddressSpace::Storage { .. } => {
587 super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
588 }
589 _ => unreachable!(),
590 }
591 } else {
592 unreachable!()
593 }
594 }
595 };
596
597 true
598 }
599 _ => false,
600 };
601
602 if self[index].uniformity.non_uniform_result.is_some()
603 && !capabilities.contains(needed_caps)
604 && is_binding_array
605 {
606 return Err(ExpressionError::MissingCapabilities(needed_caps));
607 }
608
609 Uniformity {
610 non_uniform_result: self
611 .add_assignable_ref(base, &mut assignable_global)
612 .or(self.add_ref(index)),
613 requirements: UniformityRequirements::empty(),
614 }
615 }
616 E::AccessIndex { base, .. } => Uniformity {
617 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
618 requirements: UniformityRequirements::empty(),
619 },
620 E::Splat { size: _, value } => Uniformity {
622 non_uniform_result: self.add_ref(value),
623 requirements: UniformityRequirements::empty(),
624 },
625 E::Swizzle { vector, .. } => Uniformity {
626 non_uniform_result: self.add_ref(vector),
627 requirements: UniformityRequirements::empty(),
628 },
629 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
630 E::Compose { ref components, .. } => {
631 let non_uniform_result = components
632 .iter()
633 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
634 Uniformity {
635 non_uniform_result,
636 requirements: UniformityRequirements::empty(),
637 }
638 }
639 E::FunctionArgument(index) => {
641 let arg = &resolve_context.arguments[index as usize];
642 let uniform = match arg.binding {
643 Some(crate::Binding::BuiltIn(
644 crate::BuiltIn::WorkGroupId
646 | crate::BuiltIn::WorkGroupSize
647 | crate::BuiltIn::NumWorkGroups,
648 )) => true,
649 _ => false,
650 };
651 Uniformity {
652 non_uniform_result: if uniform { None } else { Some(handle) },
653 requirements: UniformityRequirements::empty(),
654 }
655 }
656 E::GlobalVariable(gh) => {
658 use crate::AddressSpace as As;
659 assignable_global = Some(gh);
660 let var = &resolve_context.global_vars[gh];
661 let uniform = match var.space {
662 As::Function | As::Private | As::RayPayload | As::IncomingRayPayload => false,
664 As::WorkGroup | As::TaskPayload => true,
667 As::Uniform | As::Immediate => true,
669 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
671 As::Handle => false,
672 };
673 Uniformity {
674 non_uniform_result: if uniform { None } else { Some(handle) },
675 requirements: UniformityRequirements::empty(),
676 }
677 }
678 E::LocalVariable(_) => Uniformity {
679 non_uniform_result: Some(handle),
680 requirements: UniformityRequirements::empty(),
681 },
682 E::Load { pointer } => {
683 let non_uniform_result = self.add_ref(pointer);
684 if let Some(global) = self.expressions[pointer.index()].assignable_global {
686 if resolve_context.global_vars[global].space == crate::AddressSpace::Immediate {
687 self.immediate_slots_used |= ImmediateSlots::for_pointer(
688 pointer,
689 global,
690 expression_arena,
691 resolve_context.global_vars,
692 resolve_context.types,
693 );
694 }
695 }
696 Uniformity {
697 non_uniform_result,
698 requirements: UniformityRequirements::empty(),
699 }
700 }
701 E::ImageSample {
702 image,
703 sampler,
704 gather: _,
705 coordinate,
706 array_index,
707 offset,
708 level,
709 depth_ref,
710 clamp_to_edge: _,
711 } => {
712 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
713 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
714
715 match (image_storage, sampler_storage) {
716 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
717 self.sampling_set.insert(SamplingKey { image, sampler });
718 }
719 _ => {
720 self.sampling.insert(Sampling {
721 image: image_storage,
722 sampler: sampler_storage,
723 });
724 }
725 }
726
727 let array_nur = array_index.and_then(|h| self.add_ref(h));
729 let level_nur = match level {
730 Sl::Auto | Sl::Zero => None,
731 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
732 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
733 };
734 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
735 let offset_nur = offset.and_then(|h| self.add_ref(h));
736 Uniformity {
737 non_uniform_result: self
738 .add_ref(image)
739 .or(self.add_ref(sampler))
740 .or(self.add_ref(coordinate))
741 .or(array_nur)
742 .or(level_nur)
743 .or(dref_nur)
744 .or(offset_nur),
745 requirements: if level.implicit_derivatives() {
746 UniformityRequirements::IMPLICIT_LEVEL
747 } else {
748 UniformityRequirements::empty()
749 },
750 }
751 }
752 E::ImageLoad {
753 image,
754 coordinate,
755 array_index,
756 sample,
757 level,
758 } => {
759 let array_nur = array_index.and_then(|h| self.add_ref(h));
760 let sample_nur = sample.and_then(|h| self.add_ref(h));
761 let level_nur = level.and_then(|h| self.add_ref(h));
762 Uniformity {
763 non_uniform_result: self
764 .add_ref(image)
765 .or(self.add_ref(coordinate))
766 .or(array_nur)
767 .or(sample_nur)
768 .or(level_nur),
769 requirements: UniformityRequirements::empty(),
770 }
771 }
772 E::ImageQuery { image, query } => {
773 let query_nur = match query {
774 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
775 _ => None,
776 };
777 Uniformity {
778 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
779 requirements: UniformityRequirements::empty(),
780 }
781 }
782 E::Unary { expr, .. } => Uniformity {
783 non_uniform_result: self.add_ref(expr),
784 requirements: UniformityRequirements::empty(),
785 },
786 E::Binary { left, right, .. } => Uniformity {
787 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
788 requirements: UniformityRequirements::empty(),
789 },
790 E::Select {
791 condition,
792 accept,
793 reject,
794 } => Uniformity {
795 non_uniform_result: self
796 .add_ref(condition)
797 .or(self.add_ref(accept))
798 .or(self.add_ref(reject)),
799 requirements: UniformityRequirements::empty(),
800 },
801 E::Derivative { expr, .. } => Uniformity {
803 non_uniform_result: self.add_ref(expr),
805 requirements: UniformityRequirements::DERIVATIVE,
806 },
807 E::Relational { argument, .. } => Uniformity {
808 non_uniform_result: self.add_ref(argument),
809 requirements: UniformityRequirements::empty(),
810 },
811 E::Math {
812 fun: _,
813 arg,
814 arg1,
815 arg2,
816 arg3,
817 } => {
818 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
819 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
820 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
821 Uniformity {
822 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
823 requirements: UniformityRequirements::empty(),
824 }
825 }
826 E::As { expr, .. } => Uniformity {
827 non_uniform_result: self.add_ref(expr),
828 requirements: UniformityRequirements::empty(),
829 },
830 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
831 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
832 non_uniform_result: Some(handle),
833 requirements: UniformityRequirements::empty(),
834 },
835 E::WorkGroupUniformLoadResult { .. } => Uniformity {
836 non_uniform_result: None,
838 requirements: UniformityRequirements::empty(),
841 },
842 E::ArrayLength(expr) => Uniformity {
843 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
844 requirements: UniformityRequirements::empty(),
845 },
846 E::RayQueryGetIntersection {
847 query,
848 committed: _,
849 } => Uniformity {
850 non_uniform_result: self.add_ref(query),
851 requirements: UniformityRequirements::empty(),
852 },
853 E::SubgroupBallotResult => Uniformity {
854 non_uniform_result: Some(handle),
855 requirements: UniformityRequirements::empty(),
856 },
857 E::SubgroupOperationResult { .. } => Uniformity {
858 non_uniform_result: Some(handle),
859 requirements: UniformityRequirements::empty(),
860 },
861 E::RayQueryVertexPositions {
862 query,
863 committed: _,
864 } => Uniformity {
865 non_uniform_result: self.add_ref(query),
866 requirements: UniformityRequirements::empty(),
867 },
868 E::CooperativeLoad { ref data, .. } => Uniformity {
869 non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)),
870 requirements: UniformityRequirements::COOP_OPS,
871 },
872 E::CooperativeMultiplyAdd { a, b, c } => Uniformity {
873 non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))),
874 requirements: UniformityRequirements::COOP_OPS,
875 },
876 };
877
878 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
879 self.expressions[handle.index()] = ExpressionInfo {
880 uniformity,
881 ref_count: 0,
882 assignable_global,
883 ty,
884 };
885 Ok(())
886 }
887
888 #[allow(clippy::or_fun_call)]
898 fn process_block(
899 &mut self,
900 statements: &crate::Block,
901 other_functions: &[FunctionInfo],
902 mut disruptor: Option<UniformityDisruptor>,
903 expression_arena: &Arena<crate::Expression>,
904 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
905 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
906 use crate::Statement as S;
907
908 let mut combined_uniformity = FunctionUniformity::new();
909 for statement in statements {
910 let uniformity = match *statement {
911 S::Emit(ref range) => {
912 let mut requirements = UniformityRequirements::empty();
913 for expr in range.clone() {
914 let req = self.expressions[expr.index()].uniformity.requirements;
915 if self
916 .flags
917 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
918 && !req.is_empty()
919 {
920 if let Some(cause) = disruptor {
921 let severity = DiagnosticFilterNode::search(
922 self.diagnostic_filter_leaf,
923 diagnostic_filter_arena,
924 StandardFilterableTriggeringRule::DerivativeUniformity,
925 );
926 severity.report_diag(
927 FunctionError::NonUniformControlFlow(req, expr, cause)
928 .with_span_handle(expr, expression_arena),
929 |e, level| log::log!(level, "{e}"),
935 )?;
936 }
937 }
938 requirements |= req;
939 }
940 FunctionUniformity {
941 result: Uniformity {
942 non_uniform_result: None,
943 requirements,
944 },
945 exit: ExitFlags::empty(),
946 }
947 }
948 S::Break | S::Continue => FunctionUniformity::new(),
949 S::Kill => FunctionUniformity {
950 result: Uniformity::new(),
951 exit: if disruptor.is_some() {
952 ExitFlags::MAY_KILL
953 } else {
954 ExitFlags::empty()
955 },
956 },
957 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
958 result: Uniformity {
959 non_uniform_result: None,
960 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
961 },
962 exit: ExitFlags::empty(),
963 },
964 S::WorkGroupUniformLoad { pointer, .. } => {
965 let _condition_nur = self.add_ref(pointer);
966
967 FunctionUniformity {
986 result: Uniformity {
987 non_uniform_result: None,
988 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
989 },
990 exit: ExitFlags::empty(),
991 }
992 }
993 S::Block(ref b) => self.process_block(
994 b,
995 other_functions,
996 disruptor,
997 expression_arena,
998 diagnostic_filter_arena,
999 )?,
1000 S::If {
1001 condition,
1002 ref accept,
1003 ref reject,
1004 } => {
1005 let condition_nur = self.add_ref(condition);
1006 let branch_disruptor =
1007 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
1008 let accept_uniformity = self.process_block(
1009 accept,
1010 other_functions,
1011 branch_disruptor,
1012 expression_arena,
1013 diagnostic_filter_arena,
1014 )?;
1015 let reject_uniformity = self.process_block(
1016 reject,
1017 other_functions,
1018 branch_disruptor,
1019 expression_arena,
1020 diagnostic_filter_arena,
1021 )?;
1022 accept_uniformity | reject_uniformity
1023 }
1024 S::Switch {
1025 selector,
1026 ref cases,
1027 } => {
1028 let selector_nur = self.add_ref(selector);
1029 let branch_disruptor =
1030 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1031 let mut uniformity = FunctionUniformity::new();
1032 let mut case_disruptor = branch_disruptor;
1033 for case in cases.iter() {
1034 let case_uniformity = self.process_block(
1035 &case.body,
1036 other_functions,
1037 case_disruptor,
1038 expression_arena,
1039 diagnostic_filter_arena,
1040 )?;
1041 case_disruptor = if case.fall_through {
1042 case_disruptor.or(case_uniformity.exit_disruptor())
1043 } else {
1044 branch_disruptor
1045 };
1046 uniformity = uniformity | case_uniformity;
1047 }
1048 uniformity
1049 }
1050 S::Loop {
1051 ref body,
1052 ref continuing,
1053 break_if,
1054 } => {
1055 let body_uniformity = self.process_block(
1056 body,
1057 other_functions,
1058 disruptor,
1059 expression_arena,
1060 diagnostic_filter_arena,
1061 )?;
1062 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1063 let continuing_uniformity = self.process_block(
1064 continuing,
1065 other_functions,
1066 continuing_disruptor,
1067 expression_arena,
1068 diagnostic_filter_arena,
1069 )?;
1070 if let Some(expr) = break_if {
1071 let _ = self.add_ref(expr);
1072 }
1073 body_uniformity | continuing_uniformity
1074 }
1075 S::Return { value } => FunctionUniformity {
1076 result: Uniformity {
1077 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1078 requirements: UniformityRequirements::empty(),
1079 },
1080 exit: if disruptor.is_some() {
1081 ExitFlags::MAY_RETURN
1082 } else {
1083 ExitFlags::empty()
1084 },
1085 },
1086 S::Store { pointer, value } => {
1090 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1091 let _ = self.add_ref(value);
1092 FunctionUniformity::new()
1093 }
1094 S::ImageStore {
1095 image,
1096 coordinate,
1097 array_index,
1098 value,
1099 } => {
1100 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1101 if let Some(expr) = array_index {
1102 let _ = self.add_ref(expr);
1103 }
1104 let _ = self.add_ref(coordinate);
1105 let _ = self.add_ref(value);
1106 FunctionUniformity::new()
1107 }
1108 S::Call {
1109 function,
1110 ref arguments,
1111 result: _,
1112 } => {
1113 for &argument in arguments {
1114 let _ = self.add_ref(argument);
1115 }
1116 let info = &other_functions[function.index()];
1117 self.process_call(info, arguments, expression_arena)?
1119 }
1120 S::Atomic {
1121 pointer,
1122 ref fun,
1123 value,
1124 result: _,
1125 } => {
1126 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1127 let _ = self.add_ref(value);
1128 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1129 let _ = self.add_ref(cmp);
1130 }
1131 FunctionUniformity::new()
1132 }
1133 S::ImageAtomic {
1134 image,
1135 coordinate,
1136 array_index,
1137 fun: _,
1138 value,
1139 } => {
1140 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1141 let _ = self.add_ref(coordinate);
1142 if let Some(expr) = array_index {
1143 let _ = self.add_ref(expr);
1144 }
1145 let _ = self.add_ref(value);
1146 FunctionUniformity::new()
1147 }
1148 S::RayQuery { query, ref fun } => {
1149 let _ = self.add_ref(query);
1150 match *fun {
1151 crate::RayQueryFunction::Initialize {
1152 acceleration_structure,
1153 descriptor,
1154 } => {
1155 let _ = self.add_ref(acceleration_structure);
1156 let _ = self.add_ref(descriptor);
1157 }
1158 crate::RayQueryFunction::Proceed { result: _ } => {}
1159 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1160 let _ = self.add_ref(hit_t);
1161 }
1162 crate::RayQueryFunction::ConfirmIntersection => {}
1163 crate::RayQueryFunction::Terminate => {}
1164 }
1165 FunctionUniformity::new()
1166 }
1167 S::SubgroupBallot {
1168 result: _,
1169 predicate,
1170 } => {
1171 if let Some(predicate) = predicate {
1172 let _ = self.add_ref(predicate);
1173 }
1174 FunctionUniformity::new()
1175 }
1176 S::SubgroupCollectiveOperation {
1177 op: _,
1178 collective_op: _,
1179 argument,
1180 result: _,
1181 } => {
1182 let _ = self.add_ref(argument);
1183 FunctionUniformity::new()
1184 }
1185 S::SubgroupGather {
1186 mode,
1187 argument,
1188 result: _,
1189 } => {
1190 let _ = self.add_ref(argument);
1191 match mode {
1192 crate::GatherMode::BroadcastFirst => {}
1193 crate::GatherMode::Broadcast(index)
1194 | crate::GatherMode::Shuffle(index)
1195 | crate::GatherMode::ShuffleDown(index)
1196 | crate::GatherMode::ShuffleUp(index)
1197 | crate::GatherMode::ShuffleXor(index)
1198 | crate::GatherMode::QuadBroadcast(index) => {
1199 let _ = self.add_ref(index);
1200 }
1201 crate::GatherMode::QuadSwap(_) => {}
1202 }
1203 FunctionUniformity::new()
1204 }
1205 S::CooperativeStore { target, ref data } => FunctionUniformity {
1206 result: Uniformity {
1207 non_uniform_result: self
1208 .add_ref(target)
1209 .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE))
1210 .or(self.add_ref(data.stride)),
1211 requirements: UniformityRequirements::COOP_OPS,
1212 },
1213 exit: ExitFlags::empty(),
1214 },
1215 S::RayPipelineFunction(ref fun) => {
1216 match *fun {
1217 crate::RayPipelineFunction::TraceRay {
1218 acceleration_structure,
1219 descriptor,
1220 payload,
1221 } => {
1222 let _ = self.add_ref(acceleration_structure);
1223 let _ = self.add_ref(descriptor);
1224 let _ = self.add_ref(payload);
1225 }
1226 }
1227 FunctionUniformity::new()
1228 }
1229 };
1230
1231 disruptor = disruptor.or(uniformity.exit_disruptor());
1232 combined_uniformity = combined_uniformity | uniformity;
1233 }
1234 Ok(combined_uniformity)
1235 }
1236}
1237
1238impl ModuleInfo {
1239 pub(super) fn process_const_expression(
1241 &mut self,
1242 handle: Handle<crate::Expression>,
1243 resolve_context: &ResolveContext,
1244 gctx: crate::proc::GlobalCtx,
1245 ) -> Result<(), super::ConstExpressionError> {
1246 self.const_expression_types[handle.index()] =
1247 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1248 Ok(())
1249 }
1250
1251 pub(super) fn process_function(
1254 &self,
1255 fun: &crate::Function,
1256 module: &crate::Module,
1257 flags: ValidationFlags,
1258 capabilities: super::Capabilities,
1259 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1260 let mut info = FunctionInfo {
1261 flags,
1262 available_stages: ShaderStages::all(),
1263 uniformity: Uniformity::new(),
1264 may_kill: false,
1265 sampling_set: crate::FastHashSet::default(),
1266 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1267 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1268 sampling: crate::FastHashSet::default(),
1269 dual_source_blending: false,
1270 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1271 immediate_slots_used: ImmediateSlots::default(),
1272 };
1273 let resolve_context =
1274 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1275
1276 for (handle, _) in fun.expressions.iter() {
1277 if let Err(source) = info.process_expression(
1278 handle,
1279 &fun.expressions,
1280 &self.functions,
1281 &resolve_context,
1282 capabilities,
1283 ) {
1284 return Err(FunctionError::Expression { handle, source }
1285 .with_span_handle(handle, &fun.expressions));
1286 }
1287 }
1288
1289 for (_, expr) in fun.local_variables.iter() {
1290 if let Some(init) = expr.init {
1291 let _ = info.add_ref(init);
1292 }
1293 }
1294
1295 let uniformity = info.process_block(
1296 &fun.body,
1297 &self.functions,
1298 None,
1299 &fun.expressions,
1300 &module.diagnostic_filters,
1301 )?;
1302 info.uniformity = uniformity.result;
1303 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1304
1305 for &handle in fun.named_expressions.keys() {
1311 if let Some(global) = info[handle].assignable_global {
1312 if info.global_uses[global.index()].is_empty() {
1313 info.global_uses[global.index()] = GlobalUse::QUERY;
1314 }
1315 }
1316 }
1317
1318 Ok(info)
1319 }
1320
1321 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1322 &self.entry_points[index]
1323 }
1324}
1325
1326#[test]
1327fn uniform_control_flow() {
1328 use crate::{Expression as E, Statement as S};
1329
1330 let mut type_arena = crate::UniqueArena::new();
1331 let ty = type_arena.insert(
1332 crate::Type {
1333 name: None,
1334 inner: crate::TypeInner::Vector {
1335 size: crate::VectorSize::Bi,
1336 scalar: crate::Scalar::F32,
1337 },
1338 },
1339 Default::default(),
1340 );
1341 let mut global_var_arena = Arena::new();
1342 let non_uniform_global = global_var_arena.append(
1343 crate::GlobalVariable {
1344 name: None,
1345 init: None,
1346 ty,
1347 space: crate::AddressSpace::Handle,
1348 binding: None,
1349 memory_decorations: crate::MemoryDecorations::empty(),
1350 },
1351 Default::default(),
1352 );
1353 let uniform_global = global_var_arena.append(
1354 crate::GlobalVariable {
1355 name: None,
1356 init: None,
1357 ty,
1358 binding: None,
1359 space: crate::AddressSpace::Uniform,
1360 memory_decorations: crate::MemoryDecorations::empty(),
1361 },
1362 Default::default(),
1363 );
1364
1365 let mut expressions = Arena::new();
1366 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1368 let derivative_expr = expressions.append(
1370 E::Derivative {
1371 axis: crate::DerivativeAxis::X,
1372 ctrl: crate::DerivativeControl::None,
1373 expr: constant_expr,
1374 },
1375 Default::default(),
1376 );
1377 let emit_range_constant_derivative = expressions.range_from(0);
1378 let non_uniform_global_expr =
1379 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1380 let uniform_global_expr =
1381 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1382 let emit_range_globals = expressions.range_from(2);
1383
1384 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1386 let access_expr = expressions.append(
1388 E::AccessIndex {
1389 base: non_uniform_global_expr,
1390 index: 1,
1391 },
1392 Default::default(),
1393 );
1394 let emit_range_query_access_globals = expressions.range_from(2);
1395
1396 let mut info = FunctionInfo {
1397 flags: ValidationFlags::all(),
1398 available_stages: ShaderStages::all(),
1399 uniformity: Uniformity::new(),
1400 may_kill: false,
1401 sampling_set: crate::FastHashSet::default(),
1402 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1403 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1404 sampling: crate::FastHashSet::default(),
1405 dual_source_blending: false,
1406 diagnostic_filter_leaf: None,
1407 immediate_slots_used: ImmediateSlots::default(),
1408 };
1409 let resolve_context = ResolveContext {
1410 constants: &Arena::new(),
1411 overrides: &Arena::new(),
1412 types: &type_arena,
1413 special_types: &crate::SpecialTypes::default(),
1414 global_vars: &global_var_arena,
1415 local_vars: &Arena::new(),
1416 functions: &Arena::new(),
1417 arguments: &[],
1418 };
1419 for (handle, _) in expressions.iter() {
1420 info.process_expression(
1421 handle,
1422 &expressions,
1423 &[],
1424 &resolve_context,
1425 super::Capabilities::empty(),
1426 )
1427 .unwrap();
1428 }
1429 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1430 assert_eq!(info[uniform_global_expr].ref_count, 1);
1431 assert_eq!(info[query_expr].ref_count, 0);
1432 assert_eq!(info[access_expr].ref_count, 0);
1433 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1434 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1435
1436 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1437 let stmt_if_uniform = S::If {
1438 condition: uniform_global_expr,
1439 accept: crate::Block::new(),
1440 reject: vec![
1441 S::Emit(emit_range_constant_derivative.clone()),
1442 S::Store {
1443 pointer: constant_expr,
1444 value: derivative_expr,
1445 },
1446 ]
1447 .into(),
1448 };
1449 assert_eq!(
1450 info.process_block(
1451 &vec![stmt_emit1, stmt_if_uniform].into(),
1452 &[],
1453 None,
1454 &expressions,
1455 &Arena::new(),
1456 ),
1457 Ok(FunctionUniformity {
1458 result: Uniformity {
1459 non_uniform_result: None,
1460 requirements: UniformityRequirements::DERIVATIVE,
1461 },
1462 exit: ExitFlags::empty(),
1463 }),
1464 );
1465 assert_eq!(info[constant_expr].ref_count, 2);
1466 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1467
1468 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1469 let stmt_if_non_uniform = S::If {
1470 condition: non_uniform_global_expr,
1471 accept: vec![
1472 S::Emit(emit_range_constant_derivative),
1473 S::Store {
1474 pointer: constant_expr,
1475 value: derivative_expr,
1476 },
1477 ]
1478 .into(),
1479 reject: crate::Block::new(),
1480 };
1481 {
1482 let block_info = info.process_block(
1483 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1484 &[],
1485 None,
1486 &expressions,
1487 &Arena::new(),
1488 );
1489 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1490 assert_eq!(info[derivative_expr].ref_count, 2);
1491 } else {
1492 assert_eq!(
1493 block_info,
1494 Err(FunctionError::NonUniformControlFlow(
1495 UniformityRequirements::DERIVATIVE,
1496 derivative_expr,
1497 UniformityDisruptor::Expression(non_uniform_global_expr)
1498 )
1499 .with_span()),
1500 );
1501 assert_eq!(info[derivative_expr].ref_count, 1);
1502
1503 let mut diagnostic_filters = Arena::new();
1505 let diagnostic_filter_leaf = diagnostic_filters.append(
1506 DiagnosticFilterNode {
1507 inner: crate::diagnostic_filter::DiagnosticFilter {
1508 new_severity: crate::diagnostic_filter::Severity::Off,
1509 triggering_rule:
1510 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1511 StandardFilterableTriggeringRule::DerivativeUniformity,
1512 ),
1513 },
1514 parent: None,
1515 },
1516 crate::Span::default(),
1517 );
1518 let mut info = FunctionInfo {
1519 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1520 ..info.clone()
1521 };
1522
1523 let block_info = info.process_block(
1524 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1525 &[],
1526 None,
1527 &expressions,
1528 &diagnostic_filters,
1529 );
1530 assert_eq!(
1531 block_info,
1532 Ok(FunctionUniformity {
1533 result: Uniformity {
1534 non_uniform_result: None,
1535 requirements: UniformityRequirements::DERIVATIVE,
1536 },
1537 exit: ExitFlags::empty()
1538 }),
1539 );
1540 assert_eq!(info[derivative_expr].ref_count, 2);
1541 }
1542 }
1543 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1544
1545 let stmt_emit3 = S::Emit(emit_range_globals);
1546 let stmt_return_non_uniform = S::Return {
1547 value: Some(non_uniform_global_expr),
1548 };
1549 assert_eq!(
1550 info.process_block(
1551 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1552 &[],
1553 Some(UniformityDisruptor::Return),
1554 &expressions,
1555 &Arena::new(),
1556 ),
1557 Ok(FunctionUniformity {
1558 result: Uniformity {
1559 non_uniform_result: Some(non_uniform_global_expr),
1560 requirements: UniformityRequirements::empty(),
1561 },
1562 exit: ExitFlags::MAY_RETURN,
1563 }),
1564 );
1565 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1566
1567 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1569 let stmt_assign = S::Store {
1570 pointer: access_expr,
1571 value: query_expr,
1572 };
1573 let stmt_return_pointer = S::Return {
1574 value: Some(access_expr),
1575 };
1576 let stmt_kill = S::Kill;
1577 assert_eq!(
1578 info.process_block(
1579 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1580 &[],
1581 Some(UniformityDisruptor::Discard),
1582 &expressions,
1583 &Arena::new(),
1584 ),
1585 Ok(FunctionUniformity {
1586 result: Uniformity {
1587 non_uniform_result: Some(non_uniform_global_expr),
1588 requirements: UniformityRequirements::empty(),
1589 },
1590 exit: ExitFlags::all(),
1591 }),
1592 );
1593 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1594}