naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15    arena::{Arena, Handle},
16    proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24    /// Kinds of expressions that require uniform control flow.
25    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28    pub struct UniformityRequirements: u8 {
29        const WORK_GROUP_BARRIER = 0x1;
30        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32        const COOP_OPS = 0x8;
33    }
34}
35
36/// Uniform control flow characteristics.
37#[derive(Clone, Debug)]
38#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
39#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
40#[cfg_attr(test, derive(PartialEq))]
41pub struct Uniformity {
42    /// A child expression with non-uniform result.
43    ///
44    /// This means, when the relevant invocations are scheduled on a compute unit,
45    /// they have to use vector registers to store an individual value
46    /// per invocation.
47    ///
48    /// Whenever the control flow is conditioned on such value,
49    /// the hardware needs to keep track of the mask of invocations,
50    /// and process all branches of the control flow.
51    ///
52    /// Any operations that depend on non-uniform results also produce non-uniform.
53    pub non_uniform_result: NonUniformResult,
54    /// If this expression requires uniform control flow, store the reason here.
55    pub requirements: UniformityRequirements,
56}
57
58impl Uniformity {
59    const fn new() -> Self {
60        Uniformity {
61            non_uniform_result: None,
62            requirements: UniformityRequirements::empty(),
63        }
64    }
65}
66
67bitflags::bitflags! {
68    #[derive(Clone, Copy, Debug, PartialEq)]
69    struct ExitFlags: u8 {
70        /// Control flow may return from the function, which makes all the
71        /// subsequent statements within the current function (only!)
72        /// to be executed in a non-uniform control flow.
73        const MAY_RETURN = 0x1;
74        /// Control flow may be killed. Anything after [`Statement::Kill`] is
75        /// considered inside non-uniform context.
76        ///
77        /// [`Statement::Kill`]: crate::Statement::Kill
78        const MAY_KILL = 0x2;
79    }
80}
81
82/// Uniformity characteristics of a function.
83#[cfg_attr(test, derive(Debug, PartialEq))]
84struct FunctionUniformity {
85    result: Uniformity,
86    exit: ExitFlags,
87}
88
89impl ops::BitOr for FunctionUniformity {
90    type Output = Self;
91    fn bitor(self, other: Self) -> Self {
92        FunctionUniformity {
93            result: Uniformity {
94                non_uniform_result: self
95                    .result
96                    .non_uniform_result
97                    .or(other.result.non_uniform_result),
98                requirements: self.result.requirements | other.result.requirements,
99            },
100            exit: self.exit | other.exit,
101        }
102    }
103}
104
105impl FunctionUniformity {
106    const fn new() -> Self {
107        FunctionUniformity {
108            result: Uniformity::new(),
109            exit: ExitFlags::empty(),
110        }
111    }
112
113    /// Returns a disruptor based on the stored exit flags, if any.
114    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
115        if self.exit.contains(ExitFlags::MAY_RETURN) {
116            Some(UniformityDisruptor::Return)
117        } else if self.exit.contains(ExitFlags::MAY_KILL) {
118            Some(UniformityDisruptor::Discard)
119        } else {
120            None
121        }
122    }
123}
124
125bitflags::bitflags! {
126    /// Indicates how a global variable is used.
127    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
128    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
129    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
130    pub struct GlobalUse: u8 {
131        /// Data will be read from the variable.
132        const READ = 0x1;
133        /// Data will be written to the variable.
134        const WRITE = 0x2;
135        /// The information about the data is queried.
136        const QUERY = 0x4;
137        /// Atomic operations will be performed on the variable.
138        const ATOMIC = 0x8;
139    }
140}
141
142#[derive(Clone, Debug, Eq, Hash, PartialEq)]
143#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
144#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
145pub struct SamplingKey {
146    pub image: Handle<crate::GlobalVariable>,
147    pub sampler: Handle<crate::GlobalVariable>,
148}
149
150#[derive(Clone, Debug)]
151#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
152#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
153/// Information about an expression in a function body.
154pub struct ExpressionInfo {
155    /// Whether this expression is uniform, and why.
156    ///
157    /// If this expression's value is not uniform, this is the handle
158    /// of the expression from which this one's non-uniformity
159    /// originates. Otherwise, this is `None`.
160    pub uniformity: Uniformity,
161
162    /// The number of direct references to this expression in statements and
163    /// other expressions.
164    ///
165    /// This is a _local_ reference count only, it may be non-zero for
166    /// expressions that are ultimately unused.
167    pub ref_count: usize,
168
169    /// The global variable into which this expression produces a pointer.
170    ///
171    /// This is `None` unless this expression is either a
172    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
173    /// ultimately refers to some part of a global.
174    ///
175    /// [`Load`] expressions applied to pointer-typed arguments could
176    /// refer to globals, but we leave this as `None` for them.
177    ///
178    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
179    /// [`Access`]: crate::Expression::Access
180    /// [`AccessIndex`]: crate::Expression::AccessIndex
181    /// [`Load`]: crate::Expression::Load
182    assignable_global: Option<Handle<crate::GlobalVariable>>,
183
184    /// The type of this expression.
185    pub ty: TypeResolution,
186}
187
188impl ExpressionInfo {
189    const fn new() -> Self {
190        ExpressionInfo {
191            uniformity: Uniformity::new(),
192            ref_count: 0,
193            assignable_global: None,
194            // this doesn't matter at this point, will be overwritten
195            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
196                kind: crate::ScalarKind::Bool,
197                width: 0,
198            })),
199        }
200    }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
205#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
206enum GlobalOrArgument {
207    Global(Handle<crate::GlobalVariable>),
208    Argument(u32),
209}
210
211impl GlobalOrArgument {
212    fn from_expression(
213        expression_arena: &Arena<crate::Expression>,
214        expression: Handle<crate::Expression>,
215    ) -> Result<GlobalOrArgument, ExpressionError> {
216        Ok(match expression_arena[expression] {
217            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
218            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
219            crate::Expression::Access { base, .. }
220            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
221                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
222                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
223            },
224            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
225        })
226    }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, Hash)]
230#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
231#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
232struct Sampling {
233    image: GlobalOrArgument,
234    sampler: GlobalOrArgument,
235}
236
237#[derive(Debug, Clone)]
238#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
239#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
240pub struct FunctionInfo {
241    /// Validation flags.
242    #[allow(dead_code)]
243    flags: ValidationFlags,
244    /// Set of shader stages where calling this function is valid.
245    pub available_stages: ShaderStages,
246    /// Uniformity characteristics.
247    pub uniformity: Uniformity,
248    /// Function may kill the invocation.
249    pub may_kill: bool,
250
251    /// All pairs of (texture, sampler) globals that may be used together in
252    /// sampling operations by this function and its callees. This includes
253    /// pairings that arise when this function passes textures and samplers as
254    /// arguments to its callees.
255    ///
256    /// This table does not include uses of textures and samplers passed as
257    /// arguments to this function itself, since we do not know which globals
258    /// those will be. However, this table *is* exhaustive when computed for an
259    /// entry point function: entry points never receive textures or samplers as
260    /// arguments, so all an entry point's sampling can be reported in terms of
261    /// globals.
262    ///
263    /// The GLSL back end uses this table to construct reflection info that
264    /// clients need to construct texture-combined sampler values.
265    pub sampling_set: crate::FastHashSet<SamplingKey>,
266
267    /// How this function and its callees use this module's globals.
268    ///
269    /// This is indexed by `Handle<GlobalVariable>` indices. However,
270    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
271    /// so you can simply index this struct with a global handle to retrieve
272    /// its usage information.
273    global_uses: Box<[GlobalUse]>,
274
275    /// Information about each expression in this function's body.
276    ///
277    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
278    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
279    /// index this struct with an expression handle to retrieve its
280    /// `ExpressionInfo`.
281    expressions: Box<[ExpressionInfo]>,
282
283    /// All (texture, sampler) pairs that may be used together in sampling
284    /// operations by this function and its callees, whether they are accessed
285    /// as globals or passed as arguments.
286    ///
287    /// Participants are represented by [`GlobalVariable`] handles whenever
288    /// possible, and otherwise by indices of this function's arguments.
289    ///
290    /// When analyzing a function call, we combine this data about the callee
291    /// with the actual arguments being passed to produce the callers' own
292    /// `sampling_set` and `sampling` tables.
293    ///
294    /// [`GlobalVariable`]: crate::GlobalVariable
295    sampling: crate::FastHashSet<Sampling>,
296
297    /// Indicates that the function is using dual source blending.
298    pub dual_source_blending: bool,
299
300    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
301    /// module.
302    ///
303    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
304    /// validation.
305    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
306}
307
308impl FunctionInfo {
309    pub const fn global_variable_count(&self) -> usize {
310        self.global_uses.len()
311    }
312    pub const fn expression_count(&self) -> usize {
313        self.expressions.len()
314    }
315    pub fn dominates_global_use(&self, other: &Self) -> bool {
316        for (self_global_uses, other_global_uses) in
317            self.global_uses.iter().zip(other.global_uses.iter())
318        {
319            if !self_global_uses.contains(*other_global_uses) {
320                return false;
321            }
322        }
323        true
324    }
325}
326
327impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
328    type Output = GlobalUse;
329    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
330        &self.global_uses[handle.index()]
331    }
332}
333
334impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
335    type Output = ExpressionInfo;
336    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
337        &self.expressions[handle.index()]
338    }
339}
340
341/// Disruptor of the uniform control flow.
342#[derive(Clone, Copy, Debug, thiserror::Error)]
343#[cfg_attr(test, derive(PartialEq))]
344pub enum UniformityDisruptor {
345    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
346    Expression(Handle<crate::Expression>),
347    #[error("There is a Return earlier in the control flow of the function")]
348    Return,
349    #[error("There is a Discard earlier in the entry point across all called functions")]
350    Discard,
351}
352
353impl FunctionInfo {
354    /// Record a use of `expr` of the sort given by `global_use`.
355    ///
356    /// Bump `expr`'s reference count, and return its uniformity.
357    ///
358    /// If `expr` is a pointer to a global variable, or some part of
359    /// a global variable, add `global_use` to that global's set of
360    /// uses.
361    #[must_use]
362    fn add_ref_impl(
363        &mut self,
364        expr: Handle<crate::Expression>,
365        global_use: GlobalUse,
366    ) -> NonUniformResult {
367        let info = &mut self.expressions[expr.index()];
368        info.ref_count += 1;
369        // Record usage if this expression may access a global
370        if let Some(global) = info.assignable_global {
371            self.global_uses[global.index()] |= global_use;
372        }
373        info.uniformity.non_uniform_result
374    }
375
376    /// Note an entry point's use of `global` not recorded by [`ModuleInfo::process_function`].
377    ///
378    /// Most global variable usage should be recorded via [`add_ref_impl`] in the process
379    /// of expression behavior analysis by [`ModuleInfo::process_function`]. But that code
380    /// has no access to entrypoint-specific information, so interface analysis uses this
381    /// function to record global uses there (like task shader payloads).
382    ///
383    /// [`add_ref_impl`]: Self::add_ref_impl
384    pub(super) fn insert_global_use(
385        &mut self,
386        global_use: GlobalUse,
387        global: Handle<crate::GlobalVariable>,
388    ) {
389        self.global_uses[global.index()] |= global_use;
390    }
391
392    /// Record a use of `expr` for its value.
393    ///
394    /// This is used for almost all expression references. Anything
395    /// that writes to the value `expr` points to, or otherwise wants
396    /// contribute flags other than `GlobalUse::READ`, should use
397    /// `add_ref_impl` directly.
398    #[must_use]
399    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
400        self.add_ref_impl(expr, GlobalUse::READ)
401    }
402
403    /// Record a use of `expr`, and indicate which global variable it
404    /// refers to, if any.
405    ///
406    /// Bump `expr`'s reference count, and return its uniformity.
407    ///
408    /// If `expr` is a pointer to a global variable, or some part
409    /// thereof, store that global in `*assignable_global`. Leave the
410    /// global's uses unchanged.
411    ///
412    /// This is used to determine the [`assignable_global`] for
413    /// [`Access`] and [`AccessIndex`] expressions that ultimately
414    /// refer to a global variable. Those expressions don't contribute
415    /// any usage to the global themselves; that depends on how other
416    /// expressions use them.
417    ///
418    /// [`assignable_global`]: ExpressionInfo::assignable_global
419    /// [`Access`]: crate::Expression::Access
420    /// [`AccessIndex`]: crate::Expression::AccessIndex
421    #[must_use]
422    fn add_assignable_ref(
423        &mut self,
424        expr: Handle<crate::Expression>,
425        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
426    ) -> NonUniformResult {
427        let info = &mut self.expressions[expr.index()];
428        info.ref_count += 1;
429        // propagate the assignable global up the chain, till it either hits
430        // a value-type expression, or the assignment statement.
431        if let Some(global) = info.assignable_global {
432            if let Some(_old) = assignable_global.replace(global) {
433                unreachable!()
434            }
435        }
436        info.uniformity.non_uniform_result
437    }
438
439    /// Inherit information from a called function.
440    fn process_call(
441        &mut self,
442        callee: &Self,
443        arguments: &[Handle<crate::Expression>],
444        expression_arena: &Arena<crate::Expression>,
445    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
446        self.sampling_set
447            .extend(callee.sampling_set.iter().cloned());
448        for sampling in callee.sampling.iter() {
449            // If the callee was passed the texture or sampler as an argument,
450            // we may now be able to determine which globals those referred to.
451            let image_storage = match sampling.image {
452                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
453                GlobalOrArgument::Argument(i) => {
454                    let Some(handle) = arguments.get(i as usize).cloned() else {
455                        // Argument count mismatch, will be reported later by validate_call
456                        break;
457                    };
458                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
459                        |source| {
460                            FunctionError::Expression { handle, source }
461                                .with_span_handle(handle, expression_arena)
462                        },
463                    )?
464                }
465            };
466
467            let sampler_storage = match sampling.sampler {
468                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
469                GlobalOrArgument::Argument(i) => {
470                    let Some(handle) = arguments.get(i as usize).cloned() else {
471                        // Argument count mismatch, will be reported later by validate_call
472                        break;
473                    };
474                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
475                        |source| {
476                            FunctionError::Expression { handle, source }
477                                .with_span_handle(handle, expression_arena)
478                        },
479                    )?
480                }
481            };
482
483            // If we've managed to pin both the image and sampler down to
484            // specific globals, record that in our `sampling_set`. Otherwise,
485            // record as much as we do know in our own `sampling` table, for our
486            // callers to sort out.
487            match (image_storage, sampler_storage) {
488                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
489                    self.sampling_set.insert(SamplingKey { image, sampler });
490                }
491                (image, sampler) => {
492                    self.sampling.insert(Sampling { image, sampler });
493                }
494            }
495        }
496
497        // Inherit global use from our callees.
498        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
499            *mine |= *other;
500        }
501
502        Ok(FunctionUniformity {
503            result: callee.uniformity.clone(),
504            exit: if callee.may_kill {
505                ExitFlags::MAY_KILL
506            } else {
507                ExitFlags::empty()
508            },
509        })
510    }
511
512    /// Compute the [`ExpressionInfo`] for `handle`.
513    ///
514    /// Replace the dummy entry in [`self.expressions`] for `handle`
515    /// with a real `ExpressionInfo` value describing that expression.
516    ///
517    /// This function is called as part of a forward sweep through the
518    /// arena, so we can assume that all earlier expressions in the
519    /// arena already have valid info. Since expressions only depend
520    /// on earlier expressions, this includes all our subexpressions.
521    ///
522    /// Adjust the reference counts on all expressions we use.
523    ///
524    /// Also populate the [`sampling_set`], [`sampling`] and
525    /// [`global_uses`] fields of `self`.
526    ///
527    /// [`self.expressions`]: FunctionInfo::expressions
528    /// [`sampling_set`]: FunctionInfo::sampling_set
529    /// [`sampling`]: FunctionInfo::sampling
530    /// [`global_uses`]: FunctionInfo::global_uses
531    #[allow(clippy::or_fun_call)]
532    fn process_expression(
533        &mut self,
534        handle: Handle<crate::Expression>,
535        expression_arena: &Arena<crate::Expression>,
536        other_functions: &[FunctionInfo],
537        resolve_context: &ResolveContext,
538        capabilities: super::Capabilities,
539    ) -> Result<(), ExpressionError> {
540        use crate::{Expression as E, SampleLevel as Sl};
541
542        let expression = &expression_arena[handle];
543        let mut assignable_global = None;
544        let uniformity = match *expression {
545            E::Access { base, index } => {
546                let base_ty = self[base].ty.inner_with(resolve_context.types);
547
548                // build up the caps needed if this is indexed non-uniformly
549                let mut needed_caps = super::Capabilities::empty();
550                let is_binding_array = match *base_ty {
551                    crate::TypeInner::BindingArray {
552                        base: array_element_ty_handle,
553                        ..
554                    } => {
555                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
556                        let array_element_ty =
557                            &resolve_context.types[array_element_ty_handle].inner;
558
559                        needed_caps |= match *array_element_ty {
560                            // If we're an image, use the appropriate capability.
561                            crate::TypeInner::Image { class, .. } => match class {
562                                crate::ImageClass::Storage { .. } => {
563                                    super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
564                                }
565                                _ => {
566                                    super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
567                                }
568                            },
569                            crate::TypeInner::Sampler { .. } => {
570                                super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
571                            }
572                            // If we're anything but an image or sampler, assume we're a buffer and use the address space.
573                            _ => {
574                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
575                                    let global = &resolve_context.global_vars[global_handle];
576                                    match global.space {
577                                        crate::AddressSpace::Uniform => {
578                                            super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
579                                        }
580                                        crate::AddressSpace::Storage { .. } => {
581                                            super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
582                                        }
583                                        _ => unreachable!(),
584                                    }
585                                } else {
586                                    unreachable!()
587                                }
588                            }
589                        };
590
591                        true
592                    }
593                    _ => false,
594                };
595
596                if self[index].uniformity.non_uniform_result.is_some()
597                    && !capabilities.contains(needed_caps)
598                    && is_binding_array
599                {
600                    return Err(ExpressionError::MissingCapabilities(needed_caps));
601                }
602
603                Uniformity {
604                    non_uniform_result: self
605                        .add_assignable_ref(base, &mut assignable_global)
606                        .or(self.add_ref(index)),
607                    requirements: UniformityRequirements::empty(),
608                }
609            }
610            E::AccessIndex { base, .. } => Uniformity {
611                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
612                requirements: UniformityRequirements::empty(),
613            },
614            // always uniform
615            E::Splat { size: _, value } => Uniformity {
616                non_uniform_result: self.add_ref(value),
617                requirements: UniformityRequirements::empty(),
618            },
619            E::Swizzle { vector, .. } => Uniformity {
620                non_uniform_result: self.add_ref(vector),
621                requirements: UniformityRequirements::empty(),
622            },
623            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
624            E::Compose { ref components, .. } => {
625                let non_uniform_result = components
626                    .iter()
627                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
628                Uniformity {
629                    non_uniform_result,
630                    requirements: UniformityRequirements::empty(),
631                }
632            }
633            // depends on the builtin
634            E::FunctionArgument(index) => {
635                let arg = &resolve_context.arguments[index as usize];
636                let uniform = match arg.binding {
637                    Some(crate::Binding::BuiltIn(
638                        // per-work-group built-ins are uniform
639                        crate::BuiltIn::WorkGroupId
640                        | crate::BuiltIn::WorkGroupSize
641                        | crate::BuiltIn::NumWorkGroups,
642                    )) => true,
643                    _ => false,
644                };
645                Uniformity {
646                    non_uniform_result: if uniform { None } else { Some(handle) },
647                    requirements: UniformityRequirements::empty(),
648                }
649            }
650            // depends on the address space
651            E::GlobalVariable(gh) => {
652                use crate::AddressSpace as As;
653                assignable_global = Some(gh);
654                let var = &resolve_context.global_vars[gh];
655                let uniform = match var.space {
656                    // local data is non-uniform
657                    As::Function | As::Private => false,
658                    // workgroup memory is exclusively accessed by the group
659                    // task payload memory is very similar to workgroup memory
660                    As::WorkGroup | As::TaskPayload => true,
661                    // uniform data
662                    As::Uniform | As::Immediate => true,
663                    // storage data is only uniform when read-only
664                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
665                    As::Handle => false,
666                };
667                Uniformity {
668                    non_uniform_result: if uniform { None } else { Some(handle) },
669                    requirements: UniformityRequirements::empty(),
670                }
671            }
672            E::LocalVariable(_) => Uniformity {
673                non_uniform_result: Some(handle),
674                requirements: UniformityRequirements::empty(),
675            },
676            E::Load { pointer } => Uniformity {
677                non_uniform_result: self.add_ref(pointer),
678                requirements: UniformityRequirements::empty(),
679            },
680            E::ImageSample {
681                image,
682                sampler,
683                gather: _,
684                coordinate,
685                array_index,
686                offset,
687                level,
688                depth_ref,
689                clamp_to_edge: _,
690            } => {
691                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
692                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
693
694                match (image_storage, sampler_storage) {
695                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
696                        self.sampling_set.insert(SamplingKey { image, sampler });
697                    }
698                    _ => {
699                        self.sampling.insert(Sampling {
700                            image: image_storage,
701                            sampler: sampler_storage,
702                        });
703                    }
704                }
705
706                // "nur" == "Non-Uniform Result"
707                let array_nur = array_index.and_then(|h| self.add_ref(h));
708                let level_nur = match level {
709                    Sl::Auto | Sl::Zero => None,
710                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
711                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
712                };
713                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
714                let offset_nur = offset.and_then(|h| self.add_ref(h));
715                Uniformity {
716                    non_uniform_result: self
717                        .add_ref(image)
718                        .or(self.add_ref(sampler))
719                        .or(self.add_ref(coordinate))
720                        .or(array_nur)
721                        .or(level_nur)
722                        .or(dref_nur)
723                        .or(offset_nur),
724                    requirements: if level.implicit_derivatives() {
725                        UniformityRequirements::IMPLICIT_LEVEL
726                    } else {
727                        UniformityRequirements::empty()
728                    },
729                }
730            }
731            E::ImageLoad {
732                image,
733                coordinate,
734                array_index,
735                sample,
736                level,
737            } => {
738                let array_nur = array_index.and_then(|h| self.add_ref(h));
739                let sample_nur = sample.and_then(|h| self.add_ref(h));
740                let level_nur = level.and_then(|h| self.add_ref(h));
741                Uniformity {
742                    non_uniform_result: self
743                        .add_ref(image)
744                        .or(self.add_ref(coordinate))
745                        .or(array_nur)
746                        .or(sample_nur)
747                        .or(level_nur),
748                    requirements: UniformityRequirements::empty(),
749                }
750            }
751            E::ImageQuery { image, query } => {
752                let query_nur = match query {
753                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
754                    _ => None,
755                };
756                Uniformity {
757                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
758                    requirements: UniformityRequirements::empty(),
759                }
760            }
761            E::Unary { expr, .. } => Uniformity {
762                non_uniform_result: self.add_ref(expr),
763                requirements: UniformityRequirements::empty(),
764            },
765            E::Binary { left, right, .. } => Uniformity {
766                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
767                requirements: UniformityRequirements::empty(),
768            },
769            E::Select {
770                condition,
771                accept,
772                reject,
773            } => Uniformity {
774                non_uniform_result: self
775                    .add_ref(condition)
776                    .or(self.add_ref(accept))
777                    .or(self.add_ref(reject)),
778                requirements: UniformityRequirements::empty(),
779            },
780            // explicit derivatives require uniform
781            E::Derivative { expr, .. } => Uniformity {
782                //Note: taking a derivative of a uniform doesn't make it non-uniform
783                non_uniform_result: self.add_ref(expr),
784                requirements: UniformityRequirements::DERIVATIVE,
785            },
786            E::Relational { argument, .. } => Uniformity {
787                non_uniform_result: self.add_ref(argument),
788                requirements: UniformityRequirements::empty(),
789            },
790            E::Math {
791                fun: _,
792                arg,
793                arg1,
794                arg2,
795                arg3,
796            } => {
797                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
798                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
799                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
800                Uniformity {
801                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
802                    requirements: UniformityRequirements::empty(),
803                }
804            }
805            E::As { expr, .. } => Uniformity {
806                non_uniform_result: self.add_ref(expr),
807                requirements: UniformityRequirements::empty(),
808            },
809            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
810            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
811                non_uniform_result: Some(handle),
812                requirements: UniformityRequirements::empty(),
813            },
814            E::WorkGroupUniformLoadResult { .. } => Uniformity {
815                // The result of WorkGroupUniformLoad is always uniform by definition
816                non_uniform_result: None,
817                // The call is what cares about uniformity, not the expression
818                // This expression is never emitted, so this requirement should never be used anyway?
819                requirements: UniformityRequirements::empty(),
820            },
821            E::ArrayLength(expr) => Uniformity {
822                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
823                requirements: UniformityRequirements::empty(),
824            },
825            E::RayQueryGetIntersection {
826                query,
827                committed: _,
828            } => Uniformity {
829                non_uniform_result: self.add_ref(query),
830                requirements: UniformityRequirements::empty(),
831            },
832            E::SubgroupBallotResult => Uniformity {
833                non_uniform_result: Some(handle),
834                requirements: UniformityRequirements::empty(),
835            },
836            E::SubgroupOperationResult { .. } => Uniformity {
837                non_uniform_result: Some(handle),
838                requirements: UniformityRequirements::empty(),
839            },
840            E::RayQueryVertexPositions {
841                query,
842                committed: _,
843            } => Uniformity {
844                non_uniform_result: self.add_ref(query),
845                requirements: UniformityRequirements::empty(),
846            },
847            E::CooperativeLoad { ref data, .. } => Uniformity {
848                non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)),
849                requirements: UniformityRequirements::COOP_OPS,
850            },
851            E::CooperativeMultiplyAdd { a, b, c } => Uniformity {
852                non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))),
853                requirements: UniformityRequirements::COOP_OPS,
854            },
855        };
856
857        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
858        self.expressions[handle.index()] = ExpressionInfo {
859            uniformity,
860            ref_count: 0,
861            assignable_global,
862            ty,
863        };
864        Ok(())
865    }
866
867    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
868    /// Returns the uniformity characteristics at the *function* level, i.e.
869    /// whether or not the function requires to be called in uniform control flow,
870    /// and whether the produced result is not disrupting the control flow.
871    ///
872    /// The parent control flow is uniform if `disruptor.is_none()`.
873    ///
874    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
875    /// require uniformity, but the current flow is non-uniform.
876    #[allow(clippy::or_fun_call)]
877    fn process_block(
878        &mut self,
879        statements: &crate::Block,
880        other_functions: &[FunctionInfo],
881        mut disruptor: Option<UniformityDisruptor>,
882        expression_arena: &Arena<crate::Expression>,
883        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
884    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
885        use crate::Statement as S;
886
887        let mut combined_uniformity = FunctionUniformity::new();
888        for statement in statements {
889            let uniformity = match *statement {
890                S::Emit(ref range) => {
891                    let mut requirements = UniformityRequirements::empty();
892                    for expr in range.clone() {
893                        let req = self.expressions[expr.index()].uniformity.requirements;
894                        if self
895                            .flags
896                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
897                            && !req.is_empty()
898                        {
899                            if let Some(cause) = disruptor {
900                                let severity = DiagnosticFilterNode::search(
901                                    self.diagnostic_filter_leaf,
902                                    diagnostic_filter_arena,
903                                    StandardFilterableTriggeringRule::DerivativeUniformity,
904                                );
905                                severity.report_diag(
906                                    FunctionError::NonUniformControlFlow(req, expr, cause)
907                                        .with_span_handle(expr, expression_arena),
908                                    // TODO: Yes, this isn't contextualized with source, because
909                                    // the user is supposed to render what would normally be an
910                                    // error here. Once we actually support warning-level
911                                    // diagnostic items, then we won't need this non-compliant hack:
912                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
913                                    |e, level| log::log!(level, "{e}"),
914                                )?;
915                            }
916                        }
917                        requirements |= req;
918                    }
919                    FunctionUniformity {
920                        result: Uniformity {
921                            non_uniform_result: None,
922                            requirements,
923                        },
924                        exit: ExitFlags::empty(),
925                    }
926                }
927                S::Break | S::Continue => FunctionUniformity::new(),
928                S::Kill => FunctionUniformity {
929                    result: Uniformity::new(),
930                    exit: if disruptor.is_some() {
931                        ExitFlags::MAY_KILL
932                    } else {
933                        ExitFlags::empty()
934                    },
935                },
936                S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
937                    result: Uniformity {
938                        non_uniform_result: None,
939                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
940                    },
941                    exit: ExitFlags::empty(),
942                },
943                S::WorkGroupUniformLoad { pointer, .. } => {
944                    let _condition_nur = self.add_ref(pointer);
945
946                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
947                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
948                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
949                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
950
951                    /*
952                    if self
953                        .flags
954                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
955                    {
956                        let condition_nur = self.add_ref(pointer);
957                        let this_disruptor =
958                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
959                        if let Some(cause) = this_disruptor {
960                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
961                                .with_span_static(*span, "WorkGroupUniformLoad"));
962                        }
963                    } */
964                    FunctionUniformity {
965                        result: Uniformity {
966                            non_uniform_result: None,
967                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
968                        },
969                        exit: ExitFlags::empty(),
970                    }
971                }
972                S::Block(ref b) => self.process_block(
973                    b,
974                    other_functions,
975                    disruptor,
976                    expression_arena,
977                    diagnostic_filter_arena,
978                )?,
979                S::If {
980                    condition,
981                    ref accept,
982                    ref reject,
983                } => {
984                    let condition_nur = self.add_ref(condition);
985                    let branch_disruptor =
986                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
987                    let accept_uniformity = self.process_block(
988                        accept,
989                        other_functions,
990                        branch_disruptor,
991                        expression_arena,
992                        diagnostic_filter_arena,
993                    )?;
994                    let reject_uniformity = self.process_block(
995                        reject,
996                        other_functions,
997                        branch_disruptor,
998                        expression_arena,
999                        diagnostic_filter_arena,
1000                    )?;
1001                    accept_uniformity | reject_uniformity
1002                }
1003                S::Switch {
1004                    selector,
1005                    ref cases,
1006                } => {
1007                    let selector_nur = self.add_ref(selector);
1008                    let branch_disruptor =
1009                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1010                    let mut uniformity = FunctionUniformity::new();
1011                    let mut case_disruptor = branch_disruptor;
1012                    for case in cases.iter() {
1013                        let case_uniformity = self.process_block(
1014                            &case.body,
1015                            other_functions,
1016                            case_disruptor,
1017                            expression_arena,
1018                            diagnostic_filter_arena,
1019                        )?;
1020                        case_disruptor = if case.fall_through {
1021                            case_disruptor.or(case_uniformity.exit_disruptor())
1022                        } else {
1023                            branch_disruptor
1024                        };
1025                        uniformity = uniformity | case_uniformity;
1026                    }
1027                    uniformity
1028                }
1029                S::Loop {
1030                    ref body,
1031                    ref continuing,
1032                    break_if,
1033                } => {
1034                    let body_uniformity = self.process_block(
1035                        body,
1036                        other_functions,
1037                        disruptor,
1038                        expression_arena,
1039                        diagnostic_filter_arena,
1040                    )?;
1041                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1042                    let continuing_uniformity = self.process_block(
1043                        continuing,
1044                        other_functions,
1045                        continuing_disruptor,
1046                        expression_arena,
1047                        diagnostic_filter_arena,
1048                    )?;
1049                    if let Some(expr) = break_if {
1050                        let _ = self.add_ref(expr);
1051                    }
1052                    body_uniformity | continuing_uniformity
1053                }
1054                S::Return { value } => FunctionUniformity {
1055                    result: Uniformity {
1056                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1057                        requirements: UniformityRequirements::empty(),
1058                    },
1059                    exit: if disruptor.is_some() {
1060                        ExitFlags::MAY_RETURN
1061                    } else {
1062                        ExitFlags::empty()
1063                    },
1064                },
1065                // Here and below, the used expressions are already emitted,
1066                // and their results do not affect the function return value,
1067                // so we can ignore their non-uniformity.
1068                S::Store { pointer, value } => {
1069                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1070                    let _ = self.add_ref(value);
1071                    FunctionUniformity::new()
1072                }
1073                S::ImageStore {
1074                    image,
1075                    coordinate,
1076                    array_index,
1077                    value,
1078                } => {
1079                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1080                    if let Some(expr) = array_index {
1081                        let _ = self.add_ref(expr);
1082                    }
1083                    let _ = self.add_ref(coordinate);
1084                    let _ = self.add_ref(value);
1085                    FunctionUniformity::new()
1086                }
1087                S::Call {
1088                    function,
1089                    ref arguments,
1090                    result: _,
1091                } => {
1092                    for &argument in arguments {
1093                        let _ = self.add_ref(argument);
1094                    }
1095                    let info = &other_functions[function.index()];
1096                    //Note: the result is validated by the Validator, not here
1097                    self.process_call(info, arguments, expression_arena)?
1098                }
1099                S::Atomic {
1100                    pointer,
1101                    ref fun,
1102                    value,
1103                    result: _,
1104                } => {
1105                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1106                    let _ = self.add_ref(value);
1107                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1108                        let _ = self.add_ref(cmp);
1109                    }
1110                    FunctionUniformity::new()
1111                }
1112                S::ImageAtomic {
1113                    image,
1114                    coordinate,
1115                    array_index,
1116                    fun: _,
1117                    value,
1118                } => {
1119                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1120                    let _ = self.add_ref(coordinate);
1121                    if let Some(expr) = array_index {
1122                        let _ = self.add_ref(expr);
1123                    }
1124                    let _ = self.add_ref(value);
1125                    FunctionUniformity::new()
1126                }
1127                S::RayQuery { query, ref fun } => {
1128                    let _ = self.add_ref(query);
1129                    match *fun {
1130                        crate::RayQueryFunction::Initialize {
1131                            acceleration_structure,
1132                            descriptor,
1133                        } => {
1134                            let _ = self.add_ref(acceleration_structure);
1135                            let _ = self.add_ref(descriptor);
1136                        }
1137                        crate::RayQueryFunction::Proceed { result: _ } => {}
1138                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1139                            let _ = self.add_ref(hit_t);
1140                        }
1141                        crate::RayQueryFunction::ConfirmIntersection => {}
1142                        crate::RayQueryFunction::Terminate => {}
1143                    }
1144                    FunctionUniformity::new()
1145                }
1146                S::SubgroupBallot {
1147                    result: _,
1148                    predicate,
1149                } => {
1150                    if let Some(predicate) = predicate {
1151                        let _ = self.add_ref(predicate);
1152                    }
1153                    FunctionUniformity::new()
1154                }
1155                S::SubgroupCollectiveOperation {
1156                    op: _,
1157                    collective_op: _,
1158                    argument,
1159                    result: _,
1160                } => {
1161                    let _ = self.add_ref(argument);
1162                    FunctionUniformity::new()
1163                }
1164                S::SubgroupGather {
1165                    mode,
1166                    argument,
1167                    result: _,
1168                } => {
1169                    let _ = self.add_ref(argument);
1170                    match mode {
1171                        crate::GatherMode::BroadcastFirst => {}
1172                        crate::GatherMode::Broadcast(index)
1173                        | crate::GatherMode::Shuffle(index)
1174                        | crate::GatherMode::ShuffleDown(index)
1175                        | crate::GatherMode::ShuffleUp(index)
1176                        | crate::GatherMode::ShuffleXor(index)
1177                        | crate::GatherMode::QuadBroadcast(index) => {
1178                            let _ = self.add_ref(index);
1179                        }
1180                        crate::GatherMode::QuadSwap(_) => {}
1181                    }
1182                    FunctionUniformity::new()
1183                }
1184                S::CooperativeStore { target, ref data } => FunctionUniformity {
1185                    result: Uniformity {
1186                        non_uniform_result: self
1187                            .add_ref(target)
1188                            .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE))
1189                            .or(self.add_ref(data.stride)),
1190                        requirements: UniformityRequirements::COOP_OPS,
1191                    },
1192                    exit: ExitFlags::empty(),
1193                },
1194            };
1195
1196            disruptor = disruptor.or(uniformity.exit_disruptor());
1197            combined_uniformity = combined_uniformity | uniformity;
1198        }
1199        Ok(combined_uniformity)
1200    }
1201}
1202
1203impl ModuleInfo {
1204    /// Populates `self.const_expression_types`
1205    pub(super) fn process_const_expression(
1206        &mut self,
1207        handle: Handle<crate::Expression>,
1208        resolve_context: &ResolveContext,
1209        gctx: crate::proc::GlobalCtx,
1210    ) -> Result<(), super::ConstExpressionError> {
1211        self.const_expression_types[handle.index()] =
1212            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1213        Ok(())
1214    }
1215
1216    /// Builds the `FunctionInfo` based on the function, and validates the
1217    /// uniform control flow if required by the expressions of this function.
1218    pub(super) fn process_function(
1219        &self,
1220        fun: &crate::Function,
1221        module: &crate::Module,
1222        flags: ValidationFlags,
1223        capabilities: super::Capabilities,
1224    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1225        let mut info = FunctionInfo {
1226            flags,
1227            available_stages: ShaderStages::all(),
1228            uniformity: Uniformity::new(),
1229            may_kill: false,
1230            sampling_set: crate::FastHashSet::default(),
1231            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1232            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1233            sampling: crate::FastHashSet::default(),
1234            dual_source_blending: false,
1235            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1236        };
1237        let resolve_context =
1238            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1239
1240        for (handle, _) in fun.expressions.iter() {
1241            if let Err(source) = info.process_expression(
1242                handle,
1243                &fun.expressions,
1244                &self.functions,
1245                &resolve_context,
1246                capabilities,
1247            ) {
1248                return Err(FunctionError::Expression { handle, source }
1249                    .with_span_handle(handle, &fun.expressions));
1250            }
1251        }
1252
1253        for (_, expr) in fun.local_variables.iter() {
1254            if let Some(init) = expr.init {
1255                let _ = info.add_ref(init);
1256            }
1257        }
1258
1259        let uniformity = info.process_block(
1260            &fun.body,
1261            &self.functions,
1262            None,
1263            &fun.expressions,
1264            &module.diagnostic_filters,
1265        )?;
1266        info.uniformity = uniformity.result;
1267        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1268
1269        // If there are any globals referenced directly by a named expression,
1270        // ensure they are marked as used even if they are not referenced
1271        // anywhere else. An important case where this matters is phony
1272        // assignments used to include a global in the shader's resource
1273        // interface. https://www.w3.org/TR/WGSL/#phony-assignment-section
1274        for &handle in fun.named_expressions.keys() {
1275            if let Some(global) = info[handle].assignable_global {
1276                if info.global_uses[global.index()].is_empty() {
1277                    info.global_uses[global.index()] = GlobalUse::QUERY;
1278                }
1279            }
1280        }
1281
1282        Ok(info)
1283    }
1284
1285    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1286        &self.entry_points[index]
1287    }
1288}
1289
1290#[test]
1291fn uniform_control_flow() {
1292    use crate::{Expression as E, Statement as S};
1293
1294    let mut type_arena = crate::UniqueArena::new();
1295    let ty = type_arena.insert(
1296        crate::Type {
1297            name: None,
1298            inner: crate::TypeInner::Vector {
1299                size: crate::VectorSize::Bi,
1300                scalar: crate::Scalar::F32,
1301            },
1302        },
1303        Default::default(),
1304    );
1305    let mut global_var_arena = Arena::new();
1306    let non_uniform_global = global_var_arena.append(
1307        crate::GlobalVariable {
1308            name: None,
1309            init: None,
1310            ty,
1311            space: crate::AddressSpace::Handle,
1312            binding: None,
1313        },
1314        Default::default(),
1315    );
1316    let uniform_global = global_var_arena.append(
1317        crate::GlobalVariable {
1318            name: None,
1319            init: None,
1320            ty,
1321            binding: None,
1322            space: crate::AddressSpace::Uniform,
1323        },
1324        Default::default(),
1325    );
1326
1327    let mut expressions = Arena::new();
1328    // checks the uniform control flow
1329    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1330    // checks the non-uniform control flow
1331    let derivative_expr = expressions.append(
1332        E::Derivative {
1333            axis: crate::DerivativeAxis::X,
1334            ctrl: crate::DerivativeControl::None,
1335            expr: constant_expr,
1336        },
1337        Default::default(),
1338    );
1339    let emit_range_constant_derivative = expressions.range_from(0);
1340    let non_uniform_global_expr =
1341        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1342    let uniform_global_expr =
1343        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1344    let emit_range_globals = expressions.range_from(2);
1345
1346    // checks the QUERY flag
1347    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1348    // checks the transitive WRITE flag
1349    let access_expr = expressions.append(
1350        E::AccessIndex {
1351            base: non_uniform_global_expr,
1352            index: 1,
1353        },
1354        Default::default(),
1355    );
1356    let emit_range_query_access_globals = expressions.range_from(2);
1357
1358    let mut info = FunctionInfo {
1359        flags: ValidationFlags::all(),
1360        available_stages: ShaderStages::all(),
1361        uniformity: Uniformity::new(),
1362        may_kill: false,
1363        sampling_set: crate::FastHashSet::default(),
1364        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1365        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1366        sampling: crate::FastHashSet::default(),
1367        dual_source_blending: false,
1368        diagnostic_filter_leaf: None,
1369    };
1370    let resolve_context = ResolveContext {
1371        constants: &Arena::new(),
1372        overrides: &Arena::new(),
1373        types: &type_arena,
1374        special_types: &crate::SpecialTypes::default(),
1375        global_vars: &global_var_arena,
1376        local_vars: &Arena::new(),
1377        functions: &Arena::new(),
1378        arguments: &[],
1379    };
1380    for (handle, _) in expressions.iter() {
1381        info.process_expression(
1382            handle,
1383            &expressions,
1384            &[],
1385            &resolve_context,
1386            super::Capabilities::empty(),
1387        )
1388        .unwrap();
1389    }
1390    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1391    assert_eq!(info[uniform_global_expr].ref_count, 1);
1392    assert_eq!(info[query_expr].ref_count, 0);
1393    assert_eq!(info[access_expr].ref_count, 0);
1394    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1395    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1396
1397    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1398    let stmt_if_uniform = S::If {
1399        condition: uniform_global_expr,
1400        accept: crate::Block::new(),
1401        reject: vec![
1402            S::Emit(emit_range_constant_derivative.clone()),
1403            S::Store {
1404                pointer: constant_expr,
1405                value: derivative_expr,
1406            },
1407        ]
1408        .into(),
1409    };
1410    assert_eq!(
1411        info.process_block(
1412            &vec![stmt_emit1, stmt_if_uniform].into(),
1413            &[],
1414            None,
1415            &expressions,
1416            &Arena::new(),
1417        ),
1418        Ok(FunctionUniformity {
1419            result: Uniformity {
1420                non_uniform_result: None,
1421                requirements: UniformityRequirements::DERIVATIVE,
1422            },
1423            exit: ExitFlags::empty(),
1424        }),
1425    );
1426    assert_eq!(info[constant_expr].ref_count, 2);
1427    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1428
1429    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1430    let stmt_if_non_uniform = S::If {
1431        condition: non_uniform_global_expr,
1432        accept: vec![
1433            S::Emit(emit_range_constant_derivative),
1434            S::Store {
1435                pointer: constant_expr,
1436                value: derivative_expr,
1437            },
1438        ]
1439        .into(),
1440        reject: crate::Block::new(),
1441    };
1442    {
1443        let block_info = info.process_block(
1444            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1445            &[],
1446            None,
1447            &expressions,
1448            &Arena::new(),
1449        );
1450        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1451            assert_eq!(info[derivative_expr].ref_count, 2);
1452        } else {
1453            assert_eq!(
1454                block_info,
1455                Err(FunctionError::NonUniformControlFlow(
1456                    UniformityRequirements::DERIVATIVE,
1457                    derivative_expr,
1458                    UniformityDisruptor::Expression(non_uniform_global_expr)
1459                )
1460                .with_span()),
1461            );
1462            assert_eq!(info[derivative_expr].ref_count, 1);
1463
1464            // Test that the same thing passes when we disable the `derivative_uniformity`
1465            let mut diagnostic_filters = Arena::new();
1466            let diagnostic_filter_leaf = diagnostic_filters.append(
1467                DiagnosticFilterNode {
1468                    inner: crate::diagnostic_filter::DiagnosticFilter {
1469                        new_severity: crate::diagnostic_filter::Severity::Off,
1470                        triggering_rule:
1471                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1472                                StandardFilterableTriggeringRule::DerivativeUniformity,
1473                            ),
1474                    },
1475                    parent: None,
1476                },
1477                crate::Span::default(),
1478            );
1479            let mut info = FunctionInfo {
1480                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1481                ..info.clone()
1482            };
1483
1484            let block_info = info.process_block(
1485                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1486                &[],
1487                None,
1488                &expressions,
1489                &diagnostic_filters,
1490            );
1491            assert_eq!(
1492                block_info,
1493                Ok(FunctionUniformity {
1494                    result: Uniformity {
1495                        non_uniform_result: None,
1496                        requirements: UniformityRequirements::DERIVATIVE,
1497                    },
1498                    exit: ExitFlags::empty()
1499                }),
1500            );
1501            assert_eq!(info[derivative_expr].ref_count, 2);
1502        }
1503    }
1504    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1505
1506    let stmt_emit3 = S::Emit(emit_range_globals);
1507    let stmt_return_non_uniform = S::Return {
1508        value: Some(non_uniform_global_expr),
1509    };
1510    assert_eq!(
1511        info.process_block(
1512            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1513            &[],
1514            Some(UniformityDisruptor::Return),
1515            &expressions,
1516            &Arena::new(),
1517        ),
1518        Ok(FunctionUniformity {
1519            result: Uniformity {
1520                non_uniform_result: Some(non_uniform_global_expr),
1521                requirements: UniformityRequirements::empty(),
1522            },
1523            exit: ExitFlags::MAY_RETURN,
1524        }),
1525    );
1526    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1527
1528    // Check that uniformity requirements reach through a pointer
1529    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1530    let stmt_assign = S::Store {
1531        pointer: access_expr,
1532        value: query_expr,
1533    };
1534    let stmt_return_pointer = S::Return {
1535        value: Some(access_expr),
1536    };
1537    let stmt_kill = S::Kill;
1538    assert_eq!(
1539        info.process_block(
1540            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1541            &[],
1542            Some(UniformityDisruptor::Discard),
1543            &expressions,
1544            &Arena::new(),
1545        ),
1546        Ok(FunctionUniformity {
1547            result: Uniformity {
1548                non_uniform_result: Some(non_uniform_global_expr),
1549                requirements: UniformityRequirements::empty(),
1550            },
1551            exit: ExitFlags::all(),
1552        }),
1553    );
1554    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1555}