1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
78 }
79}
80
81#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84 result: Uniformity,
85 exit: ExitFlags,
86}
87
88#[derive(Debug, Clone, Default)]
90#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
91#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
92#[cfg_attr(test, derive(PartialEq))]
93pub struct FunctionMeshShaderInfo {
94 pub vertex_type: Option<(Handle<crate::Type>, Handle<crate::Expression>)>,
99
100 pub primitive_type: Option<(Handle<crate::Type>, Handle<crate::Expression>)>,
105}
106
107impl ops::BitOr for FunctionUniformity {
108 type Output = Self;
109 fn bitor(self, other: Self) -> Self {
110 FunctionUniformity {
111 result: Uniformity {
112 non_uniform_result: self
113 .result
114 .non_uniform_result
115 .or(other.result.non_uniform_result),
116 requirements: self.result.requirements | other.result.requirements,
117 },
118 exit: self.exit | other.exit,
119 }
120 }
121}
122
123impl FunctionUniformity {
124 const fn new() -> Self {
125 FunctionUniformity {
126 result: Uniformity::new(),
127 exit: ExitFlags::empty(),
128 }
129 }
130
131 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
133 if self.exit.contains(ExitFlags::MAY_RETURN) {
134 Some(UniformityDisruptor::Return)
135 } else if self.exit.contains(ExitFlags::MAY_KILL) {
136 Some(UniformityDisruptor::Discard)
137 } else {
138 None
139 }
140 }
141}
142
143bitflags::bitflags! {
144 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
148 pub struct GlobalUse: u8 {
149 const READ = 0x1;
151 const WRITE = 0x2;
153 const QUERY = 0x4;
155 const ATOMIC = 0x8;
157 }
158}
159
160#[derive(Clone, Debug, Eq, Hash, PartialEq)]
161#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
162#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
163pub struct SamplingKey {
164 pub image: Handle<crate::GlobalVariable>,
165 pub sampler: Handle<crate::GlobalVariable>,
166}
167
168#[derive(Clone, Debug)]
169#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
170#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
171pub struct ExpressionInfo {
173 pub uniformity: Uniformity,
179
180 pub ref_count: usize,
186
187 assignable_global: Option<Handle<crate::GlobalVariable>>,
201
202 pub ty: TypeResolution,
204}
205
206impl ExpressionInfo {
207 const fn new() -> Self {
208 ExpressionInfo {
209 uniformity: Uniformity::new(),
210 ref_count: 0,
211 assignable_global: None,
212 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
214 kind: crate::ScalarKind::Bool,
215 width: 0,
216 })),
217 }
218 }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
222#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
223#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
224enum GlobalOrArgument {
225 Global(Handle<crate::GlobalVariable>),
226 Argument(u32),
227}
228
229impl GlobalOrArgument {
230 fn from_expression(
231 expression_arena: &Arena<crate::Expression>,
232 expression: Handle<crate::Expression>,
233 ) -> Result<GlobalOrArgument, ExpressionError> {
234 Ok(match expression_arena[expression] {
235 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
236 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
237 crate::Expression::Access { base, .. }
238 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
239 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
240 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
241 },
242 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
243 })
244 }
245}
246
247#[derive(Debug, Clone, PartialEq, Eq, Hash)]
248#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
249#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
250struct Sampling {
251 image: GlobalOrArgument,
252 sampler: GlobalOrArgument,
253}
254
255#[derive(Debug, Clone)]
256#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
257#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
258pub struct FunctionInfo {
259 #[allow(dead_code)]
261 flags: ValidationFlags,
262 pub available_stages: ShaderStages,
264 pub uniformity: Uniformity,
266 pub may_kill: bool,
268
269 pub sampling_set: crate::FastHashSet<SamplingKey>,
284
285 global_uses: Box<[GlobalUse]>,
292
293 expressions: Box<[ExpressionInfo]>,
300
301 sampling: crate::FastHashSet<Sampling>,
314
315 pub dual_source_blending: bool,
317
318 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
324
325 pub mesh_shader_info: FunctionMeshShaderInfo,
327}
328
329impl FunctionInfo {
330 pub const fn global_variable_count(&self) -> usize {
331 self.global_uses.len()
332 }
333 pub const fn expression_count(&self) -> usize {
334 self.expressions.len()
335 }
336 pub fn dominates_global_use(&self, other: &Self) -> bool {
337 for (self_global_uses, other_global_uses) in
338 self.global_uses.iter().zip(other.global_uses.iter())
339 {
340 if !self_global_uses.contains(*other_global_uses) {
341 return false;
342 }
343 }
344 true
345 }
346}
347
348impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
349 type Output = GlobalUse;
350 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
351 &self.global_uses[handle.index()]
352 }
353}
354
355impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
356 type Output = ExpressionInfo;
357 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
358 &self.expressions[handle.index()]
359 }
360}
361
362#[derive(Clone, Copy, Debug, thiserror::Error)]
364#[cfg_attr(test, derive(PartialEq))]
365pub enum UniformityDisruptor {
366 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
367 Expression(Handle<crate::Expression>),
368 #[error("There is a Return earlier in the control flow of the function")]
369 Return,
370 #[error("There is a Discard earlier in the entry point across all called functions")]
371 Discard,
372}
373
374impl FunctionInfo {
375 #[must_use]
383 fn add_ref_impl(
384 &mut self,
385 expr: Handle<crate::Expression>,
386 global_use: GlobalUse,
387 ) -> NonUniformResult {
388 let info = &mut self.expressions[expr.index()];
389 info.ref_count += 1;
390 if let Some(global) = info.assignable_global {
392 self.global_uses[global.index()] |= global_use;
393 }
394 info.uniformity.non_uniform_result
395 }
396
397 pub(super) fn insert_global_use(
406 &mut self,
407 global_use: GlobalUse,
408 global: Handle<crate::GlobalVariable>,
409 ) {
410 self.global_uses[global.index()] |= global_use;
411 }
412
413 #[must_use]
420 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
421 self.add_ref_impl(expr, GlobalUse::READ)
422 }
423
424 #[must_use]
443 fn add_assignable_ref(
444 &mut self,
445 expr: Handle<crate::Expression>,
446 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
447 ) -> NonUniformResult {
448 let info = &mut self.expressions[expr.index()];
449 info.ref_count += 1;
450 if let Some(global) = info.assignable_global {
453 if let Some(_old) = assignable_global.replace(global) {
454 unreachable!()
455 }
456 }
457 info.uniformity.non_uniform_result
458 }
459
460 fn process_call(
462 &mut self,
463 callee: &Self,
464 arguments: &[Handle<crate::Expression>],
465 expression_arena: &Arena<crate::Expression>,
466 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
467 self.sampling_set
468 .extend(callee.sampling_set.iter().cloned());
469 for sampling in callee.sampling.iter() {
470 let image_storage = match sampling.image {
473 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
474 GlobalOrArgument::Argument(i) => {
475 let Some(handle) = arguments.get(i as usize).cloned() else {
476 break;
478 };
479 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
480 |source| {
481 FunctionError::Expression { handle, source }
482 .with_span_handle(handle, expression_arena)
483 },
484 )?
485 }
486 };
487
488 let sampler_storage = match sampling.sampler {
489 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
490 GlobalOrArgument::Argument(i) => {
491 let Some(handle) = arguments.get(i as usize).cloned() else {
492 break;
494 };
495 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
496 |source| {
497 FunctionError::Expression { handle, source }
498 .with_span_handle(handle, expression_arena)
499 },
500 )?
501 }
502 };
503
504 match (image_storage, sampler_storage) {
509 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
510 self.sampling_set.insert(SamplingKey { image, sampler });
511 }
512 (image, sampler) => {
513 self.sampling.insert(Sampling { image, sampler });
514 }
515 }
516 }
517
518 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
520 *mine |= *other;
521 }
522
523 self.try_update_mesh_info(&callee.mesh_shader_info)?;
525
526 Ok(FunctionUniformity {
527 result: callee.uniformity.clone(),
528 exit: if callee.may_kill {
529 ExitFlags::MAY_KILL
530 } else {
531 ExitFlags::empty()
532 },
533 })
534 }
535
536 #[allow(clippy::or_fun_call)]
556 fn process_expression(
557 &mut self,
558 handle: Handle<crate::Expression>,
559 expression_arena: &Arena<crate::Expression>,
560 other_functions: &[FunctionInfo],
561 resolve_context: &ResolveContext,
562 capabilities: super::Capabilities,
563 ) -> Result<(), ExpressionError> {
564 use crate::{Expression as E, SampleLevel as Sl};
565
566 let expression = &expression_arena[handle];
567 let mut assignable_global = None;
568 let uniformity = match *expression {
569 E::Access { base, index } => {
570 let base_ty = self[base].ty.inner_with(resolve_context.types);
571
572 let mut needed_caps = super::Capabilities::empty();
574 let is_binding_array = match *base_ty {
575 crate::TypeInner::BindingArray {
576 base: array_element_ty_handle,
577 ..
578 } => {
579 let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
581 let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
582 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
583 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
584
585 let array_element_ty =
587 &resolve_context.types[array_element_ty_handle].inner;
588
589 needed_caps |= match *array_element_ty {
590 crate::TypeInner::Image { class, .. } => match class {
592 crate::ImageClass::Storage { .. } => sto,
593 _ => st_sb,
594 },
595 crate::TypeInner::Sampler { .. } => sampler,
596 _ => {
598 if let E::GlobalVariable(global_handle) = expression_arena[base] {
599 let global = &resolve_context.global_vars[global_handle];
600 match global.space {
601 crate::AddressSpace::Uniform => uni,
602 crate::AddressSpace::Storage { .. } => st_sb,
603 _ => unreachable!(),
604 }
605 } else {
606 unreachable!()
607 }
608 }
609 };
610
611 true
612 }
613 _ => false,
614 };
615
616 if self[index].uniformity.non_uniform_result.is_some()
617 && !capabilities.contains(needed_caps)
618 && is_binding_array
619 {
620 return Err(ExpressionError::MissingCapabilities(needed_caps));
621 }
622
623 Uniformity {
624 non_uniform_result: self
625 .add_assignable_ref(base, &mut assignable_global)
626 .or(self.add_ref(index)),
627 requirements: UniformityRequirements::empty(),
628 }
629 }
630 E::AccessIndex { base, .. } => Uniformity {
631 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
632 requirements: UniformityRequirements::empty(),
633 },
634 E::Splat { size: _, value } => Uniformity {
636 non_uniform_result: self.add_ref(value),
637 requirements: UniformityRequirements::empty(),
638 },
639 E::Swizzle { vector, .. } => Uniformity {
640 non_uniform_result: self.add_ref(vector),
641 requirements: UniformityRequirements::empty(),
642 },
643 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
644 E::Compose { ref components, .. } => {
645 let non_uniform_result = components
646 .iter()
647 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
648 Uniformity {
649 non_uniform_result,
650 requirements: UniformityRequirements::empty(),
651 }
652 }
653 E::FunctionArgument(index) => {
655 let arg = &resolve_context.arguments[index as usize];
656 let uniform = match arg.binding {
657 Some(crate::Binding::BuiltIn(
658 crate::BuiltIn::WorkGroupId
660 | crate::BuiltIn::WorkGroupSize
661 | crate::BuiltIn::NumWorkGroups,
662 )) => true,
663 _ => false,
664 };
665 Uniformity {
666 non_uniform_result: if uniform { None } else { Some(handle) },
667 requirements: UniformityRequirements::empty(),
668 }
669 }
670 E::GlobalVariable(gh) => {
672 use crate::AddressSpace as As;
673 assignable_global = Some(gh);
674 let var = &resolve_context.global_vars[gh];
675 let uniform = match var.space {
676 As::Function | As::Private => false,
678 As::WorkGroup | As::TaskPayload => true,
681 As::Uniform | As::PushConstant => true,
683 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
685 As::Handle => false,
686 };
687 Uniformity {
688 non_uniform_result: if uniform { None } else { Some(handle) },
689 requirements: UniformityRequirements::empty(),
690 }
691 }
692 E::LocalVariable(_) => Uniformity {
693 non_uniform_result: Some(handle),
694 requirements: UniformityRequirements::empty(),
695 },
696 E::Load { pointer } => Uniformity {
697 non_uniform_result: self.add_ref(pointer),
698 requirements: UniformityRequirements::empty(),
699 },
700 E::ImageSample {
701 image,
702 sampler,
703 gather: _,
704 coordinate,
705 array_index,
706 offset,
707 level,
708 depth_ref,
709 clamp_to_edge: _,
710 } => {
711 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
712 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
713
714 match (image_storage, sampler_storage) {
715 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
716 self.sampling_set.insert(SamplingKey { image, sampler });
717 }
718 _ => {
719 self.sampling.insert(Sampling {
720 image: image_storage,
721 sampler: sampler_storage,
722 });
723 }
724 }
725
726 let array_nur = array_index.and_then(|h| self.add_ref(h));
728 let level_nur = match level {
729 Sl::Auto | Sl::Zero => None,
730 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
731 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
732 };
733 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
734 let offset_nur = offset.and_then(|h| self.add_ref(h));
735 Uniformity {
736 non_uniform_result: self
737 .add_ref(image)
738 .or(self.add_ref(sampler))
739 .or(self.add_ref(coordinate))
740 .or(array_nur)
741 .or(level_nur)
742 .or(dref_nur)
743 .or(offset_nur),
744 requirements: if level.implicit_derivatives() {
745 UniformityRequirements::IMPLICIT_LEVEL
746 } else {
747 UniformityRequirements::empty()
748 },
749 }
750 }
751 E::ImageLoad {
752 image,
753 coordinate,
754 array_index,
755 sample,
756 level,
757 } => {
758 let array_nur = array_index.and_then(|h| self.add_ref(h));
759 let sample_nur = sample.and_then(|h| self.add_ref(h));
760 let level_nur = level.and_then(|h| self.add_ref(h));
761 Uniformity {
762 non_uniform_result: self
763 .add_ref(image)
764 .or(self.add_ref(coordinate))
765 .or(array_nur)
766 .or(sample_nur)
767 .or(level_nur),
768 requirements: UniformityRequirements::empty(),
769 }
770 }
771 E::ImageQuery { image, query } => {
772 let query_nur = match query {
773 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
774 _ => None,
775 };
776 Uniformity {
777 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
778 requirements: UniformityRequirements::empty(),
779 }
780 }
781 E::Unary { expr, .. } => Uniformity {
782 non_uniform_result: self.add_ref(expr),
783 requirements: UniformityRequirements::empty(),
784 },
785 E::Binary { left, right, .. } => Uniformity {
786 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
787 requirements: UniformityRequirements::empty(),
788 },
789 E::Select {
790 condition,
791 accept,
792 reject,
793 } => Uniformity {
794 non_uniform_result: self
795 .add_ref(condition)
796 .or(self.add_ref(accept))
797 .or(self.add_ref(reject)),
798 requirements: UniformityRequirements::empty(),
799 },
800 E::Derivative { expr, .. } => Uniformity {
802 non_uniform_result: self.add_ref(expr),
804 requirements: UniformityRequirements::DERIVATIVE,
805 },
806 E::Relational { argument, .. } => Uniformity {
807 non_uniform_result: self.add_ref(argument),
808 requirements: UniformityRequirements::empty(),
809 },
810 E::Math {
811 fun: _,
812 arg,
813 arg1,
814 arg2,
815 arg3,
816 } => {
817 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
818 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
819 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
820 Uniformity {
821 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
822 requirements: UniformityRequirements::empty(),
823 }
824 }
825 E::As { expr, .. } => Uniformity {
826 non_uniform_result: self.add_ref(expr),
827 requirements: UniformityRequirements::empty(),
828 },
829 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
830 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
831 non_uniform_result: Some(handle),
832 requirements: UniformityRequirements::empty(),
833 },
834 E::WorkGroupUniformLoadResult { .. } => Uniformity {
835 non_uniform_result: None,
837 requirements: UniformityRequirements::empty(),
840 },
841 E::ArrayLength(expr) => Uniformity {
842 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
843 requirements: UniformityRequirements::empty(),
844 },
845 E::RayQueryGetIntersection {
846 query,
847 committed: _,
848 } => Uniformity {
849 non_uniform_result: self.add_ref(query),
850 requirements: UniformityRequirements::empty(),
851 },
852 E::SubgroupBallotResult => Uniformity {
853 non_uniform_result: Some(handle),
854 requirements: UniformityRequirements::empty(),
855 },
856 E::SubgroupOperationResult { .. } => Uniformity {
857 non_uniform_result: Some(handle),
858 requirements: UniformityRequirements::empty(),
859 },
860 E::RayQueryVertexPositions {
861 query,
862 committed: _,
863 } => Uniformity {
864 non_uniform_result: self.add_ref(query),
865 requirements: UniformityRequirements::empty(),
866 },
867 };
868
869 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
870 self.expressions[handle.index()] = ExpressionInfo {
871 uniformity,
872 ref_count: 0,
873 assignable_global,
874 ty,
875 };
876 Ok(())
877 }
878
879 #[allow(clippy::or_fun_call)]
889 fn process_block(
890 &mut self,
891 statements: &crate::Block,
892 other_functions: &[FunctionInfo],
893 mut disruptor: Option<UniformityDisruptor>,
894 expression_arena: &Arena<crate::Expression>,
895 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
896 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
897 use crate::Statement as S;
898
899 let mut combined_uniformity = FunctionUniformity::new();
900 for statement in statements {
901 let uniformity = match *statement {
902 S::Emit(ref range) => {
903 let mut requirements = UniformityRequirements::empty();
904 for expr in range.clone() {
905 let req = self.expressions[expr.index()].uniformity.requirements;
906 if self
907 .flags
908 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
909 && !req.is_empty()
910 {
911 if let Some(cause) = disruptor {
912 let severity = DiagnosticFilterNode::search(
913 self.diagnostic_filter_leaf,
914 diagnostic_filter_arena,
915 StandardFilterableTriggeringRule::DerivativeUniformity,
916 );
917 severity.report_diag(
918 FunctionError::NonUniformControlFlow(req, expr, cause)
919 .with_span_handle(expr, expression_arena),
920 |e, level| log::log!(level, "{e}"),
926 )?;
927 }
928 }
929 requirements |= req;
930 }
931 FunctionUniformity {
932 result: Uniformity {
933 non_uniform_result: None,
934 requirements,
935 },
936 exit: ExitFlags::empty(),
937 }
938 }
939 S::Break | S::Continue => FunctionUniformity::new(),
940 S::Kill => FunctionUniformity {
941 result: Uniformity::new(),
942 exit: if disruptor.is_some() {
943 ExitFlags::MAY_KILL
944 } else {
945 ExitFlags::empty()
946 },
947 },
948 S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
949 result: Uniformity {
950 non_uniform_result: None,
951 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
952 },
953 exit: ExitFlags::empty(),
954 },
955 S::WorkGroupUniformLoad { pointer, .. } => {
956 let _condition_nur = self.add_ref(pointer);
957
958 FunctionUniformity {
977 result: Uniformity {
978 non_uniform_result: None,
979 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
980 },
981 exit: ExitFlags::empty(),
982 }
983 }
984 S::Block(ref b) => self.process_block(
985 b,
986 other_functions,
987 disruptor,
988 expression_arena,
989 diagnostic_filter_arena,
990 )?,
991 S::If {
992 condition,
993 ref accept,
994 ref reject,
995 } => {
996 let condition_nur = self.add_ref(condition);
997 let branch_disruptor =
998 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
999 let accept_uniformity = self.process_block(
1000 accept,
1001 other_functions,
1002 branch_disruptor,
1003 expression_arena,
1004 diagnostic_filter_arena,
1005 )?;
1006 let reject_uniformity = self.process_block(
1007 reject,
1008 other_functions,
1009 branch_disruptor,
1010 expression_arena,
1011 diagnostic_filter_arena,
1012 )?;
1013 accept_uniformity | reject_uniformity
1014 }
1015 S::Switch {
1016 selector,
1017 ref cases,
1018 } => {
1019 let selector_nur = self.add_ref(selector);
1020 let branch_disruptor =
1021 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1022 let mut uniformity = FunctionUniformity::new();
1023 let mut case_disruptor = branch_disruptor;
1024 for case in cases.iter() {
1025 let case_uniformity = self.process_block(
1026 &case.body,
1027 other_functions,
1028 case_disruptor,
1029 expression_arena,
1030 diagnostic_filter_arena,
1031 )?;
1032 case_disruptor = if case.fall_through {
1033 case_disruptor.or(case_uniformity.exit_disruptor())
1034 } else {
1035 branch_disruptor
1036 };
1037 uniformity = uniformity | case_uniformity;
1038 }
1039 uniformity
1040 }
1041 S::Loop {
1042 ref body,
1043 ref continuing,
1044 break_if,
1045 } => {
1046 let body_uniformity = self.process_block(
1047 body,
1048 other_functions,
1049 disruptor,
1050 expression_arena,
1051 diagnostic_filter_arena,
1052 )?;
1053 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1054 let continuing_uniformity = self.process_block(
1055 continuing,
1056 other_functions,
1057 continuing_disruptor,
1058 expression_arena,
1059 diagnostic_filter_arena,
1060 )?;
1061 if let Some(expr) = break_if {
1062 let _ = self.add_ref(expr);
1063 }
1064 body_uniformity | continuing_uniformity
1065 }
1066 S::Return { value } => FunctionUniformity {
1067 result: Uniformity {
1068 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1069 requirements: UniformityRequirements::empty(),
1070 },
1071 exit: if disruptor.is_some() {
1072 ExitFlags::MAY_RETURN
1073 } else {
1074 ExitFlags::empty()
1075 },
1076 },
1077 S::Store { pointer, value } => {
1081 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1082 let _ = self.add_ref(value);
1083 FunctionUniformity::new()
1084 }
1085 S::ImageStore {
1086 image,
1087 coordinate,
1088 array_index,
1089 value,
1090 } => {
1091 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1092 if let Some(expr) = array_index {
1093 let _ = self.add_ref(expr);
1094 }
1095 let _ = self.add_ref(coordinate);
1096 let _ = self.add_ref(value);
1097 FunctionUniformity::new()
1098 }
1099 S::Call {
1100 function,
1101 ref arguments,
1102 result: _,
1103 } => {
1104 for &argument in arguments {
1105 let _ = self.add_ref(argument);
1106 }
1107 let info = &other_functions[function.index()];
1108 self.process_call(info, arguments, expression_arena)?
1110 }
1111 S::Atomic {
1112 pointer,
1113 ref fun,
1114 value,
1115 result: _,
1116 } => {
1117 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1118 let _ = self.add_ref(value);
1119 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1120 let _ = self.add_ref(cmp);
1121 }
1122 FunctionUniformity::new()
1123 }
1124 S::ImageAtomic {
1125 image,
1126 coordinate,
1127 array_index,
1128 fun: _,
1129 value,
1130 } => {
1131 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1132 let _ = self.add_ref(coordinate);
1133 if let Some(expr) = array_index {
1134 let _ = self.add_ref(expr);
1135 }
1136 let _ = self.add_ref(value);
1137 FunctionUniformity::new()
1138 }
1139 S::RayQuery { query, ref fun } => {
1140 let _ = self.add_ref(query);
1141 match *fun {
1142 crate::RayQueryFunction::Initialize {
1143 acceleration_structure,
1144 descriptor,
1145 } => {
1146 let _ = self.add_ref(acceleration_structure);
1147 let _ = self.add_ref(descriptor);
1148 }
1149 crate::RayQueryFunction::Proceed { result: _ } => {}
1150 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1151 let _ = self.add_ref(hit_t);
1152 }
1153 crate::RayQueryFunction::ConfirmIntersection => {}
1154 crate::RayQueryFunction::Terminate => {}
1155 }
1156 FunctionUniformity::new()
1157 }
1158 S::MeshFunction(func) => {
1159 self.available_stages |= ShaderStages::MESH;
1160 match &func {
1161 &crate::MeshFunction::SetMeshOutputs {
1163 vertex_count,
1164 primitive_count,
1165 } => {
1166 let _ = self.add_ref(vertex_count);
1167 let _ = self.add_ref(primitive_count);
1168 FunctionUniformity::new()
1169 }
1170 &crate::MeshFunction::SetVertex { index, value }
1171 | &crate::MeshFunction::SetPrimitive { index, value } => {
1172 let _ = self.add_ref(index);
1173 let _ = self.add_ref(value);
1174 let ty = self.expressions[value.index()].ty.handle().ok_or(
1175 FunctionError::InvalidMeshShaderOutputType(value).with_span(),
1176 )?;
1177
1178 if matches!(func, crate::MeshFunction::SetVertex { .. }) {
1179 self.try_update_mesh_vertex_type(ty, value)?;
1180 } else {
1181 self.try_update_mesh_primitive_type(ty, value)?;
1182 };
1183
1184 FunctionUniformity::new()
1185 }
1186 }
1187 }
1188 S::SubgroupBallot {
1189 result: _,
1190 predicate,
1191 } => {
1192 if let Some(predicate) = predicate {
1193 let _ = self.add_ref(predicate);
1194 }
1195 FunctionUniformity::new()
1196 }
1197 S::SubgroupCollectiveOperation {
1198 op: _,
1199 collective_op: _,
1200 argument,
1201 result: _,
1202 } => {
1203 let _ = self.add_ref(argument);
1204 FunctionUniformity::new()
1205 }
1206 S::SubgroupGather {
1207 mode,
1208 argument,
1209 result: _,
1210 } => {
1211 let _ = self.add_ref(argument);
1212 match mode {
1213 crate::GatherMode::BroadcastFirst => {}
1214 crate::GatherMode::Broadcast(index)
1215 | crate::GatherMode::Shuffle(index)
1216 | crate::GatherMode::ShuffleDown(index)
1217 | crate::GatherMode::ShuffleUp(index)
1218 | crate::GatherMode::ShuffleXor(index)
1219 | crate::GatherMode::QuadBroadcast(index) => {
1220 let _ = self.add_ref(index);
1221 }
1222 crate::GatherMode::QuadSwap(_) => {}
1223 }
1224 FunctionUniformity::new()
1225 }
1226 };
1227
1228 disruptor = disruptor.or(uniformity.exit_disruptor());
1229 combined_uniformity = combined_uniformity | uniformity;
1230 }
1231 Ok(combined_uniformity)
1232 }
1233
1234 fn try_update_mesh_vertex_type(
1244 &mut self,
1245 ty: Handle<crate::Type>,
1246 value: Handle<crate::Expression>,
1247 ) -> Result<(), WithSpan<FunctionError>> {
1248 if let &Some(ref existing) = &self.mesh_shader_info.vertex_type {
1249 if existing.0 != ty {
1250 return Err(
1251 FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span()
1252 );
1253 }
1254 } else {
1255 self.mesh_shader_info.vertex_type = Some((ty, value));
1256 }
1257 Ok(())
1258 }
1259
1260 fn try_update_mesh_primitive_type(
1270 &mut self,
1271 ty: Handle<crate::Type>,
1272 value: Handle<crate::Expression>,
1273 ) -> Result<(), WithSpan<FunctionError>> {
1274 if let &Some(ref existing) = &self.mesh_shader_info.primitive_type {
1275 if existing.0 != ty {
1276 return Err(
1277 FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span()
1278 );
1279 }
1280 } else {
1281 self.mesh_shader_info.primitive_type = Some((ty, value));
1282 }
1283 Ok(())
1284 }
1285
1286 fn try_update_mesh_info(
1288 &mut self,
1289 callee: &FunctionMeshShaderInfo,
1290 ) -> Result<(), WithSpan<FunctionError>> {
1291 if let &Some(ref other_vertex) = &callee.vertex_type {
1292 self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?;
1293 }
1294 if let &Some(ref other_primitive) = &callee.primitive_type {
1295 self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?;
1296 }
1297 Ok(())
1298 }
1299}
1300
1301impl ModuleInfo {
1302 pub(super) fn process_const_expression(
1304 &mut self,
1305 handle: Handle<crate::Expression>,
1306 resolve_context: &ResolveContext,
1307 gctx: crate::proc::GlobalCtx,
1308 ) -> Result<(), super::ConstExpressionError> {
1309 self.const_expression_types[handle.index()] =
1310 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1311 Ok(())
1312 }
1313
1314 pub(super) fn process_function(
1317 &self,
1318 fun: &crate::Function,
1319 module: &crate::Module,
1320 flags: ValidationFlags,
1321 capabilities: super::Capabilities,
1322 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1323 let mut info = FunctionInfo {
1324 flags,
1325 available_stages: ShaderStages::all(),
1326 uniformity: Uniformity::new(),
1327 may_kill: false,
1328 sampling_set: crate::FastHashSet::default(),
1329 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1330 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1331 sampling: crate::FastHashSet::default(),
1332 dual_source_blending: false,
1333 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1334 mesh_shader_info: FunctionMeshShaderInfo::default(),
1335 };
1336 let resolve_context =
1337 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1338
1339 for (handle, _) in fun.expressions.iter() {
1340 if let Err(source) = info.process_expression(
1341 handle,
1342 &fun.expressions,
1343 &self.functions,
1344 &resolve_context,
1345 capabilities,
1346 ) {
1347 return Err(FunctionError::Expression { handle, source }
1348 .with_span_handle(handle, &fun.expressions));
1349 }
1350 }
1351
1352 for (_, expr) in fun.local_variables.iter() {
1353 if let Some(init) = expr.init {
1354 let _ = info.add_ref(init);
1355 }
1356 }
1357
1358 let uniformity = info.process_block(
1359 &fun.body,
1360 &self.functions,
1361 None,
1362 &fun.expressions,
1363 &module.diagnostic_filters,
1364 )?;
1365 info.uniformity = uniformity.result;
1366 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1367
1368 for &handle in fun.named_expressions.keys() {
1374 if let Some(global) = info[handle].assignable_global {
1375 if info.global_uses[global.index()].is_empty() {
1376 info.global_uses[global.index()] = GlobalUse::QUERY;
1377 }
1378 }
1379 }
1380
1381 Ok(info)
1382 }
1383
1384 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1385 &self.entry_points[index]
1386 }
1387}
1388
1389#[test]
1390fn uniform_control_flow() {
1391 use crate::{Expression as E, Statement as S};
1392
1393 let mut type_arena = crate::UniqueArena::new();
1394 let ty = type_arena.insert(
1395 crate::Type {
1396 name: None,
1397 inner: crate::TypeInner::Vector {
1398 size: crate::VectorSize::Bi,
1399 scalar: crate::Scalar::F32,
1400 },
1401 },
1402 Default::default(),
1403 );
1404 let mut global_var_arena = Arena::new();
1405 let non_uniform_global = global_var_arena.append(
1406 crate::GlobalVariable {
1407 name: None,
1408 init: None,
1409 ty,
1410 space: crate::AddressSpace::Handle,
1411 binding: None,
1412 },
1413 Default::default(),
1414 );
1415 let uniform_global = global_var_arena.append(
1416 crate::GlobalVariable {
1417 name: None,
1418 init: None,
1419 ty,
1420 binding: None,
1421 space: crate::AddressSpace::Uniform,
1422 },
1423 Default::default(),
1424 );
1425
1426 let mut expressions = Arena::new();
1427 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1429 let derivative_expr = expressions.append(
1431 E::Derivative {
1432 axis: crate::DerivativeAxis::X,
1433 ctrl: crate::DerivativeControl::None,
1434 expr: constant_expr,
1435 },
1436 Default::default(),
1437 );
1438 let emit_range_constant_derivative = expressions.range_from(0);
1439 let non_uniform_global_expr =
1440 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1441 let uniform_global_expr =
1442 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1443 let emit_range_globals = expressions.range_from(2);
1444
1445 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1447 let access_expr = expressions.append(
1449 E::AccessIndex {
1450 base: non_uniform_global_expr,
1451 index: 1,
1452 },
1453 Default::default(),
1454 );
1455 let emit_range_query_access_globals = expressions.range_from(2);
1456
1457 let mut info = FunctionInfo {
1458 flags: ValidationFlags::all(),
1459 available_stages: ShaderStages::all(),
1460 uniformity: Uniformity::new(),
1461 may_kill: false,
1462 sampling_set: crate::FastHashSet::default(),
1463 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1464 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1465 sampling: crate::FastHashSet::default(),
1466 dual_source_blending: false,
1467 diagnostic_filter_leaf: None,
1468 mesh_shader_info: FunctionMeshShaderInfo::default(),
1469 };
1470 let resolve_context = ResolveContext {
1471 constants: &Arena::new(),
1472 overrides: &Arena::new(),
1473 types: &type_arena,
1474 special_types: &crate::SpecialTypes::default(),
1475 global_vars: &global_var_arena,
1476 local_vars: &Arena::new(),
1477 functions: &Arena::new(),
1478 arguments: &[],
1479 };
1480 for (handle, _) in expressions.iter() {
1481 info.process_expression(
1482 handle,
1483 &expressions,
1484 &[],
1485 &resolve_context,
1486 super::Capabilities::empty(),
1487 )
1488 .unwrap();
1489 }
1490 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1491 assert_eq!(info[uniform_global_expr].ref_count, 1);
1492 assert_eq!(info[query_expr].ref_count, 0);
1493 assert_eq!(info[access_expr].ref_count, 0);
1494 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1495 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1496
1497 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1498 let stmt_if_uniform = S::If {
1499 condition: uniform_global_expr,
1500 accept: crate::Block::new(),
1501 reject: vec![
1502 S::Emit(emit_range_constant_derivative.clone()),
1503 S::Store {
1504 pointer: constant_expr,
1505 value: derivative_expr,
1506 },
1507 ]
1508 .into(),
1509 };
1510 assert_eq!(
1511 info.process_block(
1512 &vec![stmt_emit1, stmt_if_uniform].into(),
1513 &[],
1514 None,
1515 &expressions,
1516 &Arena::new(),
1517 ),
1518 Ok(FunctionUniformity {
1519 result: Uniformity {
1520 non_uniform_result: None,
1521 requirements: UniformityRequirements::DERIVATIVE,
1522 },
1523 exit: ExitFlags::empty(),
1524 }),
1525 );
1526 assert_eq!(info[constant_expr].ref_count, 2);
1527 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1528
1529 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1530 let stmt_if_non_uniform = S::If {
1531 condition: non_uniform_global_expr,
1532 accept: vec![
1533 S::Emit(emit_range_constant_derivative),
1534 S::Store {
1535 pointer: constant_expr,
1536 value: derivative_expr,
1537 },
1538 ]
1539 .into(),
1540 reject: crate::Block::new(),
1541 };
1542 {
1543 let block_info = info.process_block(
1544 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1545 &[],
1546 None,
1547 &expressions,
1548 &Arena::new(),
1549 );
1550 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1551 assert_eq!(info[derivative_expr].ref_count, 2);
1552 } else {
1553 assert_eq!(
1554 block_info,
1555 Err(FunctionError::NonUniformControlFlow(
1556 UniformityRequirements::DERIVATIVE,
1557 derivative_expr,
1558 UniformityDisruptor::Expression(non_uniform_global_expr)
1559 )
1560 .with_span()),
1561 );
1562 assert_eq!(info[derivative_expr].ref_count, 1);
1563
1564 let mut diagnostic_filters = Arena::new();
1566 let diagnostic_filter_leaf = diagnostic_filters.append(
1567 DiagnosticFilterNode {
1568 inner: crate::diagnostic_filter::DiagnosticFilter {
1569 new_severity: crate::diagnostic_filter::Severity::Off,
1570 triggering_rule:
1571 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1572 StandardFilterableTriggeringRule::DerivativeUniformity,
1573 ),
1574 },
1575 parent: None,
1576 },
1577 crate::Span::default(),
1578 );
1579 let mut info = FunctionInfo {
1580 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1581 ..info.clone()
1582 };
1583
1584 let block_info = info.process_block(
1585 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1586 &[],
1587 None,
1588 &expressions,
1589 &diagnostic_filters,
1590 );
1591 assert_eq!(
1592 block_info,
1593 Ok(FunctionUniformity {
1594 result: Uniformity {
1595 non_uniform_result: None,
1596 requirements: UniformityRequirements::DERIVATIVE,
1597 },
1598 exit: ExitFlags::empty()
1599 }),
1600 );
1601 assert_eq!(info[derivative_expr].ref_count, 2);
1602 }
1603 }
1604 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1605
1606 let stmt_emit3 = S::Emit(emit_range_globals);
1607 let stmt_return_non_uniform = S::Return {
1608 value: Some(non_uniform_global_expr),
1609 };
1610 assert_eq!(
1611 info.process_block(
1612 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1613 &[],
1614 Some(UniformityDisruptor::Return),
1615 &expressions,
1616 &Arena::new(),
1617 ),
1618 Ok(FunctionUniformity {
1619 result: Uniformity {
1620 non_uniform_result: Some(non_uniform_global_expr),
1621 requirements: UniformityRequirements::empty(),
1622 },
1623 exit: ExitFlags::MAY_RETURN,
1624 }),
1625 );
1626 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1627
1628 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1630 let stmt_assign = S::Store {
1631 pointer: access_expr,
1632 value: query_expr,
1633 };
1634 let stmt_return_pointer = S::Return {
1635 value: Some(access_expr),
1636 };
1637 let stmt_kill = S::Kill;
1638 assert_eq!(
1639 info.process_block(
1640 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1641 &[],
1642 Some(UniformityDisruptor::Discard),
1643 &expressions,
1644 &Arena::new(),
1645 ),
1646 Ok(FunctionUniformity {
1647 result: Uniformity {
1648 non_uniform_result: Some(non_uniform_global_expr),
1649 requirements: UniformityRequirements::empty(),
1650 },
1651 exit: ExitFlags::all(),
1652 }),
1653 );
1654 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1655}