naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15    arena::{Arena, Handle},
16    proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24    /// Kinds of expressions that require uniform control flow.
25    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28    pub struct UniformityRequirements: u8 {
29        const WORK_GROUP_BARRIER = 0x1;
30        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32    }
33}
34
35/// Uniform control flow characteristics.
36#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41    /// A child expression with non-uniform result.
42    ///
43    /// This means, when the relevant invocations are scheduled on a compute unit,
44    /// they have to use vector registers to store an individual value
45    /// per invocation.
46    ///
47    /// Whenever the control flow is conditioned on such value,
48    /// the hardware needs to keep track of the mask of invocations,
49    /// and process all branches of the control flow.
50    ///
51    /// Any operations that depend on non-uniform results also produce non-uniform.
52    pub non_uniform_result: NonUniformResult,
53    /// If this expression requires uniform control flow, store the reason here.
54    pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58    const fn new() -> Self {
59        Uniformity {
60            non_uniform_result: None,
61            requirements: UniformityRequirements::empty(),
62        }
63    }
64}
65
66bitflags::bitflags! {
67    #[derive(Clone, Copy, Debug, PartialEq)]
68    struct ExitFlags: u8 {
69        /// Control flow may return from the function, which makes all the
70        /// subsequent statements within the current function (only!)
71        /// to be executed in a non-uniform control flow.
72        const MAY_RETURN = 0x1;
73        /// Control flow may be killed. Anything after [`Statement::Kill`] is
74        /// considered inside non-uniform context.
75        ///
76        /// [`Statement::Kill`]: crate::Statement::Kill
77        const MAY_KILL = 0x2;
78    }
79}
80
81/// Uniformity characteristics of a function.
82#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84    result: Uniformity,
85    exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89    type Output = Self;
90    fn bitor(self, other: Self) -> Self {
91        FunctionUniformity {
92            result: Uniformity {
93                non_uniform_result: self
94                    .result
95                    .non_uniform_result
96                    .or(other.result.non_uniform_result),
97                requirements: self.result.requirements | other.result.requirements,
98            },
99            exit: self.exit | other.exit,
100        }
101    }
102}
103
104impl FunctionUniformity {
105    const fn new() -> Self {
106        FunctionUniformity {
107            result: Uniformity::new(),
108            exit: ExitFlags::empty(),
109        }
110    }
111
112    /// Returns a disruptor based on the stored exit flags, if any.
113    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114        if self.exit.contains(ExitFlags::MAY_RETURN) {
115            Some(UniformityDisruptor::Return)
116        } else if self.exit.contains(ExitFlags::MAY_KILL) {
117            Some(UniformityDisruptor::Discard)
118        } else {
119            None
120        }
121    }
122}
123
124bitflags::bitflags! {
125    /// Indicates how a global variable is used.
126    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129    pub struct GlobalUse: u8 {
130        /// Data will be read from the variable.
131        const READ = 0x1;
132        /// Data will be written to the variable.
133        const WRITE = 0x2;
134        /// The information about the data is queried.
135        const QUERY = 0x4;
136        /// Atomic operations will be performed on the variable.
137        const ATOMIC = 0x8;
138    }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145    pub image: Handle<crate::GlobalVariable>,
146    pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152/// Information about an expression in a function body.
153pub struct ExpressionInfo {
154    /// Whether this expression is uniform, and why.
155    ///
156    /// If this expression's value is not uniform, this is the handle
157    /// of the expression from which this one's non-uniformity
158    /// originates. Otherwise, this is `None`.
159    pub uniformity: Uniformity,
160
161    /// The number of direct references to this expression in statements and
162    /// other expressions.
163    ///
164    /// This is a _local_ reference count only, it may be non-zero for
165    /// expressions that are ultimately unused.
166    pub ref_count: usize,
167
168    /// The global variable into which this expression produces a pointer.
169    ///
170    /// This is `None` unless this expression is either a
171    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
172    /// ultimately refers to some part of a global.
173    ///
174    /// [`Load`] expressions applied to pointer-typed arguments could
175    /// refer to globals, but we leave this as `None` for them.
176    ///
177    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
178    /// [`Access`]: crate::Expression::Access
179    /// [`AccessIndex`]: crate::Expression::AccessIndex
180    /// [`Load`]: crate::Expression::Load
181    assignable_global: Option<Handle<crate::GlobalVariable>>,
182
183    /// The type of this expression.
184    pub ty: TypeResolution,
185}
186
187impl ExpressionInfo {
188    const fn new() -> Self {
189        ExpressionInfo {
190            uniformity: Uniformity::new(),
191            ref_count: 0,
192            assignable_global: None,
193            // this doesn't matter at this point, will be overwritten
194            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
195                kind: crate::ScalarKind::Bool,
196                width: 0,
197            })),
198        }
199    }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
204#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
205enum GlobalOrArgument {
206    Global(Handle<crate::GlobalVariable>),
207    Argument(u32),
208}
209
210impl GlobalOrArgument {
211    fn from_expression(
212        expression_arena: &Arena<crate::Expression>,
213        expression: Handle<crate::Expression>,
214    ) -> Result<GlobalOrArgument, ExpressionError> {
215        Ok(match expression_arena[expression] {
216            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
217            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
218            crate::Expression::Access { base, .. }
219            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
220                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
221                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
222            },
223            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
224        })
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
230#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
231struct Sampling {
232    image: GlobalOrArgument,
233    sampler: GlobalOrArgument,
234}
235
236#[derive(Debug, Clone)]
237#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
238#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
239pub struct FunctionInfo {
240    /// Validation flags.
241    #[allow(dead_code)]
242    flags: ValidationFlags,
243    /// Set of shader stages where calling this function is valid.
244    pub available_stages: ShaderStages,
245    /// Uniformity characteristics.
246    pub uniformity: Uniformity,
247    /// Function may kill the invocation.
248    pub may_kill: bool,
249
250    /// All pairs of (texture, sampler) globals that may be used together in
251    /// sampling operations by this function and its callees. This includes
252    /// pairings that arise when this function passes textures and samplers as
253    /// arguments to its callees.
254    ///
255    /// This table does not include uses of textures and samplers passed as
256    /// arguments to this function itself, since we do not know which globals
257    /// those will be. However, this table *is* exhaustive when computed for an
258    /// entry point function: entry points never receive textures or samplers as
259    /// arguments, so all an entry point's sampling can be reported in terms of
260    /// globals.
261    ///
262    /// The GLSL back end uses this table to construct reflection info that
263    /// clients need to construct texture-combined sampler values.
264    pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266    /// How this function and its callees use this module's globals.
267    ///
268    /// This is indexed by `Handle<GlobalVariable>` indices. However,
269    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
270    /// so you can simply index this struct with a global handle to retrieve
271    /// its usage information.
272    global_uses: Box<[GlobalUse]>,
273
274    /// Information about each expression in this function's body.
275    ///
276    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
277    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
278    /// index this struct with an expression handle to retrieve its
279    /// `ExpressionInfo`.
280    expressions: Box<[ExpressionInfo]>,
281
282    /// All (texture, sampler) pairs that may be used together in sampling
283    /// operations by this function and its callees, whether they are accessed
284    /// as globals or passed as arguments.
285    ///
286    /// Participants are represented by [`GlobalVariable`] handles whenever
287    /// possible, and otherwise by indices of this function's arguments.
288    ///
289    /// When analyzing a function call, we combine this data about the callee
290    /// with the actual arguments being passed to produce the callers' own
291    /// `sampling_set` and `sampling` tables.
292    ///
293    /// [`GlobalVariable`]: crate::GlobalVariable
294    sampling: crate::FastHashSet<Sampling>,
295
296    /// Indicates that the function is using dual source blending.
297    pub dual_source_blending: bool,
298
299    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
300    /// module.
301    ///
302    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
303    /// validation.
304    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308    pub const fn global_variable_count(&self) -> usize {
309        self.global_uses.len()
310    }
311    pub const fn expression_count(&self) -> usize {
312        self.expressions.len()
313    }
314    pub fn dominates_global_use(&self, other: &Self) -> bool {
315        for (self_global_uses, other_global_uses) in
316            self.global_uses.iter().zip(other.global_uses.iter())
317        {
318            if !self_global_uses.contains(*other_global_uses) {
319                return false;
320            }
321        }
322        true
323    }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327    type Output = GlobalUse;
328    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329        &self.global_uses[handle.index()]
330    }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334    type Output = ExpressionInfo;
335    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336        &self.expressions[handle.index()]
337    }
338}
339
340/// Disruptor of the uniform control flow.
341#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345    Expression(Handle<crate::Expression>),
346    #[error("There is a Return earlier in the control flow of the function")]
347    Return,
348    #[error("There is a Discard earlier in the entry point across all called functions")]
349    Discard,
350}
351
352impl FunctionInfo {
353    /// Record a use of `expr` of the sort given by `global_use`.
354    ///
355    /// Bump `expr`'s reference count, and return its uniformity.
356    ///
357    /// If `expr` is a pointer to a global variable, or some part of
358    /// a global variable, add `global_use` to that global's set of
359    /// uses.
360    #[must_use]
361    fn add_ref_impl(
362        &mut self,
363        expr: Handle<crate::Expression>,
364        global_use: GlobalUse,
365    ) -> NonUniformResult {
366        let info = &mut self.expressions[expr.index()];
367        info.ref_count += 1;
368        // Record usage if this expression may access a global
369        if let Some(global) = info.assignable_global {
370            self.global_uses[global.index()] |= global_use;
371        }
372        info.uniformity.non_uniform_result
373    }
374
375    /// Note an entry point's use of `global` not recorded by [`ModuleInfo::process_function`].
376    ///
377    /// Most global variable usage should be recorded via [`add_ref_impl`] in the process
378    /// of expression behavior analysis by [`ModuleInfo::process_function`]. But that code
379    /// has no access to entrypoint-specific information, so interface analysis uses this
380    /// function to record global uses there (like task shader payloads).
381    ///
382    /// [`add_ref_impl`]: Self::add_ref_impl
383    pub(super) fn insert_global_use(
384        &mut self,
385        global_use: GlobalUse,
386        global: Handle<crate::GlobalVariable>,
387    ) {
388        self.global_uses[global.index()] |= global_use;
389    }
390
391    /// Record a use of `expr` for its value.
392    ///
393    /// This is used for almost all expression references. Anything
394    /// that writes to the value `expr` points to, or otherwise wants
395    /// contribute flags other than `GlobalUse::READ`, should use
396    /// `add_ref_impl` directly.
397    #[must_use]
398    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
399        self.add_ref_impl(expr, GlobalUse::READ)
400    }
401
402    /// Record a use of `expr`, and indicate which global variable it
403    /// refers to, if any.
404    ///
405    /// Bump `expr`'s reference count, and return its uniformity.
406    ///
407    /// If `expr` is a pointer to a global variable, or some part
408    /// thereof, store that global in `*assignable_global`. Leave the
409    /// global's uses unchanged.
410    ///
411    /// This is used to determine the [`assignable_global`] for
412    /// [`Access`] and [`AccessIndex`] expressions that ultimately
413    /// refer to a global variable. Those expressions don't contribute
414    /// any usage to the global themselves; that depends on how other
415    /// expressions use them.
416    ///
417    /// [`assignable_global`]: ExpressionInfo::assignable_global
418    /// [`Access`]: crate::Expression::Access
419    /// [`AccessIndex`]: crate::Expression::AccessIndex
420    #[must_use]
421    fn add_assignable_ref(
422        &mut self,
423        expr: Handle<crate::Expression>,
424        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
425    ) -> NonUniformResult {
426        let info = &mut self.expressions[expr.index()];
427        info.ref_count += 1;
428        // propagate the assignable global up the chain, till it either hits
429        // a value-type expression, or the assignment statement.
430        if let Some(global) = info.assignable_global {
431            if let Some(_old) = assignable_global.replace(global) {
432                unreachable!()
433            }
434        }
435        info.uniformity.non_uniform_result
436    }
437
438    /// Inherit information from a called function.
439    fn process_call(
440        &mut self,
441        callee: &Self,
442        arguments: &[Handle<crate::Expression>],
443        expression_arena: &Arena<crate::Expression>,
444    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
445        self.sampling_set
446            .extend(callee.sampling_set.iter().cloned());
447        for sampling in callee.sampling.iter() {
448            // If the callee was passed the texture or sampler as an argument,
449            // we may now be able to determine which globals those referred to.
450            let image_storage = match sampling.image {
451                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452                GlobalOrArgument::Argument(i) => {
453                    let Some(handle) = arguments.get(i as usize).cloned() else {
454                        // Argument count mismatch, will be reported later by validate_call
455                        break;
456                    };
457                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458                        |source| {
459                            FunctionError::Expression { handle, source }
460                                .with_span_handle(handle, expression_arena)
461                        },
462                    )?
463                }
464            };
465
466            let sampler_storage = match sampling.sampler {
467                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
468                GlobalOrArgument::Argument(i) => {
469                    let Some(handle) = arguments.get(i as usize).cloned() else {
470                        // Argument count mismatch, will be reported later by validate_call
471                        break;
472                    };
473                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
474                        |source| {
475                            FunctionError::Expression { handle, source }
476                                .with_span_handle(handle, expression_arena)
477                        },
478                    )?
479                }
480            };
481
482            // If we've managed to pin both the image and sampler down to
483            // specific globals, record that in our `sampling_set`. Otherwise,
484            // record as much as we do know in our own `sampling` table, for our
485            // callers to sort out.
486            match (image_storage, sampler_storage) {
487                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
488                    self.sampling_set.insert(SamplingKey { image, sampler });
489                }
490                (image, sampler) => {
491                    self.sampling.insert(Sampling { image, sampler });
492                }
493            }
494        }
495
496        // Inherit global use from our callees.
497        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
498            *mine |= *other;
499        }
500
501        Ok(FunctionUniformity {
502            result: callee.uniformity.clone(),
503            exit: if callee.may_kill {
504                ExitFlags::MAY_KILL
505            } else {
506                ExitFlags::empty()
507            },
508        })
509    }
510
511    /// Compute the [`ExpressionInfo`] for `handle`.
512    ///
513    /// Replace the dummy entry in [`self.expressions`] for `handle`
514    /// with a real `ExpressionInfo` value describing that expression.
515    ///
516    /// This function is called as part of a forward sweep through the
517    /// arena, so we can assume that all earlier expressions in the
518    /// arena already have valid info. Since expressions only depend
519    /// on earlier expressions, this includes all our subexpressions.
520    ///
521    /// Adjust the reference counts on all expressions we use.
522    ///
523    /// Also populate the [`sampling_set`], [`sampling`] and
524    /// [`global_uses`] fields of `self`.
525    ///
526    /// [`self.expressions`]: FunctionInfo::expressions
527    /// [`sampling_set`]: FunctionInfo::sampling_set
528    /// [`sampling`]: FunctionInfo::sampling
529    /// [`global_uses`]: FunctionInfo::global_uses
530    #[allow(clippy::or_fun_call)]
531    fn process_expression(
532        &mut self,
533        handle: Handle<crate::Expression>,
534        expression_arena: &Arena<crate::Expression>,
535        other_functions: &[FunctionInfo],
536        resolve_context: &ResolveContext,
537        capabilities: super::Capabilities,
538    ) -> Result<(), ExpressionError> {
539        use crate::{Expression as E, SampleLevel as Sl};
540
541        let expression = &expression_arena[handle];
542        let mut assignable_global = None;
543        let uniformity = match *expression {
544            E::Access { base, index } => {
545                let base_ty = self[base].ty.inner_with(resolve_context.types);
546
547                // build up the caps needed if this is indexed non-uniformly
548                let mut needed_caps = super::Capabilities::empty();
549                let is_binding_array = match *base_ty {
550                    crate::TypeInner::BindingArray {
551                        base: array_element_ty_handle,
552                        ..
553                    } => {
554                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
555                        let array_element_ty =
556                            &resolve_context.types[array_element_ty_handle].inner;
557
558                        needed_caps |= match *array_element_ty {
559                            // If we're an image, use the appropriate capability.
560                            crate::TypeInner::Image { class, .. } => match class {
561                                crate::ImageClass::Storage { .. } => {
562                                    super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
563                                }
564                                _ => {
565                                    super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
566                                }
567                            },
568                            crate::TypeInner::Sampler { .. } => {
569                                super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
570                            }
571                            // If we're anything but an image or sampler, assume we're a buffer and use the address space.
572                            _ => {
573                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
574                                    let global = &resolve_context.global_vars[global_handle];
575                                    match global.space {
576                                        crate::AddressSpace::Uniform => {
577                                            super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
578                                        }
579                                        crate::AddressSpace::Storage { .. } => {
580                                            super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
581                                        }
582                                        _ => unreachable!(),
583                                    }
584                                } else {
585                                    unreachable!()
586                                }
587                            }
588                        };
589
590                        true
591                    }
592                    _ => false,
593                };
594
595                if self[index].uniformity.non_uniform_result.is_some()
596                    && !capabilities.contains(needed_caps)
597                    && is_binding_array
598                {
599                    return Err(ExpressionError::MissingCapabilities(needed_caps));
600                }
601
602                Uniformity {
603                    non_uniform_result: self
604                        .add_assignable_ref(base, &mut assignable_global)
605                        .or(self.add_ref(index)),
606                    requirements: UniformityRequirements::empty(),
607                }
608            }
609            E::AccessIndex { base, .. } => Uniformity {
610                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
611                requirements: UniformityRequirements::empty(),
612            },
613            // always uniform
614            E::Splat { size: _, value } => Uniformity {
615                non_uniform_result: self.add_ref(value),
616                requirements: UniformityRequirements::empty(),
617            },
618            E::Swizzle { vector, .. } => Uniformity {
619                non_uniform_result: self.add_ref(vector),
620                requirements: UniformityRequirements::empty(),
621            },
622            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
623            E::Compose { ref components, .. } => {
624                let non_uniform_result = components
625                    .iter()
626                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
627                Uniformity {
628                    non_uniform_result,
629                    requirements: UniformityRequirements::empty(),
630                }
631            }
632            // depends on the builtin
633            E::FunctionArgument(index) => {
634                let arg = &resolve_context.arguments[index as usize];
635                let uniform = match arg.binding {
636                    Some(crate::Binding::BuiltIn(
637                        // per-work-group built-ins are uniform
638                        crate::BuiltIn::WorkGroupId
639                        | crate::BuiltIn::WorkGroupSize
640                        | crate::BuiltIn::NumWorkGroups,
641                    )) => true,
642                    _ => false,
643                };
644                Uniformity {
645                    non_uniform_result: if uniform { None } else { Some(handle) },
646                    requirements: UniformityRequirements::empty(),
647                }
648            }
649            // depends on the address space
650            E::GlobalVariable(gh) => {
651                use crate::AddressSpace as As;
652                assignable_global = Some(gh);
653                let var = &resolve_context.global_vars[gh];
654                let uniform = match var.space {
655                    // local data is non-uniform
656                    As::Function | As::Private => false,
657                    // workgroup memory is exclusively accessed by the group
658                    // task payload memory is very similar to workgroup memory
659                    As::WorkGroup | As::TaskPayload => true,
660                    // uniform data
661                    As::Uniform | As::Immediate => true,
662                    // storage data is only uniform when read-only
663                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
664                    As::Handle => false,
665                };
666                Uniformity {
667                    non_uniform_result: if uniform { None } else { Some(handle) },
668                    requirements: UniformityRequirements::empty(),
669                }
670            }
671            E::LocalVariable(_) => Uniformity {
672                non_uniform_result: Some(handle),
673                requirements: UniformityRequirements::empty(),
674            },
675            E::Load { pointer } => Uniformity {
676                non_uniform_result: self.add_ref(pointer),
677                requirements: UniformityRequirements::empty(),
678            },
679            E::ImageSample {
680                image,
681                sampler,
682                gather: _,
683                coordinate,
684                array_index,
685                offset,
686                level,
687                depth_ref,
688                clamp_to_edge: _,
689            } => {
690                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
691                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
692
693                match (image_storage, sampler_storage) {
694                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
695                        self.sampling_set.insert(SamplingKey { image, sampler });
696                    }
697                    _ => {
698                        self.sampling.insert(Sampling {
699                            image: image_storage,
700                            sampler: sampler_storage,
701                        });
702                    }
703                }
704
705                // "nur" == "Non-Uniform Result"
706                let array_nur = array_index.and_then(|h| self.add_ref(h));
707                let level_nur = match level {
708                    Sl::Auto | Sl::Zero => None,
709                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
710                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
711                };
712                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
713                let offset_nur = offset.and_then(|h| self.add_ref(h));
714                Uniformity {
715                    non_uniform_result: self
716                        .add_ref(image)
717                        .or(self.add_ref(sampler))
718                        .or(self.add_ref(coordinate))
719                        .or(array_nur)
720                        .or(level_nur)
721                        .or(dref_nur)
722                        .or(offset_nur),
723                    requirements: if level.implicit_derivatives() {
724                        UniformityRequirements::IMPLICIT_LEVEL
725                    } else {
726                        UniformityRequirements::empty()
727                    },
728                }
729            }
730            E::ImageLoad {
731                image,
732                coordinate,
733                array_index,
734                sample,
735                level,
736            } => {
737                let array_nur = array_index.and_then(|h| self.add_ref(h));
738                let sample_nur = sample.and_then(|h| self.add_ref(h));
739                let level_nur = level.and_then(|h| self.add_ref(h));
740                Uniformity {
741                    non_uniform_result: self
742                        .add_ref(image)
743                        .or(self.add_ref(coordinate))
744                        .or(array_nur)
745                        .or(sample_nur)
746                        .or(level_nur),
747                    requirements: UniformityRequirements::empty(),
748                }
749            }
750            E::ImageQuery { image, query } => {
751                let query_nur = match query {
752                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
753                    _ => None,
754                };
755                Uniformity {
756                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
757                    requirements: UniformityRequirements::empty(),
758                }
759            }
760            E::Unary { expr, .. } => Uniformity {
761                non_uniform_result: self.add_ref(expr),
762                requirements: UniformityRequirements::empty(),
763            },
764            E::Binary { left, right, .. } => Uniformity {
765                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
766                requirements: UniformityRequirements::empty(),
767            },
768            E::Select {
769                condition,
770                accept,
771                reject,
772            } => Uniformity {
773                non_uniform_result: self
774                    .add_ref(condition)
775                    .or(self.add_ref(accept))
776                    .or(self.add_ref(reject)),
777                requirements: UniformityRequirements::empty(),
778            },
779            // explicit derivatives require uniform
780            E::Derivative { expr, .. } => Uniformity {
781                //Note: taking a derivative of a uniform doesn't make it non-uniform
782                non_uniform_result: self.add_ref(expr),
783                requirements: UniformityRequirements::DERIVATIVE,
784            },
785            E::Relational { argument, .. } => Uniformity {
786                non_uniform_result: self.add_ref(argument),
787                requirements: UniformityRequirements::empty(),
788            },
789            E::Math {
790                fun: _,
791                arg,
792                arg1,
793                arg2,
794                arg3,
795            } => {
796                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
797                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
798                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
799                Uniformity {
800                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
801                    requirements: UniformityRequirements::empty(),
802                }
803            }
804            E::As { expr, .. } => Uniformity {
805                non_uniform_result: self.add_ref(expr),
806                requirements: UniformityRequirements::empty(),
807            },
808            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
809            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
810                non_uniform_result: Some(handle),
811                requirements: UniformityRequirements::empty(),
812            },
813            E::WorkGroupUniformLoadResult { .. } => Uniformity {
814                // The result of WorkGroupUniformLoad is always uniform by definition
815                non_uniform_result: None,
816                // The call is what cares about uniformity, not the expression
817                // This expression is never emitted, so this requirement should never be used anyway?
818                requirements: UniformityRequirements::empty(),
819            },
820            E::ArrayLength(expr) => Uniformity {
821                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
822                requirements: UniformityRequirements::empty(),
823            },
824            E::RayQueryGetIntersection {
825                query,
826                committed: _,
827            } => Uniformity {
828                non_uniform_result: self.add_ref(query),
829                requirements: UniformityRequirements::empty(),
830            },
831            E::SubgroupBallotResult => Uniformity {
832                non_uniform_result: Some(handle),
833                requirements: UniformityRequirements::empty(),
834            },
835            E::SubgroupOperationResult { .. } => Uniformity {
836                non_uniform_result: Some(handle),
837                requirements: UniformityRequirements::empty(),
838            },
839            E::RayQueryVertexPositions {
840                query,
841                committed: _,
842            } => Uniformity {
843                non_uniform_result: self.add_ref(query),
844                requirements: UniformityRequirements::empty(),
845            },
846        };
847
848        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
849        self.expressions[handle.index()] = ExpressionInfo {
850            uniformity,
851            ref_count: 0,
852            assignable_global,
853            ty,
854        };
855        Ok(())
856    }
857
858    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
859    /// Returns the uniformity characteristics at the *function* level, i.e.
860    /// whether or not the function requires to be called in uniform control flow,
861    /// and whether the produced result is not disrupting the control flow.
862    ///
863    /// The parent control flow is uniform if `disruptor.is_none()`.
864    ///
865    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
866    /// require uniformity, but the current flow is non-uniform.
867    #[allow(clippy::or_fun_call)]
868    fn process_block(
869        &mut self,
870        statements: &crate::Block,
871        other_functions: &[FunctionInfo],
872        mut disruptor: Option<UniformityDisruptor>,
873        expression_arena: &Arena<crate::Expression>,
874        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
875    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
876        use crate::Statement as S;
877
878        let mut combined_uniformity = FunctionUniformity::new();
879        for statement in statements {
880            let uniformity = match *statement {
881                S::Emit(ref range) => {
882                    let mut requirements = UniformityRequirements::empty();
883                    for expr in range.clone() {
884                        let req = self.expressions[expr.index()].uniformity.requirements;
885                        if self
886                            .flags
887                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
888                            && !req.is_empty()
889                        {
890                            if let Some(cause) = disruptor {
891                                let severity = DiagnosticFilterNode::search(
892                                    self.diagnostic_filter_leaf,
893                                    diagnostic_filter_arena,
894                                    StandardFilterableTriggeringRule::DerivativeUniformity,
895                                );
896                                severity.report_diag(
897                                    FunctionError::NonUniformControlFlow(req, expr, cause)
898                                        .with_span_handle(expr, expression_arena),
899                                    // TODO: Yes, this isn't contextualized with source, because
900                                    // the user is supposed to render what would normally be an
901                                    // error here. Once we actually support warning-level
902                                    // diagnostic items, then we won't need this non-compliant hack:
903                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
904                                    |e, level| log::log!(level, "{e}"),
905                                )?;
906                            }
907                        }
908                        requirements |= req;
909                    }
910                    FunctionUniformity {
911                        result: Uniformity {
912                            non_uniform_result: None,
913                            requirements,
914                        },
915                        exit: ExitFlags::empty(),
916                    }
917                }
918                S::Break | S::Continue => FunctionUniformity::new(),
919                S::Kill => FunctionUniformity {
920                    result: Uniformity::new(),
921                    exit: if disruptor.is_some() {
922                        ExitFlags::MAY_KILL
923                    } else {
924                        ExitFlags::empty()
925                    },
926                },
927                S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
928                    result: Uniformity {
929                        non_uniform_result: None,
930                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
931                    },
932                    exit: ExitFlags::empty(),
933                },
934                S::WorkGroupUniformLoad { pointer, .. } => {
935                    let _condition_nur = self.add_ref(pointer);
936
937                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
938                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
939                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
940                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
941
942                    /*
943                    if self
944                        .flags
945                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
946                    {
947                        let condition_nur = self.add_ref(pointer);
948                        let this_disruptor =
949                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
950                        if let Some(cause) = this_disruptor {
951                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
952                                .with_span_static(*span, "WorkGroupUniformLoad"));
953                        }
954                    } */
955                    FunctionUniformity {
956                        result: Uniformity {
957                            non_uniform_result: None,
958                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
959                        },
960                        exit: ExitFlags::empty(),
961                    }
962                }
963                S::Block(ref b) => self.process_block(
964                    b,
965                    other_functions,
966                    disruptor,
967                    expression_arena,
968                    diagnostic_filter_arena,
969                )?,
970                S::If {
971                    condition,
972                    ref accept,
973                    ref reject,
974                } => {
975                    let condition_nur = self.add_ref(condition);
976                    let branch_disruptor =
977                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
978                    let accept_uniformity = self.process_block(
979                        accept,
980                        other_functions,
981                        branch_disruptor,
982                        expression_arena,
983                        diagnostic_filter_arena,
984                    )?;
985                    let reject_uniformity = self.process_block(
986                        reject,
987                        other_functions,
988                        branch_disruptor,
989                        expression_arena,
990                        diagnostic_filter_arena,
991                    )?;
992                    accept_uniformity | reject_uniformity
993                }
994                S::Switch {
995                    selector,
996                    ref cases,
997                } => {
998                    let selector_nur = self.add_ref(selector);
999                    let branch_disruptor =
1000                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1001                    let mut uniformity = FunctionUniformity::new();
1002                    let mut case_disruptor = branch_disruptor;
1003                    for case in cases.iter() {
1004                        let case_uniformity = self.process_block(
1005                            &case.body,
1006                            other_functions,
1007                            case_disruptor,
1008                            expression_arena,
1009                            diagnostic_filter_arena,
1010                        )?;
1011                        case_disruptor = if case.fall_through {
1012                            case_disruptor.or(case_uniformity.exit_disruptor())
1013                        } else {
1014                            branch_disruptor
1015                        };
1016                        uniformity = uniformity | case_uniformity;
1017                    }
1018                    uniformity
1019                }
1020                S::Loop {
1021                    ref body,
1022                    ref continuing,
1023                    break_if,
1024                } => {
1025                    let body_uniformity = self.process_block(
1026                        body,
1027                        other_functions,
1028                        disruptor,
1029                        expression_arena,
1030                        diagnostic_filter_arena,
1031                    )?;
1032                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1033                    let continuing_uniformity = self.process_block(
1034                        continuing,
1035                        other_functions,
1036                        continuing_disruptor,
1037                        expression_arena,
1038                        diagnostic_filter_arena,
1039                    )?;
1040                    if let Some(expr) = break_if {
1041                        let _ = self.add_ref(expr);
1042                    }
1043                    body_uniformity | continuing_uniformity
1044                }
1045                S::Return { value } => FunctionUniformity {
1046                    result: Uniformity {
1047                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1048                        requirements: UniformityRequirements::empty(),
1049                    },
1050                    exit: if disruptor.is_some() {
1051                        ExitFlags::MAY_RETURN
1052                    } else {
1053                        ExitFlags::empty()
1054                    },
1055                },
1056                // Here and below, the used expressions are already emitted,
1057                // and their results do not affect the function return value,
1058                // so we can ignore their non-uniformity.
1059                S::Store { pointer, value } => {
1060                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1061                    let _ = self.add_ref(value);
1062                    FunctionUniformity::new()
1063                }
1064                S::ImageStore {
1065                    image,
1066                    coordinate,
1067                    array_index,
1068                    value,
1069                } => {
1070                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1071                    if let Some(expr) = array_index {
1072                        let _ = self.add_ref(expr);
1073                    }
1074                    let _ = self.add_ref(coordinate);
1075                    let _ = self.add_ref(value);
1076                    FunctionUniformity::new()
1077                }
1078                S::Call {
1079                    function,
1080                    ref arguments,
1081                    result: _,
1082                } => {
1083                    for &argument in arguments {
1084                        let _ = self.add_ref(argument);
1085                    }
1086                    let info = &other_functions[function.index()];
1087                    //Note: the result is validated by the Validator, not here
1088                    self.process_call(info, arguments, expression_arena)?
1089                }
1090                S::Atomic {
1091                    pointer,
1092                    ref fun,
1093                    value,
1094                    result: _,
1095                } => {
1096                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1097                    let _ = self.add_ref(value);
1098                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1099                        let _ = self.add_ref(cmp);
1100                    }
1101                    FunctionUniformity::new()
1102                }
1103                S::ImageAtomic {
1104                    image,
1105                    coordinate,
1106                    array_index,
1107                    fun: _,
1108                    value,
1109                } => {
1110                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1111                    let _ = self.add_ref(coordinate);
1112                    if let Some(expr) = array_index {
1113                        let _ = self.add_ref(expr);
1114                    }
1115                    let _ = self.add_ref(value);
1116                    FunctionUniformity::new()
1117                }
1118                S::RayQuery { query, ref fun } => {
1119                    let _ = self.add_ref(query);
1120                    match *fun {
1121                        crate::RayQueryFunction::Initialize {
1122                            acceleration_structure,
1123                            descriptor,
1124                        } => {
1125                            let _ = self.add_ref(acceleration_structure);
1126                            let _ = self.add_ref(descriptor);
1127                        }
1128                        crate::RayQueryFunction::Proceed { result: _ } => {}
1129                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1130                            let _ = self.add_ref(hit_t);
1131                        }
1132                        crate::RayQueryFunction::ConfirmIntersection => {}
1133                        crate::RayQueryFunction::Terminate => {}
1134                    }
1135                    FunctionUniformity::new()
1136                }
1137                S::SubgroupBallot {
1138                    result: _,
1139                    predicate,
1140                } => {
1141                    if let Some(predicate) = predicate {
1142                        let _ = self.add_ref(predicate);
1143                    }
1144                    FunctionUniformity::new()
1145                }
1146                S::SubgroupCollectiveOperation {
1147                    op: _,
1148                    collective_op: _,
1149                    argument,
1150                    result: _,
1151                } => {
1152                    let _ = self.add_ref(argument);
1153                    FunctionUniformity::new()
1154                }
1155                S::SubgroupGather {
1156                    mode,
1157                    argument,
1158                    result: _,
1159                } => {
1160                    let _ = self.add_ref(argument);
1161                    match mode {
1162                        crate::GatherMode::BroadcastFirst => {}
1163                        crate::GatherMode::Broadcast(index)
1164                        | crate::GatherMode::Shuffle(index)
1165                        | crate::GatherMode::ShuffleDown(index)
1166                        | crate::GatherMode::ShuffleUp(index)
1167                        | crate::GatherMode::ShuffleXor(index)
1168                        | crate::GatherMode::QuadBroadcast(index) => {
1169                            let _ = self.add_ref(index);
1170                        }
1171                        crate::GatherMode::QuadSwap(_) => {}
1172                    }
1173                    FunctionUniformity::new()
1174                }
1175            };
1176
1177            disruptor = disruptor.or(uniformity.exit_disruptor());
1178            combined_uniformity = combined_uniformity | uniformity;
1179        }
1180        Ok(combined_uniformity)
1181    }
1182}
1183
1184impl ModuleInfo {
1185    /// Populates `self.const_expression_types`
1186    pub(super) fn process_const_expression(
1187        &mut self,
1188        handle: Handle<crate::Expression>,
1189        resolve_context: &ResolveContext,
1190        gctx: crate::proc::GlobalCtx,
1191    ) -> Result<(), super::ConstExpressionError> {
1192        self.const_expression_types[handle.index()] =
1193            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1194        Ok(())
1195    }
1196
1197    /// Builds the `FunctionInfo` based on the function, and validates the
1198    /// uniform control flow if required by the expressions of this function.
1199    pub(super) fn process_function(
1200        &self,
1201        fun: &crate::Function,
1202        module: &crate::Module,
1203        flags: ValidationFlags,
1204        capabilities: super::Capabilities,
1205    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1206        let mut info = FunctionInfo {
1207            flags,
1208            available_stages: ShaderStages::all(),
1209            uniformity: Uniformity::new(),
1210            may_kill: false,
1211            sampling_set: crate::FastHashSet::default(),
1212            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1213            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1214            sampling: crate::FastHashSet::default(),
1215            dual_source_blending: false,
1216            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1217        };
1218        let resolve_context =
1219            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1220
1221        for (handle, _) in fun.expressions.iter() {
1222            if let Err(source) = info.process_expression(
1223                handle,
1224                &fun.expressions,
1225                &self.functions,
1226                &resolve_context,
1227                capabilities,
1228            ) {
1229                return Err(FunctionError::Expression { handle, source }
1230                    .with_span_handle(handle, &fun.expressions));
1231            }
1232        }
1233
1234        for (_, expr) in fun.local_variables.iter() {
1235            if let Some(init) = expr.init {
1236                let _ = info.add_ref(init);
1237            }
1238        }
1239
1240        let uniformity = info.process_block(
1241            &fun.body,
1242            &self.functions,
1243            None,
1244            &fun.expressions,
1245            &module.diagnostic_filters,
1246        )?;
1247        info.uniformity = uniformity.result;
1248        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1249
1250        // If there are any globals referenced directly by a named expression,
1251        // ensure they are marked as used even if they are not referenced
1252        // anywhere else. An important case where this matters is phony
1253        // assignments used to include a global in the shader's resource
1254        // interface. https://www.w3.org/TR/WGSL/#phony-assignment-section
1255        for &handle in fun.named_expressions.keys() {
1256            if let Some(global) = info[handle].assignable_global {
1257                if info.global_uses[global.index()].is_empty() {
1258                    info.global_uses[global.index()] = GlobalUse::QUERY;
1259                }
1260            }
1261        }
1262
1263        Ok(info)
1264    }
1265
1266    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1267        &self.entry_points[index]
1268    }
1269}
1270
1271#[test]
1272fn uniform_control_flow() {
1273    use crate::{Expression as E, Statement as S};
1274
1275    let mut type_arena = crate::UniqueArena::new();
1276    let ty = type_arena.insert(
1277        crate::Type {
1278            name: None,
1279            inner: crate::TypeInner::Vector {
1280                size: crate::VectorSize::Bi,
1281                scalar: crate::Scalar::F32,
1282            },
1283        },
1284        Default::default(),
1285    );
1286    let mut global_var_arena = Arena::new();
1287    let non_uniform_global = global_var_arena.append(
1288        crate::GlobalVariable {
1289            name: None,
1290            init: None,
1291            ty,
1292            space: crate::AddressSpace::Handle,
1293            binding: None,
1294        },
1295        Default::default(),
1296    );
1297    let uniform_global = global_var_arena.append(
1298        crate::GlobalVariable {
1299            name: None,
1300            init: None,
1301            ty,
1302            binding: None,
1303            space: crate::AddressSpace::Uniform,
1304        },
1305        Default::default(),
1306    );
1307
1308    let mut expressions = Arena::new();
1309    // checks the uniform control flow
1310    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1311    // checks the non-uniform control flow
1312    let derivative_expr = expressions.append(
1313        E::Derivative {
1314            axis: crate::DerivativeAxis::X,
1315            ctrl: crate::DerivativeControl::None,
1316            expr: constant_expr,
1317        },
1318        Default::default(),
1319    );
1320    let emit_range_constant_derivative = expressions.range_from(0);
1321    let non_uniform_global_expr =
1322        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1323    let uniform_global_expr =
1324        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1325    let emit_range_globals = expressions.range_from(2);
1326
1327    // checks the QUERY flag
1328    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1329    // checks the transitive WRITE flag
1330    let access_expr = expressions.append(
1331        E::AccessIndex {
1332            base: non_uniform_global_expr,
1333            index: 1,
1334        },
1335        Default::default(),
1336    );
1337    let emit_range_query_access_globals = expressions.range_from(2);
1338
1339    let mut info = FunctionInfo {
1340        flags: ValidationFlags::all(),
1341        available_stages: ShaderStages::all(),
1342        uniformity: Uniformity::new(),
1343        may_kill: false,
1344        sampling_set: crate::FastHashSet::default(),
1345        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1346        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1347        sampling: crate::FastHashSet::default(),
1348        dual_source_blending: false,
1349        diagnostic_filter_leaf: None,
1350    };
1351    let resolve_context = ResolveContext {
1352        constants: &Arena::new(),
1353        overrides: &Arena::new(),
1354        types: &type_arena,
1355        special_types: &crate::SpecialTypes::default(),
1356        global_vars: &global_var_arena,
1357        local_vars: &Arena::new(),
1358        functions: &Arena::new(),
1359        arguments: &[],
1360    };
1361    for (handle, _) in expressions.iter() {
1362        info.process_expression(
1363            handle,
1364            &expressions,
1365            &[],
1366            &resolve_context,
1367            super::Capabilities::empty(),
1368        )
1369        .unwrap();
1370    }
1371    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1372    assert_eq!(info[uniform_global_expr].ref_count, 1);
1373    assert_eq!(info[query_expr].ref_count, 0);
1374    assert_eq!(info[access_expr].ref_count, 0);
1375    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1376    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1377
1378    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1379    let stmt_if_uniform = S::If {
1380        condition: uniform_global_expr,
1381        accept: crate::Block::new(),
1382        reject: vec![
1383            S::Emit(emit_range_constant_derivative.clone()),
1384            S::Store {
1385                pointer: constant_expr,
1386                value: derivative_expr,
1387            },
1388        ]
1389        .into(),
1390    };
1391    assert_eq!(
1392        info.process_block(
1393            &vec![stmt_emit1, stmt_if_uniform].into(),
1394            &[],
1395            None,
1396            &expressions,
1397            &Arena::new(),
1398        ),
1399        Ok(FunctionUniformity {
1400            result: Uniformity {
1401                non_uniform_result: None,
1402                requirements: UniformityRequirements::DERIVATIVE,
1403            },
1404            exit: ExitFlags::empty(),
1405        }),
1406    );
1407    assert_eq!(info[constant_expr].ref_count, 2);
1408    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1409
1410    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1411    let stmt_if_non_uniform = S::If {
1412        condition: non_uniform_global_expr,
1413        accept: vec![
1414            S::Emit(emit_range_constant_derivative),
1415            S::Store {
1416                pointer: constant_expr,
1417                value: derivative_expr,
1418            },
1419        ]
1420        .into(),
1421        reject: crate::Block::new(),
1422    };
1423    {
1424        let block_info = info.process_block(
1425            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1426            &[],
1427            None,
1428            &expressions,
1429            &Arena::new(),
1430        );
1431        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1432            assert_eq!(info[derivative_expr].ref_count, 2);
1433        } else {
1434            assert_eq!(
1435                block_info,
1436                Err(FunctionError::NonUniformControlFlow(
1437                    UniformityRequirements::DERIVATIVE,
1438                    derivative_expr,
1439                    UniformityDisruptor::Expression(non_uniform_global_expr)
1440                )
1441                .with_span()),
1442            );
1443            assert_eq!(info[derivative_expr].ref_count, 1);
1444
1445            // Test that the same thing passes when we disable the `derivative_uniformity`
1446            let mut diagnostic_filters = Arena::new();
1447            let diagnostic_filter_leaf = diagnostic_filters.append(
1448                DiagnosticFilterNode {
1449                    inner: crate::diagnostic_filter::DiagnosticFilter {
1450                        new_severity: crate::diagnostic_filter::Severity::Off,
1451                        triggering_rule:
1452                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1453                                StandardFilterableTriggeringRule::DerivativeUniformity,
1454                            ),
1455                    },
1456                    parent: None,
1457                },
1458                crate::Span::default(),
1459            );
1460            let mut info = FunctionInfo {
1461                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1462                ..info.clone()
1463            };
1464
1465            let block_info = info.process_block(
1466                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1467                &[],
1468                None,
1469                &expressions,
1470                &diagnostic_filters,
1471            );
1472            assert_eq!(
1473                block_info,
1474                Ok(FunctionUniformity {
1475                    result: Uniformity {
1476                        non_uniform_result: None,
1477                        requirements: UniformityRequirements::DERIVATIVE,
1478                    },
1479                    exit: ExitFlags::empty()
1480                }),
1481            );
1482            assert_eq!(info[derivative_expr].ref_count, 2);
1483        }
1484    }
1485    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1486
1487    let stmt_emit3 = S::Emit(emit_range_globals);
1488    let stmt_return_non_uniform = S::Return {
1489        value: Some(non_uniform_global_expr),
1490    };
1491    assert_eq!(
1492        info.process_block(
1493            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1494            &[],
1495            Some(UniformityDisruptor::Return),
1496            &expressions,
1497            &Arena::new(),
1498        ),
1499        Ok(FunctionUniformity {
1500            result: Uniformity {
1501                non_uniform_result: Some(non_uniform_global_expr),
1502                requirements: UniformityRequirements::empty(),
1503            },
1504            exit: ExitFlags::MAY_RETURN,
1505        }),
1506    );
1507    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1508
1509    // Check that uniformity requirements reach through a pointer
1510    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1511    let stmt_assign = S::Store {
1512        pointer: access_expr,
1513        value: query_expr,
1514    };
1515    let stmt_return_pointer = S::Return {
1516        value: Some(access_expr),
1517    };
1518    let stmt_kill = S::Kill;
1519    assert_eq!(
1520        info.process_block(
1521            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1522            &[],
1523            Some(UniformityDisruptor::Discard),
1524            &expressions,
1525            &Arena::new(),
1526        ),
1527        Ok(FunctionUniformity {
1528            result: Uniformity {
1529                non_uniform_result: Some(non_uniform_global_expr),
1530                requirements: UniformityRequirements::empty(),
1531            },
1532            exit: ExitFlags::all(),
1533        }),
1534    );
1535    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1536}