1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
78 }
79}
80
81#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84 result: Uniformity,
85 exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89 type Output = Self;
90 fn bitor(self, other: Self) -> Self {
91 FunctionUniformity {
92 result: Uniformity {
93 non_uniform_result: self
94 .result
95 .non_uniform_result
96 .or(other.result.non_uniform_result),
97 requirements: self.result.requirements | other.result.requirements,
98 },
99 exit: self.exit | other.exit,
100 }
101 }
102}
103
104impl FunctionUniformity {
105 const fn new() -> Self {
106 FunctionUniformity {
107 result: Uniformity::new(),
108 exit: ExitFlags::empty(),
109 }
110 }
111
112 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114 if self.exit.contains(ExitFlags::MAY_RETURN) {
115 Some(UniformityDisruptor::Return)
116 } else if self.exit.contains(ExitFlags::MAY_KILL) {
117 Some(UniformityDisruptor::Discard)
118 } else {
119 None
120 }
121 }
122}
123
124bitflags::bitflags! {
125 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct GlobalUse: u8 {
130 const READ = 0x1;
132 const WRITE = 0x2;
134 const QUERY = 0x4;
136 const ATOMIC = 0x8;
138 }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145 pub image: Handle<crate::GlobalVariable>,
146 pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152pub struct ExpressionInfo {
154 pub uniformity: Uniformity,
160
161 pub ref_count: usize,
167
168 assignable_global: Option<Handle<crate::GlobalVariable>>,
182
183 pub ty: TypeResolution,
185}
186
187impl ExpressionInfo {
188 const fn new() -> Self {
189 ExpressionInfo {
190 uniformity: Uniformity::new(),
191 ref_count: 0,
192 assignable_global: None,
193 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
195 kind: crate::ScalarKind::Bool,
196 width: 0,
197 })),
198 }
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
204#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
205enum GlobalOrArgument {
206 Global(Handle<crate::GlobalVariable>),
207 Argument(u32),
208}
209
210impl GlobalOrArgument {
211 fn from_expression(
212 expression_arena: &Arena<crate::Expression>,
213 expression: Handle<crate::Expression>,
214 ) -> Result<GlobalOrArgument, ExpressionError> {
215 Ok(match expression_arena[expression] {
216 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
217 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
218 crate::Expression::Access { base, .. }
219 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
220 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
221 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
222 },
223 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
224 })
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
230#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
231struct Sampling {
232 image: GlobalOrArgument,
233 sampler: GlobalOrArgument,
234}
235
236#[derive(Debug, Clone)]
237#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
238#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
239pub struct FunctionInfo {
240 #[allow(dead_code)]
242 flags: ValidationFlags,
243 pub available_stages: ShaderStages,
245 pub uniformity: Uniformity,
247 pub may_kill: bool,
249
250 pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266 global_uses: Box<[GlobalUse]>,
273
274 expressions: Box<[ExpressionInfo]>,
281
282 sampling: crate::FastHashSet<Sampling>,
295
296 pub dual_source_blending: bool,
298
299 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308 pub const fn global_variable_count(&self) -> usize {
309 self.global_uses.len()
310 }
311 pub const fn expression_count(&self) -> usize {
312 self.expressions.len()
313 }
314 pub fn dominates_global_use(&self, other: &Self) -> bool {
315 for (self_global_uses, other_global_uses) in
316 self.global_uses.iter().zip(other.global_uses.iter())
317 {
318 if !self_global_uses.contains(*other_global_uses) {
319 return false;
320 }
321 }
322 true
323 }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327 type Output = GlobalUse;
328 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329 &self.global_uses[handle.index()]
330 }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334 type Output = ExpressionInfo;
335 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336 &self.expressions[handle.index()]
337 }
338}
339
340#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345 Expression(Handle<crate::Expression>),
346 #[error("There is a Return earlier in the control flow of the function")]
347 Return,
348 #[error("There is a Discard earlier in the entry point across all called functions")]
349 Discard,
350}
351
352impl FunctionInfo {
353 #[must_use]
361 fn add_ref_impl(
362 &mut self,
363 expr: Handle<crate::Expression>,
364 global_use: GlobalUse,
365 ) -> NonUniformResult {
366 let info = &mut self.expressions[expr.index()];
367 info.ref_count += 1;
368 if let Some(global) = info.assignable_global {
370 self.global_uses[global.index()] |= global_use;
371 }
372 info.uniformity.non_uniform_result
373 }
374
375 pub(super) fn insert_global_use(
384 &mut self,
385 global_use: GlobalUse,
386 global: Handle<crate::GlobalVariable>,
387 ) {
388 self.global_uses[global.index()] |= global_use;
389 }
390
391 #[must_use]
398 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
399 self.add_ref_impl(expr, GlobalUse::READ)
400 }
401
402 #[must_use]
421 fn add_assignable_ref(
422 &mut self,
423 expr: Handle<crate::Expression>,
424 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
425 ) -> NonUniformResult {
426 let info = &mut self.expressions[expr.index()];
427 info.ref_count += 1;
428 if let Some(global) = info.assignable_global {
431 if let Some(_old) = assignable_global.replace(global) {
432 unreachable!()
433 }
434 }
435 info.uniformity.non_uniform_result
436 }
437
438 fn process_call(
440 &mut self,
441 callee: &Self,
442 arguments: &[Handle<crate::Expression>],
443 expression_arena: &Arena<crate::Expression>,
444 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
445 self.sampling_set
446 .extend(callee.sampling_set.iter().cloned());
447 for sampling in callee.sampling.iter() {
448 let image_storage = match sampling.image {
451 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452 GlobalOrArgument::Argument(i) => {
453 let Some(handle) = arguments.get(i as usize).cloned() else {
454 break;
456 };
457 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458 |source| {
459 FunctionError::Expression { handle, source }
460 .with_span_handle(handle, expression_arena)
461 },
462 )?
463 }
464 };
465
466 let sampler_storage = match sampling.sampler {
467 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
468 GlobalOrArgument::Argument(i) => {
469 let Some(handle) = arguments.get(i as usize).cloned() else {
470 break;
472 };
473 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
474 |source| {
475 FunctionError::Expression { handle, source }
476 .with_span_handle(handle, expression_arena)
477 },
478 )?
479 }
480 };
481
482 match (image_storage, sampler_storage) {
487 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
488 self.sampling_set.insert(SamplingKey { image, sampler });
489 }
490 (image, sampler) => {
491 self.sampling.insert(Sampling { image, sampler });
492 }
493 }
494 }
495
496 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
498 *mine |= *other;
499 }
500
501 Ok(FunctionUniformity {
502 result: callee.uniformity.clone(),
503 exit: if callee.may_kill {
504 ExitFlags::MAY_KILL
505 } else {
506 ExitFlags::empty()
507 },
508 })
509 }
510
511 #[allow(clippy::or_fun_call)]
531 fn process_expression(
532 &mut self,
533 handle: Handle<crate::Expression>,
534 expression_arena: &Arena<crate::Expression>,
535 other_functions: &[FunctionInfo],
536 resolve_context: &ResolveContext,
537 capabilities: super::Capabilities,
538 ) -> Result<(), ExpressionError> {
539 use crate::{Expression as E, SampleLevel as Sl};
540
541 let expression = &expression_arena[handle];
542 let mut assignable_global = None;
543 let uniformity = match *expression {
544 E::Access { base, index } => {
545 let base_ty = self[base].ty.inner_with(resolve_context.types);
546
547 let mut needed_caps = super::Capabilities::empty();
549 let is_binding_array = match *base_ty {
550 crate::TypeInner::BindingArray {
551 base: array_element_ty_handle,
552 ..
553 } => {
554 let array_element_ty =
556 &resolve_context.types[array_element_ty_handle].inner;
557
558 needed_caps |= match *array_element_ty {
559 crate::TypeInner::Image { class, .. } => match class {
561 crate::ImageClass::Storage { .. } => {
562 super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
563 }
564 _ => {
565 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
566 }
567 },
568 crate::TypeInner::Sampler { .. } => {
569 super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
570 }
571 _ => {
573 if let E::GlobalVariable(global_handle) = expression_arena[base] {
574 let global = &resolve_context.global_vars[global_handle];
575 match global.space {
576 crate::AddressSpace::Uniform => {
577 super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
578 }
579 crate::AddressSpace::Storage { .. } => {
580 super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
581 }
582 _ => unreachable!(),
583 }
584 } else {
585 unreachable!()
586 }
587 }
588 };
589
590 true
591 }
592 _ => false,
593 };
594
595 if self[index].uniformity.non_uniform_result.is_some()
596 && !capabilities.contains(needed_caps)
597 && is_binding_array
598 {
599 return Err(ExpressionError::MissingCapabilities(needed_caps));
600 }
601
602 Uniformity {
603 non_uniform_result: self
604 .add_assignable_ref(base, &mut assignable_global)
605 .or(self.add_ref(index)),
606 requirements: UniformityRequirements::empty(),
607 }
608 }
609 E::AccessIndex { base, .. } => Uniformity {
610 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
611 requirements: UniformityRequirements::empty(),
612 },
613 E::Splat { size: _, value } => Uniformity {
615 non_uniform_result: self.add_ref(value),
616 requirements: UniformityRequirements::empty(),
617 },
618 E::Swizzle { vector, .. } => Uniformity {
619 non_uniform_result: self.add_ref(vector),
620 requirements: UniformityRequirements::empty(),
621 },
622 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
623 E::Compose { ref components, .. } => {
624 let non_uniform_result = components
625 .iter()
626 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
627 Uniformity {
628 non_uniform_result,
629 requirements: UniformityRequirements::empty(),
630 }
631 }
632 E::FunctionArgument(index) => {
634 let arg = &resolve_context.arguments[index as usize];
635 let uniform = match arg.binding {
636 Some(crate::Binding::BuiltIn(
637 crate::BuiltIn::WorkGroupId
639 | crate::BuiltIn::WorkGroupSize
640 | crate::BuiltIn::NumWorkGroups,
641 )) => true,
642 _ => false,
643 };
644 Uniformity {
645 non_uniform_result: if uniform { None } else { Some(handle) },
646 requirements: UniformityRequirements::empty(),
647 }
648 }
649 E::GlobalVariable(gh) => {
651 use crate::AddressSpace as As;
652 assignable_global = Some(gh);
653 let var = &resolve_context.global_vars[gh];
654 let uniform = match var.space {
655 As::Function | As::Private => false,
657 As::WorkGroup | As::TaskPayload => true,
660 As::Uniform | As::Immediate => true,
662 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
664 As::Handle => false,
665 };
666 Uniformity {
667 non_uniform_result: if uniform { None } else { Some(handle) },
668 requirements: UniformityRequirements::empty(),
669 }
670 }
671 E::LocalVariable(_) => Uniformity {
672 non_uniform_result: Some(handle),
673 requirements: UniformityRequirements::empty(),
674 },
675 E::Load { pointer } => Uniformity {
676 non_uniform_result: self.add_ref(pointer),
677 requirements: UniformityRequirements::empty(),
678 },
679 E::ImageSample {
680 image,
681 sampler,
682 gather: _,
683 coordinate,
684 array_index,
685 offset,
686 level,
687 depth_ref,
688 clamp_to_edge: _,
689 } => {
690 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
691 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
692
693 match (image_storage, sampler_storage) {
694 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
695 self.sampling_set.insert(SamplingKey { image, sampler });
696 }
697 _ => {
698 self.sampling.insert(Sampling {
699 image: image_storage,
700 sampler: sampler_storage,
701 });
702 }
703 }
704
705 let array_nur = array_index.and_then(|h| self.add_ref(h));
707 let level_nur = match level {
708 Sl::Auto | Sl::Zero => None,
709 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
710 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
711 };
712 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
713 let offset_nur = offset.and_then(|h| self.add_ref(h));
714 Uniformity {
715 non_uniform_result: self
716 .add_ref(image)
717 .or(self.add_ref(sampler))
718 .or(self.add_ref(coordinate))
719 .or(array_nur)
720 .or(level_nur)
721 .or(dref_nur)
722 .or(offset_nur),
723 requirements: if level.implicit_derivatives() {
724 UniformityRequirements::IMPLICIT_LEVEL
725 } else {
726 UniformityRequirements::empty()
727 },
728 }
729 }
730 E::ImageLoad {
731 image,
732 coordinate,
733 array_index,
734 sample,
735 level,
736 } => {
737 let array_nur = array_index.and_then(|h| self.add_ref(h));
738 let sample_nur = sample.and_then(|h| self.add_ref(h));
739 let level_nur = level.and_then(|h| self.add_ref(h));
740 Uniformity {
741 non_uniform_result: self
742 .add_ref(image)
743 .or(self.add_ref(coordinate))
744 .or(array_nur)
745 .or(sample_nur)
746 .or(level_nur),
747 requirements: UniformityRequirements::empty(),
748 }
749 }
750 E::ImageQuery { image, query } => {
751 let query_nur = match query {
752 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
753 _ => None,
754 };
755 Uniformity {
756 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
757 requirements: UniformityRequirements::empty(),
758 }
759 }
760 E::Unary { expr, .. } => Uniformity {
761 non_uniform_result: self.add_ref(expr),
762 requirements: UniformityRequirements::empty(),
763 },
764 E::Binary { left, right, .. } => Uniformity {
765 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
766 requirements: UniformityRequirements::empty(),
767 },
768 E::Select {
769 condition,
770 accept,
771 reject,
772 } => Uniformity {
773 non_uniform_result: self
774 .add_ref(condition)
775 .or(self.add_ref(accept))
776 .or(self.add_ref(reject)),
777 requirements: UniformityRequirements::empty(),
778 },
779 E::Derivative { expr, .. } => Uniformity {
781 non_uniform_result: self.add_ref(expr),
783 requirements: UniformityRequirements::DERIVATIVE,
784 },
785 E::Relational { argument, .. } => Uniformity {
786 non_uniform_result: self.add_ref(argument),
787 requirements: UniformityRequirements::empty(),
788 },
789 E::Math {
790 fun: _,
791 arg,
792 arg1,
793 arg2,
794 arg3,
795 } => {
796 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
797 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
798 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
799 Uniformity {
800 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
801 requirements: UniformityRequirements::empty(),
802 }
803 }
804 E::As { expr, .. } => Uniformity {
805 non_uniform_result: self.add_ref(expr),
806 requirements: UniformityRequirements::empty(),
807 },
808 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
809 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
810 non_uniform_result: Some(handle),
811 requirements: UniformityRequirements::empty(),
812 },
813 E::WorkGroupUniformLoadResult { .. } => Uniformity {
814 non_uniform_result: None,
816 requirements: UniformityRequirements::empty(),
819 },
820 E::ArrayLength(expr) => Uniformity {
821 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
822 requirements: UniformityRequirements::empty(),
823 },
824 E::RayQueryGetIntersection {
825 query,
826 committed: _,
827 } => Uniformity {
828 non_uniform_result: self.add_ref(query),
829 requirements: UniformityRequirements::empty(),
830 },
831 E::SubgroupBallotResult => Uniformity {
832 non_uniform_result: Some(handle),
833 requirements: UniformityRequirements::empty(),
834 },
835 E::SubgroupOperationResult { .. } => Uniformity {
836 non_uniform_result: Some(handle),
837 requirements: UniformityRequirements::empty(),
838 },
839 E::RayQueryVertexPositions {
840 query,
841 committed: _,
842 } => Uniformity {
843 non_uniform_result: self.add_ref(query),
844 requirements: UniformityRequirements::empty(),
845 },
846 };
847
848 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
849 self.expressions[handle.index()] = ExpressionInfo {
850 uniformity,
851 ref_count: 0,
852 assignable_global,
853 ty,
854 };
855 Ok(())
856 }
857
858 #[allow(clippy::or_fun_call)]
868 fn process_block(
869 &mut self,
870 statements: &crate::Block,
871 other_functions: &[FunctionInfo],
872 mut disruptor: Option<UniformityDisruptor>,
873 expression_arena: &Arena<crate::Expression>,
874 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
875 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
876 use crate::Statement as S;
877
878 let mut combined_uniformity = FunctionUniformity::new();
879 for statement in statements {
880 let uniformity = match *statement {
881 S::Emit(ref range) => {
882 let mut requirements = UniformityRequirements::empty();
883 for expr in range.clone() {
884 let req = self.expressions[expr.index()].uniformity.requirements;
885 if self
886 .flags
887 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
888 && !req.is_empty()
889 {
890 if let Some(cause) = disruptor {
891 let severity = DiagnosticFilterNode::search(
892 self.diagnostic_filter_leaf,
893 diagnostic_filter_arena,
894 StandardFilterableTriggeringRule::DerivativeUniformity,
895 );
896 severity.report_diag(
897 FunctionError::NonUniformControlFlow(req, expr, cause)
898 .with_span_handle(expr, expression_arena),
899 |e, level| log::log!(level, "{e}"),
905 )?;
906 }
907 }
908 requirements |= req;
909 }
910 FunctionUniformity {
911 result: Uniformity {
912 non_uniform_result: None,
913 requirements,
914 },
915 exit: ExitFlags::empty(),
916 }
917 }
918 S::Break | S::Continue => FunctionUniformity::new(),
919 S::Kill => FunctionUniformity {
920 result: Uniformity::new(),
921 exit: if disruptor.is_some() {
922 ExitFlags::MAY_KILL
923 } else {
924 ExitFlags::empty()
925 },
926 },
927 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
928 result: Uniformity {
929 non_uniform_result: None,
930 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
931 },
932 exit: ExitFlags::empty(),
933 },
934 S::WorkGroupUniformLoad { pointer, .. } => {
935 let _condition_nur = self.add_ref(pointer);
936
937 FunctionUniformity {
956 result: Uniformity {
957 non_uniform_result: None,
958 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
959 },
960 exit: ExitFlags::empty(),
961 }
962 }
963 S::Block(ref b) => self.process_block(
964 b,
965 other_functions,
966 disruptor,
967 expression_arena,
968 diagnostic_filter_arena,
969 )?,
970 S::If {
971 condition,
972 ref accept,
973 ref reject,
974 } => {
975 let condition_nur = self.add_ref(condition);
976 let branch_disruptor =
977 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
978 let accept_uniformity = self.process_block(
979 accept,
980 other_functions,
981 branch_disruptor,
982 expression_arena,
983 diagnostic_filter_arena,
984 )?;
985 let reject_uniformity = self.process_block(
986 reject,
987 other_functions,
988 branch_disruptor,
989 expression_arena,
990 diagnostic_filter_arena,
991 )?;
992 accept_uniformity | reject_uniformity
993 }
994 S::Switch {
995 selector,
996 ref cases,
997 } => {
998 let selector_nur = self.add_ref(selector);
999 let branch_disruptor =
1000 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1001 let mut uniformity = FunctionUniformity::new();
1002 let mut case_disruptor = branch_disruptor;
1003 for case in cases.iter() {
1004 let case_uniformity = self.process_block(
1005 &case.body,
1006 other_functions,
1007 case_disruptor,
1008 expression_arena,
1009 diagnostic_filter_arena,
1010 )?;
1011 case_disruptor = if case.fall_through {
1012 case_disruptor.or(case_uniformity.exit_disruptor())
1013 } else {
1014 branch_disruptor
1015 };
1016 uniformity = uniformity | case_uniformity;
1017 }
1018 uniformity
1019 }
1020 S::Loop {
1021 ref body,
1022 ref continuing,
1023 break_if,
1024 } => {
1025 let body_uniformity = self.process_block(
1026 body,
1027 other_functions,
1028 disruptor,
1029 expression_arena,
1030 diagnostic_filter_arena,
1031 )?;
1032 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1033 let continuing_uniformity = self.process_block(
1034 continuing,
1035 other_functions,
1036 continuing_disruptor,
1037 expression_arena,
1038 diagnostic_filter_arena,
1039 )?;
1040 if let Some(expr) = break_if {
1041 let _ = self.add_ref(expr);
1042 }
1043 body_uniformity | continuing_uniformity
1044 }
1045 S::Return { value } => FunctionUniformity {
1046 result: Uniformity {
1047 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1048 requirements: UniformityRequirements::empty(),
1049 },
1050 exit: if disruptor.is_some() {
1051 ExitFlags::MAY_RETURN
1052 } else {
1053 ExitFlags::empty()
1054 },
1055 },
1056 S::Store { pointer, value } => {
1060 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1061 let _ = self.add_ref(value);
1062 FunctionUniformity::new()
1063 }
1064 S::ImageStore {
1065 image,
1066 coordinate,
1067 array_index,
1068 value,
1069 } => {
1070 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1071 if let Some(expr) = array_index {
1072 let _ = self.add_ref(expr);
1073 }
1074 let _ = self.add_ref(coordinate);
1075 let _ = self.add_ref(value);
1076 FunctionUniformity::new()
1077 }
1078 S::Call {
1079 function,
1080 ref arguments,
1081 result: _,
1082 } => {
1083 for &argument in arguments {
1084 let _ = self.add_ref(argument);
1085 }
1086 let info = &other_functions[function.index()];
1087 self.process_call(info, arguments, expression_arena)?
1089 }
1090 S::Atomic {
1091 pointer,
1092 ref fun,
1093 value,
1094 result: _,
1095 } => {
1096 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1097 let _ = self.add_ref(value);
1098 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1099 let _ = self.add_ref(cmp);
1100 }
1101 FunctionUniformity::new()
1102 }
1103 S::ImageAtomic {
1104 image,
1105 coordinate,
1106 array_index,
1107 fun: _,
1108 value,
1109 } => {
1110 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1111 let _ = self.add_ref(coordinate);
1112 if let Some(expr) = array_index {
1113 let _ = self.add_ref(expr);
1114 }
1115 let _ = self.add_ref(value);
1116 FunctionUniformity::new()
1117 }
1118 S::RayQuery { query, ref fun } => {
1119 let _ = self.add_ref(query);
1120 match *fun {
1121 crate::RayQueryFunction::Initialize {
1122 acceleration_structure,
1123 descriptor,
1124 } => {
1125 let _ = self.add_ref(acceleration_structure);
1126 let _ = self.add_ref(descriptor);
1127 }
1128 crate::RayQueryFunction::Proceed { result: _ } => {}
1129 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1130 let _ = self.add_ref(hit_t);
1131 }
1132 crate::RayQueryFunction::ConfirmIntersection => {}
1133 crate::RayQueryFunction::Terminate => {}
1134 }
1135 FunctionUniformity::new()
1136 }
1137 S::SubgroupBallot {
1138 result: _,
1139 predicate,
1140 } => {
1141 if let Some(predicate) = predicate {
1142 let _ = self.add_ref(predicate);
1143 }
1144 FunctionUniformity::new()
1145 }
1146 S::SubgroupCollectiveOperation {
1147 op: _,
1148 collective_op: _,
1149 argument,
1150 result: _,
1151 } => {
1152 let _ = self.add_ref(argument);
1153 FunctionUniformity::new()
1154 }
1155 S::SubgroupGather {
1156 mode,
1157 argument,
1158 result: _,
1159 } => {
1160 let _ = self.add_ref(argument);
1161 match mode {
1162 crate::GatherMode::BroadcastFirst => {}
1163 crate::GatherMode::Broadcast(index)
1164 | crate::GatherMode::Shuffle(index)
1165 | crate::GatherMode::ShuffleDown(index)
1166 | crate::GatherMode::ShuffleUp(index)
1167 | crate::GatherMode::ShuffleXor(index)
1168 | crate::GatherMode::QuadBroadcast(index) => {
1169 let _ = self.add_ref(index);
1170 }
1171 crate::GatherMode::QuadSwap(_) => {}
1172 }
1173 FunctionUniformity::new()
1174 }
1175 };
1176
1177 disruptor = disruptor.or(uniformity.exit_disruptor());
1178 combined_uniformity = combined_uniformity | uniformity;
1179 }
1180 Ok(combined_uniformity)
1181 }
1182}
1183
1184impl ModuleInfo {
1185 pub(super) fn process_const_expression(
1187 &mut self,
1188 handle: Handle<crate::Expression>,
1189 resolve_context: &ResolveContext,
1190 gctx: crate::proc::GlobalCtx,
1191 ) -> Result<(), super::ConstExpressionError> {
1192 self.const_expression_types[handle.index()] =
1193 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1194 Ok(())
1195 }
1196
1197 pub(super) fn process_function(
1200 &self,
1201 fun: &crate::Function,
1202 module: &crate::Module,
1203 flags: ValidationFlags,
1204 capabilities: super::Capabilities,
1205 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1206 let mut info = FunctionInfo {
1207 flags,
1208 available_stages: ShaderStages::all(),
1209 uniformity: Uniformity::new(),
1210 may_kill: false,
1211 sampling_set: crate::FastHashSet::default(),
1212 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1213 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1214 sampling: crate::FastHashSet::default(),
1215 dual_source_blending: false,
1216 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1217 };
1218 let resolve_context =
1219 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1220
1221 for (handle, _) in fun.expressions.iter() {
1222 if let Err(source) = info.process_expression(
1223 handle,
1224 &fun.expressions,
1225 &self.functions,
1226 &resolve_context,
1227 capabilities,
1228 ) {
1229 return Err(FunctionError::Expression { handle, source }
1230 .with_span_handle(handle, &fun.expressions));
1231 }
1232 }
1233
1234 for (_, expr) in fun.local_variables.iter() {
1235 if let Some(init) = expr.init {
1236 let _ = info.add_ref(init);
1237 }
1238 }
1239
1240 let uniformity = info.process_block(
1241 &fun.body,
1242 &self.functions,
1243 None,
1244 &fun.expressions,
1245 &module.diagnostic_filters,
1246 )?;
1247 info.uniformity = uniformity.result;
1248 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1249
1250 for &handle in fun.named_expressions.keys() {
1256 if let Some(global) = info[handle].assignable_global {
1257 if info.global_uses[global.index()].is_empty() {
1258 info.global_uses[global.index()] = GlobalUse::QUERY;
1259 }
1260 }
1261 }
1262
1263 Ok(info)
1264 }
1265
1266 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1267 &self.entry_points[index]
1268 }
1269}
1270
1271#[test]
1272fn uniform_control_flow() {
1273 use crate::{Expression as E, Statement as S};
1274
1275 let mut type_arena = crate::UniqueArena::new();
1276 let ty = type_arena.insert(
1277 crate::Type {
1278 name: None,
1279 inner: crate::TypeInner::Vector {
1280 size: crate::VectorSize::Bi,
1281 scalar: crate::Scalar::F32,
1282 },
1283 },
1284 Default::default(),
1285 );
1286 let mut global_var_arena = Arena::new();
1287 let non_uniform_global = global_var_arena.append(
1288 crate::GlobalVariable {
1289 name: None,
1290 init: None,
1291 ty,
1292 space: crate::AddressSpace::Handle,
1293 binding: None,
1294 },
1295 Default::default(),
1296 );
1297 let uniform_global = global_var_arena.append(
1298 crate::GlobalVariable {
1299 name: None,
1300 init: None,
1301 ty,
1302 binding: None,
1303 space: crate::AddressSpace::Uniform,
1304 },
1305 Default::default(),
1306 );
1307
1308 let mut expressions = Arena::new();
1309 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1311 let derivative_expr = expressions.append(
1313 E::Derivative {
1314 axis: crate::DerivativeAxis::X,
1315 ctrl: crate::DerivativeControl::None,
1316 expr: constant_expr,
1317 },
1318 Default::default(),
1319 );
1320 let emit_range_constant_derivative = expressions.range_from(0);
1321 let non_uniform_global_expr =
1322 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1323 let uniform_global_expr =
1324 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1325 let emit_range_globals = expressions.range_from(2);
1326
1327 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1329 let access_expr = expressions.append(
1331 E::AccessIndex {
1332 base: non_uniform_global_expr,
1333 index: 1,
1334 },
1335 Default::default(),
1336 );
1337 let emit_range_query_access_globals = expressions.range_from(2);
1338
1339 let mut info = FunctionInfo {
1340 flags: ValidationFlags::all(),
1341 available_stages: ShaderStages::all(),
1342 uniformity: Uniformity::new(),
1343 may_kill: false,
1344 sampling_set: crate::FastHashSet::default(),
1345 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1346 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1347 sampling: crate::FastHashSet::default(),
1348 dual_source_blending: false,
1349 diagnostic_filter_leaf: None,
1350 };
1351 let resolve_context = ResolveContext {
1352 constants: &Arena::new(),
1353 overrides: &Arena::new(),
1354 types: &type_arena,
1355 special_types: &crate::SpecialTypes::default(),
1356 global_vars: &global_var_arena,
1357 local_vars: &Arena::new(),
1358 functions: &Arena::new(),
1359 arguments: &[],
1360 };
1361 for (handle, _) in expressions.iter() {
1362 info.process_expression(
1363 handle,
1364 &expressions,
1365 &[],
1366 &resolve_context,
1367 super::Capabilities::empty(),
1368 )
1369 .unwrap();
1370 }
1371 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1372 assert_eq!(info[uniform_global_expr].ref_count, 1);
1373 assert_eq!(info[query_expr].ref_count, 0);
1374 assert_eq!(info[access_expr].ref_count, 0);
1375 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1376 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1377
1378 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1379 let stmt_if_uniform = S::If {
1380 condition: uniform_global_expr,
1381 accept: crate::Block::new(),
1382 reject: vec![
1383 S::Emit(emit_range_constant_derivative.clone()),
1384 S::Store {
1385 pointer: constant_expr,
1386 value: derivative_expr,
1387 },
1388 ]
1389 .into(),
1390 };
1391 assert_eq!(
1392 info.process_block(
1393 &vec![stmt_emit1, stmt_if_uniform].into(),
1394 &[],
1395 None,
1396 &expressions,
1397 &Arena::new(),
1398 ),
1399 Ok(FunctionUniformity {
1400 result: Uniformity {
1401 non_uniform_result: None,
1402 requirements: UniformityRequirements::DERIVATIVE,
1403 },
1404 exit: ExitFlags::empty(),
1405 }),
1406 );
1407 assert_eq!(info[constant_expr].ref_count, 2);
1408 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1409
1410 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1411 let stmt_if_non_uniform = S::If {
1412 condition: non_uniform_global_expr,
1413 accept: vec![
1414 S::Emit(emit_range_constant_derivative),
1415 S::Store {
1416 pointer: constant_expr,
1417 value: derivative_expr,
1418 },
1419 ]
1420 .into(),
1421 reject: crate::Block::new(),
1422 };
1423 {
1424 let block_info = info.process_block(
1425 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1426 &[],
1427 None,
1428 &expressions,
1429 &Arena::new(),
1430 );
1431 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1432 assert_eq!(info[derivative_expr].ref_count, 2);
1433 } else {
1434 assert_eq!(
1435 block_info,
1436 Err(FunctionError::NonUniformControlFlow(
1437 UniformityRequirements::DERIVATIVE,
1438 derivative_expr,
1439 UniformityDisruptor::Expression(non_uniform_global_expr)
1440 )
1441 .with_span()),
1442 );
1443 assert_eq!(info[derivative_expr].ref_count, 1);
1444
1445 let mut diagnostic_filters = Arena::new();
1447 let diagnostic_filter_leaf = diagnostic_filters.append(
1448 DiagnosticFilterNode {
1449 inner: crate::diagnostic_filter::DiagnosticFilter {
1450 new_severity: crate::diagnostic_filter::Severity::Off,
1451 triggering_rule:
1452 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1453 StandardFilterableTriggeringRule::DerivativeUniformity,
1454 ),
1455 },
1456 parent: None,
1457 },
1458 crate::Span::default(),
1459 );
1460 let mut info = FunctionInfo {
1461 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1462 ..info.clone()
1463 };
1464
1465 let block_info = info.process_block(
1466 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1467 &[],
1468 None,
1469 &expressions,
1470 &diagnostic_filters,
1471 );
1472 assert_eq!(
1473 block_info,
1474 Ok(FunctionUniformity {
1475 result: Uniformity {
1476 non_uniform_result: None,
1477 requirements: UniformityRequirements::DERIVATIVE,
1478 },
1479 exit: ExitFlags::empty()
1480 }),
1481 );
1482 assert_eq!(info[derivative_expr].ref_count, 2);
1483 }
1484 }
1485 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1486
1487 let stmt_emit3 = S::Emit(emit_range_globals);
1488 let stmt_return_non_uniform = S::Return {
1489 value: Some(non_uniform_global_expr),
1490 };
1491 assert_eq!(
1492 info.process_block(
1493 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1494 &[],
1495 Some(UniformityDisruptor::Return),
1496 &expressions,
1497 &Arena::new(),
1498 ),
1499 Ok(FunctionUniformity {
1500 result: Uniformity {
1501 non_uniform_result: Some(non_uniform_global_expr),
1502 requirements: UniformityRequirements::empty(),
1503 },
1504 exit: ExitFlags::MAY_RETURN,
1505 }),
1506 );
1507 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1508
1509 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1511 let stmt_assign = S::Store {
1512 pointer: access_expr,
1513 value: query_expr,
1514 };
1515 let stmt_return_pointer = S::Return {
1516 value: Some(access_expr),
1517 };
1518 let stmt_kill = S::Kill;
1519 assert_eq!(
1520 info.process_block(
1521 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1522 &[],
1523 Some(UniformityDisruptor::Discard),
1524 &expressions,
1525 &Arena::new(),
1526 ),
1527 Ok(FunctionUniformity {
1528 result: Uniformity {
1529 non_uniform_result: Some(non_uniform_global_expr),
1530 requirements: UniformityRequirements::empty(),
1531 },
1532 exit: ExitFlags::all(),
1533 }),
1534 );
1535 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1536}