naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15    arena::{Arena, Handle},
16    proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24    /// Kinds of expressions that require uniform control flow.
25    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28    pub struct UniformityRequirements: u8 {
29        const WORK_GROUP_BARRIER = 0x1;
30        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32    }
33}
34
35/// Uniform control flow characteristics.
36#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41    /// A child expression with non-uniform result.
42    ///
43    /// This means, when the relevant invocations are scheduled on a compute unit,
44    /// they have to use vector registers to store an individual value
45    /// per invocation.
46    ///
47    /// Whenever the control flow is conditioned on such value,
48    /// the hardware needs to keep track of the mask of invocations,
49    /// and process all branches of the control flow.
50    ///
51    /// Any operations that depend on non-uniform results also produce non-uniform.
52    pub non_uniform_result: NonUniformResult,
53    /// If this expression requires uniform control flow, store the reason here.
54    pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58    const fn new() -> Self {
59        Uniformity {
60            non_uniform_result: None,
61            requirements: UniformityRequirements::empty(),
62        }
63    }
64}
65
66bitflags::bitflags! {
67    #[derive(Clone, Copy, Debug, PartialEq)]
68    struct ExitFlags: u8 {
69        /// Control flow may return from the function, which makes all the
70        /// subsequent statements within the current function (only!)
71        /// to be executed in a non-uniform control flow.
72        const MAY_RETURN = 0x1;
73        /// Control flow may be killed. Anything after [`Statement::Kill`] is
74        /// considered inside non-uniform context.
75        ///
76        /// [`Statement::Kill`]: crate::Statement::Kill
77        const MAY_KILL = 0x2;
78    }
79}
80
81/// Uniformity characteristics of a function.
82#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84    result: Uniformity,
85    exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89    type Output = Self;
90    fn bitor(self, other: Self) -> Self {
91        FunctionUniformity {
92            result: Uniformity {
93                non_uniform_result: self
94                    .result
95                    .non_uniform_result
96                    .or(other.result.non_uniform_result),
97                requirements: self.result.requirements | other.result.requirements,
98            },
99            exit: self.exit | other.exit,
100        }
101    }
102}
103
104impl FunctionUniformity {
105    const fn new() -> Self {
106        FunctionUniformity {
107            result: Uniformity::new(),
108            exit: ExitFlags::empty(),
109        }
110    }
111
112    /// Returns a disruptor based on the stored exit flags, if any.
113    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114        if self.exit.contains(ExitFlags::MAY_RETURN) {
115            Some(UniformityDisruptor::Return)
116        } else if self.exit.contains(ExitFlags::MAY_KILL) {
117            Some(UniformityDisruptor::Discard)
118        } else {
119            None
120        }
121    }
122}
123
124bitflags::bitflags! {
125    /// Indicates how a global variable is used.
126    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129    pub struct GlobalUse: u8 {
130        /// Data will be read from the variable.
131        const READ = 0x1;
132        /// Data will be written to the variable.
133        const WRITE = 0x2;
134        /// The information about the data is queried.
135        const QUERY = 0x4;
136        /// Atomic operations will be performed on the variable.
137        const ATOMIC = 0x8;
138    }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145    pub image: Handle<crate::GlobalVariable>,
146    pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152/// Information about an expression in a function body.
153pub struct ExpressionInfo {
154    /// Whether this expression is uniform, and why.
155    ///
156    /// If this expression's value is not uniform, this is the handle
157    /// of the expression from which this one's non-uniformity
158    /// originates. Otherwise, this is `None`.
159    pub uniformity: Uniformity,
160
161    /// The number of direct references to this expression in statements and
162    /// other expressions.
163    ///
164    /// This is a _local_ reference count only, it may be non-zero for
165    /// expressions that are ultimately unused.
166    pub ref_count: usize,
167
168    /// The global variable into which this expression produces a pointer.
169    ///
170    /// This is `None` unless this expression is either a
171    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
172    /// ultimately refers to some part of a global.
173    ///
174    /// [`Load`] expressions applied to pointer-typed arguments could
175    /// refer to globals, but we leave this as `None` for them.
176    ///
177    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
178    /// [`Access`]: crate::Expression::Access
179    /// [`AccessIndex`]: crate::Expression::AccessIndex
180    /// [`Load`]: crate::Expression::Load
181    assignable_global: Option<Handle<crate::GlobalVariable>>,
182
183    /// The type of this expression.
184    pub ty: TypeResolution,
185}
186
187impl ExpressionInfo {
188    const fn new() -> Self {
189        ExpressionInfo {
190            uniformity: Uniformity::new(),
191            ref_count: 0,
192            assignable_global: None,
193            // this doesn't matter at this point, will be overwritten
194            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
195                kind: crate::ScalarKind::Bool,
196                width: 0,
197            })),
198        }
199    }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
204#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
205enum GlobalOrArgument {
206    Global(Handle<crate::GlobalVariable>),
207    Argument(u32),
208}
209
210impl GlobalOrArgument {
211    fn from_expression(
212        expression_arena: &Arena<crate::Expression>,
213        expression: Handle<crate::Expression>,
214    ) -> Result<GlobalOrArgument, ExpressionError> {
215        Ok(match expression_arena[expression] {
216            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
217            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
218            crate::Expression::Access { base, .. }
219            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
220                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
221                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
222            },
223            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
224        })
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
230#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
231struct Sampling {
232    image: GlobalOrArgument,
233    sampler: GlobalOrArgument,
234}
235
236#[derive(Debug, Clone)]
237#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
238#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
239pub struct FunctionInfo {
240    /// Validation flags.
241    #[allow(dead_code)]
242    flags: ValidationFlags,
243    /// Set of shader stages where calling this function is valid.
244    pub available_stages: ShaderStages,
245    /// Uniformity characteristics.
246    pub uniformity: Uniformity,
247    /// Function may kill the invocation.
248    pub may_kill: bool,
249
250    /// All pairs of (texture, sampler) globals that may be used together in
251    /// sampling operations by this function and its callees. This includes
252    /// pairings that arise when this function passes textures and samplers as
253    /// arguments to its callees.
254    ///
255    /// This table does not include uses of textures and samplers passed as
256    /// arguments to this function itself, since we do not know which globals
257    /// those will be. However, this table *is* exhaustive when computed for an
258    /// entry point function: entry points never receive textures or samplers as
259    /// arguments, so all an entry point's sampling can be reported in terms of
260    /// globals.
261    ///
262    /// The GLSL back end uses this table to construct reflection info that
263    /// clients need to construct texture-combined sampler values.
264    pub sampling_set: crate::FastHashSet<SamplingKey>,
265
266    /// How this function and its callees use this module's globals.
267    ///
268    /// This is indexed by `Handle<GlobalVariable>` indices. However,
269    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
270    /// so you can simply index this struct with a global handle to retrieve
271    /// its usage information.
272    global_uses: Box<[GlobalUse]>,
273
274    /// Information about each expression in this function's body.
275    ///
276    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
277    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
278    /// index this struct with an expression handle to retrieve its
279    /// `ExpressionInfo`.
280    expressions: Box<[ExpressionInfo]>,
281
282    /// All (texture, sampler) pairs that may be used together in sampling
283    /// operations by this function and its callees, whether they are accessed
284    /// as globals or passed as arguments.
285    ///
286    /// Participants are represented by [`GlobalVariable`] handles whenever
287    /// possible, and otherwise by indices of this function's arguments.
288    ///
289    /// When analyzing a function call, we combine this data about the callee
290    /// with the actual arguments being passed to produce the callers' own
291    /// `sampling_set` and `sampling` tables.
292    ///
293    /// [`GlobalVariable`]: crate::GlobalVariable
294    sampling: crate::FastHashSet<Sampling>,
295
296    /// Indicates that the function is using dual source blending.
297    pub dual_source_blending: bool,
298
299    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
300    /// module.
301    ///
302    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
303    /// validation.
304    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
305}
306
307impl FunctionInfo {
308    pub const fn global_variable_count(&self) -> usize {
309        self.global_uses.len()
310    }
311    pub const fn expression_count(&self) -> usize {
312        self.expressions.len()
313    }
314    pub fn dominates_global_use(&self, other: &Self) -> bool {
315        for (self_global_uses, other_global_uses) in
316            self.global_uses.iter().zip(other.global_uses.iter())
317        {
318            if !self_global_uses.contains(*other_global_uses) {
319                return false;
320            }
321        }
322        true
323    }
324}
325
326impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
327    type Output = GlobalUse;
328    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
329        &self.global_uses[handle.index()]
330    }
331}
332
333impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
334    type Output = ExpressionInfo;
335    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
336        &self.expressions[handle.index()]
337    }
338}
339
340/// Disruptor of the uniform control flow.
341#[derive(Clone, Copy, Debug, thiserror::Error)]
342#[cfg_attr(test, derive(PartialEq))]
343pub enum UniformityDisruptor {
344    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
345    Expression(Handle<crate::Expression>),
346    #[error("There is a Return earlier in the control flow of the function")]
347    Return,
348    #[error("There is a Discard earlier in the entry point across all called functions")]
349    Discard,
350}
351
352impl FunctionInfo {
353    /// Record a use of `expr` of the sort given by `global_use`.
354    ///
355    /// Bump `expr`'s reference count, and return its uniformity.
356    ///
357    /// If `expr` is a pointer to a global variable, or some part of
358    /// a global variable, add `global_use` to that global's set of
359    /// uses.
360    #[must_use]
361    fn add_ref_impl(
362        &mut self,
363        expr: Handle<crate::Expression>,
364        global_use: GlobalUse,
365    ) -> NonUniformResult {
366        let info = &mut self.expressions[expr.index()];
367        info.ref_count += 1;
368        // Record usage if this expression may access a global
369        if let Some(global) = info.assignable_global {
370            self.global_uses[global.index()] |= global_use;
371        }
372        info.uniformity.non_uniform_result
373    }
374
375    /// Record a use of `expr` for its value.
376    ///
377    /// This is used for almost all expression references. Anything
378    /// that writes to the value `expr` points to, or otherwise wants
379    /// contribute flags other than `GlobalUse::READ`, should use
380    /// `add_ref_impl` directly.
381    #[must_use]
382    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
383        self.add_ref_impl(expr, GlobalUse::READ)
384    }
385
386    /// Record a use of `expr`, and indicate which global variable it
387    /// refers to, if any.
388    ///
389    /// Bump `expr`'s reference count, and return its uniformity.
390    ///
391    /// If `expr` is a pointer to a global variable, or some part
392    /// thereof, store that global in `*assignable_global`. Leave the
393    /// global's uses unchanged.
394    ///
395    /// This is used to determine the [`assignable_global`] for
396    /// [`Access`] and [`AccessIndex`] expressions that ultimately
397    /// refer to a global variable. Those expressions don't contribute
398    /// any usage to the global themselves; that depends on how other
399    /// expressions use them.
400    ///
401    /// [`assignable_global`]: ExpressionInfo::assignable_global
402    /// [`Access`]: crate::Expression::Access
403    /// [`AccessIndex`]: crate::Expression::AccessIndex
404    #[must_use]
405    fn add_assignable_ref(
406        &mut self,
407        expr: Handle<crate::Expression>,
408        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
409    ) -> NonUniformResult {
410        let info = &mut self.expressions[expr.index()];
411        info.ref_count += 1;
412        // propagate the assignable global up the chain, till it either hits
413        // a value-type expression, or the assignment statement.
414        if let Some(global) = info.assignable_global {
415            if let Some(_old) = assignable_global.replace(global) {
416                unreachable!()
417            }
418        }
419        info.uniformity.non_uniform_result
420    }
421
422    /// Inherit information from a called function.
423    fn process_call(
424        &mut self,
425        callee: &Self,
426        arguments: &[Handle<crate::Expression>],
427        expression_arena: &Arena<crate::Expression>,
428    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
429        self.sampling_set
430            .extend(callee.sampling_set.iter().cloned());
431        for sampling in callee.sampling.iter() {
432            // If the callee was passed the texture or sampler as an argument,
433            // we may now be able to determine which globals those referred to.
434            let image_storage = match sampling.image {
435                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
436                GlobalOrArgument::Argument(i) => {
437                    let Some(handle) = arguments.get(i as usize).cloned() else {
438                        // Argument count mismatch, will be reported later by validate_call
439                        break;
440                    };
441                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
442                        |source| {
443                            FunctionError::Expression { handle, source }
444                                .with_span_handle(handle, expression_arena)
445                        },
446                    )?
447                }
448            };
449
450            let sampler_storage = match sampling.sampler {
451                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
452                GlobalOrArgument::Argument(i) => {
453                    let Some(handle) = arguments.get(i as usize).cloned() else {
454                        // Argument count mismatch, will be reported later by validate_call
455                        break;
456                    };
457                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
458                        |source| {
459                            FunctionError::Expression { handle, source }
460                                .with_span_handle(handle, expression_arena)
461                        },
462                    )?
463                }
464            };
465
466            // If we've managed to pin both the image and sampler down to
467            // specific globals, record that in our `sampling_set`. Otherwise,
468            // record as much as we do know in our own `sampling` table, for our
469            // callers to sort out.
470            match (image_storage, sampler_storage) {
471                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
472                    self.sampling_set.insert(SamplingKey { image, sampler });
473                }
474                (image, sampler) => {
475                    self.sampling.insert(Sampling { image, sampler });
476                }
477            }
478        }
479
480        // Inherit global use from our callees.
481        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
482            *mine |= *other;
483        }
484
485        Ok(FunctionUniformity {
486            result: callee.uniformity.clone(),
487            exit: if callee.may_kill {
488                ExitFlags::MAY_KILL
489            } else {
490                ExitFlags::empty()
491            },
492        })
493    }
494
495    /// Compute the [`ExpressionInfo`] for `handle`.
496    ///
497    /// Replace the dummy entry in [`self.expressions`] for `handle`
498    /// with a real `ExpressionInfo` value describing that expression.
499    ///
500    /// This function is called as part of a forward sweep through the
501    /// arena, so we can assume that all earlier expressions in the
502    /// arena already have valid info. Since expressions only depend
503    /// on earlier expressions, this includes all our subexpressions.
504    ///
505    /// Adjust the reference counts on all expressions we use.
506    ///
507    /// Also populate the [`sampling_set`], [`sampling`] and
508    /// [`global_uses`] fields of `self`.
509    ///
510    /// [`self.expressions`]: FunctionInfo::expressions
511    /// [`sampling_set`]: FunctionInfo::sampling_set
512    /// [`sampling`]: FunctionInfo::sampling
513    /// [`global_uses`]: FunctionInfo::global_uses
514    #[allow(clippy::or_fun_call)]
515    fn process_expression(
516        &mut self,
517        handle: Handle<crate::Expression>,
518        expression_arena: &Arena<crate::Expression>,
519        other_functions: &[FunctionInfo],
520        resolve_context: &ResolveContext,
521        capabilities: super::Capabilities,
522    ) -> Result<(), ExpressionError> {
523        use crate::{Expression as E, SampleLevel as Sl};
524
525        let expression = &expression_arena[handle];
526        let mut assignable_global = None;
527        let uniformity = match *expression {
528            E::Access { base, index } => {
529                let base_ty = self[base].ty.inner_with(resolve_context.types);
530
531                // build up the caps needed if this is indexed non-uniformly
532                let mut needed_caps = super::Capabilities::empty();
533                let is_binding_array = match *base_ty {
534                    crate::TypeInner::BindingArray {
535                        base: array_element_ty_handle,
536                        ..
537                    } => {
538                        // these are nasty aliases, but these idents are too long and break rustfmt
539                        let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
540                        let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
541                        let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
542                        let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
543
544                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
545                        let array_element_ty =
546                            &resolve_context.types[array_element_ty_handle].inner;
547
548                        needed_caps |= match *array_element_ty {
549                            // If we're an image, use the appropriate limit.
550                            crate::TypeInner::Image { class, .. } => match class {
551                                crate::ImageClass::Storage { .. } => sto,
552                                _ => st_sb,
553                            },
554                            crate::TypeInner::Sampler { .. } => sampler,
555                            // If we're anything but an image, assume we're a buffer and use the address space.
556                            _ => {
557                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
558                                    let global = &resolve_context.global_vars[global_handle];
559                                    match global.space {
560                                        crate::AddressSpace::Uniform => uni,
561                                        crate::AddressSpace::Storage { .. } => st_sb,
562                                        _ => unreachable!(),
563                                    }
564                                } else {
565                                    unreachable!()
566                                }
567                            }
568                        };
569
570                        true
571                    }
572                    _ => false,
573                };
574
575                if self[index].uniformity.non_uniform_result.is_some()
576                    && !capabilities.contains(needed_caps)
577                    && is_binding_array
578                {
579                    return Err(ExpressionError::MissingCapabilities(needed_caps));
580                }
581
582                Uniformity {
583                    non_uniform_result: self
584                        .add_assignable_ref(base, &mut assignable_global)
585                        .or(self.add_ref(index)),
586                    requirements: UniformityRequirements::empty(),
587                }
588            }
589            E::AccessIndex { base, .. } => Uniformity {
590                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
591                requirements: UniformityRequirements::empty(),
592            },
593            // always uniform
594            E::Splat { size: _, value } => Uniformity {
595                non_uniform_result: self.add_ref(value),
596                requirements: UniformityRequirements::empty(),
597            },
598            E::Swizzle { vector, .. } => Uniformity {
599                non_uniform_result: self.add_ref(vector),
600                requirements: UniformityRequirements::empty(),
601            },
602            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
603            E::Compose { ref components, .. } => {
604                let non_uniform_result = components
605                    .iter()
606                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
607                Uniformity {
608                    non_uniform_result,
609                    requirements: UniformityRequirements::empty(),
610                }
611            }
612            // depends on the builtin
613            E::FunctionArgument(index) => {
614                let arg = &resolve_context.arguments[index as usize];
615                let uniform = match arg.binding {
616                    Some(crate::Binding::BuiltIn(
617                        // per-work-group built-ins are uniform
618                        crate::BuiltIn::WorkGroupId
619                        | crate::BuiltIn::WorkGroupSize
620                        | crate::BuiltIn::NumWorkGroups,
621                    )) => true,
622                    _ => false,
623                };
624                Uniformity {
625                    non_uniform_result: if uniform { None } else { Some(handle) },
626                    requirements: UniformityRequirements::empty(),
627                }
628            }
629            // depends on the address space
630            E::GlobalVariable(gh) => {
631                use crate::AddressSpace as As;
632                assignable_global = Some(gh);
633                let var = &resolve_context.global_vars[gh];
634                let uniform = match var.space {
635                    // local data is non-uniform
636                    As::Function | As::Private => false,
637                    // workgroup memory is exclusively accessed by the group
638                    As::WorkGroup => true,
639                    // uniform data
640                    As::Uniform | As::PushConstant => true,
641                    // storage data is only uniform when read-only
642                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
643                    As::Handle => false,
644                };
645                Uniformity {
646                    non_uniform_result: if uniform { None } else { Some(handle) },
647                    requirements: UniformityRequirements::empty(),
648                }
649            }
650            E::LocalVariable(_) => Uniformity {
651                non_uniform_result: Some(handle),
652                requirements: UniformityRequirements::empty(),
653            },
654            E::Load { pointer } => Uniformity {
655                non_uniform_result: self.add_ref(pointer),
656                requirements: UniformityRequirements::empty(),
657            },
658            E::ImageSample {
659                image,
660                sampler,
661                gather: _,
662                coordinate,
663                array_index,
664                offset,
665                level,
666                depth_ref,
667                clamp_to_edge: _,
668            } => {
669                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
670                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
671
672                match (image_storage, sampler_storage) {
673                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
674                        self.sampling_set.insert(SamplingKey { image, sampler });
675                    }
676                    _ => {
677                        self.sampling.insert(Sampling {
678                            image: image_storage,
679                            sampler: sampler_storage,
680                        });
681                    }
682                }
683
684                // "nur" == "Non-Uniform Result"
685                let array_nur = array_index.and_then(|h| self.add_ref(h));
686                let level_nur = match level {
687                    Sl::Auto | Sl::Zero => None,
688                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
689                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
690                };
691                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
692                let offset_nur = offset.and_then(|h| self.add_ref(h));
693                Uniformity {
694                    non_uniform_result: self
695                        .add_ref(image)
696                        .or(self.add_ref(sampler))
697                        .or(self.add_ref(coordinate))
698                        .or(array_nur)
699                        .or(level_nur)
700                        .or(dref_nur)
701                        .or(offset_nur),
702                    requirements: if level.implicit_derivatives() {
703                        UniformityRequirements::IMPLICIT_LEVEL
704                    } else {
705                        UniformityRequirements::empty()
706                    },
707                }
708            }
709            E::ImageLoad {
710                image,
711                coordinate,
712                array_index,
713                sample,
714                level,
715            } => {
716                let array_nur = array_index.and_then(|h| self.add_ref(h));
717                let sample_nur = sample.and_then(|h| self.add_ref(h));
718                let level_nur = level.and_then(|h| self.add_ref(h));
719                Uniformity {
720                    non_uniform_result: self
721                        .add_ref(image)
722                        .or(self.add_ref(coordinate))
723                        .or(array_nur)
724                        .or(sample_nur)
725                        .or(level_nur),
726                    requirements: UniformityRequirements::empty(),
727                }
728            }
729            E::ImageQuery { image, query } => {
730                let query_nur = match query {
731                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
732                    _ => None,
733                };
734                Uniformity {
735                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
736                    requirements: UniformityRequirements::empty(),
737                }
738            }
739            E::Unary { expr, .. } => Uniformity {
740                non_uniform_result: self.add_ref(expr),
741                requirements: UniformityRequirements::empty(),
742            },
743            E::Binary { left, right, .. } => Uniformity {
744                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
745                requirements: UniformityRequirements::empty(),
746            },
747            E::Select {
748                condition,
749                accept,
750                reject,
751            } => Uniformity {
752                non_uniform_result: self
753                    .add_ref(condition)
754                    .or(self.add_ref(accept))
755                    .or(self.add_ref(reject)),
756                requirements: UniformityRequirements::empty(),
757            },
758            // explicit derivatives require uniform
759            E::Derivative { expr, .. } => Uniformity {
760                //Note: taking a derivative of a uniform doesn't make it non-uniform
761                non_uniform_result: self.add_ref(expr),
762                requirements: UniformityRequirements::DERIVATIVE,
763            },
764            E::Relational { argument, .. } => Uniformity {
765                non_uniform_result: self.add_ref(argument),
766                requirements: UniformityRequirements::empty(),
767            },
768            E::Math {
769                fun: _,
770                arg,
771                arg1,
772                arg2,
773                arg3,
774            } => {
775                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
776                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
777                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
778                Uniformity {
779                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
780                    requirements: UniformityRequirements::empty(),
781                }
782            }
783            E::As { expr, .. } => Uniformity {
784                non_uniform_result: self.add_ref(expr),
785                requirements: UniformityRequirements::empty(),
786            },
787            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
788            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
789                non_uniform_result: Some(handle),
790                requirements: UniformityRequirements::empty(),
791            },
792            E::WorkGroupUniformLoadResult { .. } => Uniformity {
793                // The result of WorkGroupUniformLoad is always uniform by definition
794                non_uniform_result: None,
795                // The call is what cares about uniformity, not the expression
796                // This expression is never emitted, so this requirement should never be used anyway?
797                requirements: UniformityRequirements::empty(),
798            },
799            E::ArrayLength(expr) => Uniformity {
800                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
801                requirements: UniformityRequirements::empty(),
802            },
803            E::RayQueryGetIntersection {
804                query,
805                committed: _,
806            } => Uniformity {
807                non_uniform_result: self.add_ref(query),
808                requirements: UniformityRequirements::empty(),
809            },
810            E::SubgroupBallotResult => Uniformity {
811                non_uniform_result: Some(handle),
812                requirements: UniformityRequirements::empty(),
813            },
814            E::SubgroupOperationResult { .. } => Uniformity {
815                non_uniform_result: Some(handle),
816                requirements: UniformityRequirements::empty(),
817            },
818            E::RayQueryVertexPositions {
819                query,
820                committed: _,
821            } => Uniformity {
822                non_uniform_result: self.add_ref(query),
823                requirements: UniformityRequirements::empty(),
824            },
825        };
826
827        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
828        self.expressions[handle.index()] = ExpressionInfo {
829            uniformity,
830            ref_count: 0,
831            assignable_global,
832            ty,
833        };
834        Ok(())
835    }
836
837    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
838    /// Returns the uniformity characteristics at the *function* level, i.e.
839    /// whether or not the function requires to be called in uniform control flow,
840    /// and whether the produced result is not disrupting the control flow.
841    ///
842    /// The parent control flow is uniform if `disruptor.is_none()`.
843    ///
844    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
845    /// require uniformity, but the current flow is non-uniform.
846    #[allow(clippy::or_fun_call)]
847    fn process_block(
848        &mut self,
849        statements: &crate::Block,
850        other_functions: &[FunctionInfo],
851        mut disruptor: Option<UniformityDisruptor>,
852        expression_arena: &Arena<crate::Expression>,
853        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
854    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
855        use crate::Statement as S;
856
857        let mut combined_uniformity = FunctionUniformity::new();
858        for statement in statements {
859            let uniformity = match *statement {
860                S::Emit(ref range) => {
861                    let mut requirements = UniformityRequirements::empty();
862                    for expr in range.clone() {
863                        let req = self.expressions[expr.index()].uniformity.requirements;
864                        if self
865                            .flags
866                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
867                            && !req.is_empty()
868                        {
869                            if let Some(cause) = disruptor {
870                                let severity = DiagnosticFilterNode::search(
871                                    self.diagnostic_filter_leaf,
872                                    diagnostic_filter_arena,
873                                    StandardFilterableTriggeringRule::DerivativeUniformity,
874                                );
875                                severity.report_diag(
876                                    FunctionError::NonUniformControlFlow(req, expr, cause)
877                                        .with_span_handle(expr, expression_arena),
878                                    // TODO: Yes, this isn't contextualized with source, because
879                                    // the user is supposed to render what would normally be an
880                                    // error here. Once we actually support warning-level
881                                    // diagnostic items, then we won't need this non-compliant hack:
882                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
883                                    |e, level| log::log!(level, "{e}"),
884                                )?;
885                            }
886                        }
887                        requirements |= req;
888                    }
889                    FunctionUniformity {
890                        result: Uniformity {
891                            non_uniform_result: None,
892                            requirements,
893                        },
894                        exit: ExitFlags::empty(),
895                    }
896                }
897                S::Break | S::Continue => FunctionUniformity::new(),
898                S::Kill => FunctionUniformity {
899                    result: Uniformity::new(),
900                    exit: if disruptor.is_some() {
901                        ExitFlags::MAY_KILL
902                    } else {
903                        ExitFlags::empty()
904                    },
905                },
906                S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
907                    result: Uniformity {
908                        non_uniform_result: None,
909                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
910                    },
911                    exit: ExitFlags::empty(),
912                },
913                S::WorkGroupUniformLoad { pointer, .. } => {
914                    let _condition_nur = self.add_ref(pointer);
915
916                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
917                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
918                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
919                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
920
921                    /*
922                    if self
923                        .flags
924                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
925                    {
926                        let condition_nur = self.add_ref(pointer);
927                        let this_disruptor =
928                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
929                        if let Some(cause) = this_disruptor {
930                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
931                                .with_span_static(*span, "WorkGroupUniformLoad"));
932                        }
933                    } */
934                    FunctionUniformity {
935                        result: Uniformity {
936                            non_uniform_result: None,
937                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
938                        },
939                        exit: ExitFlags::empty(),
940                    }
941                }
942                S::Block(ref b) => self.process_block(
943                    b,
944                    other_functions,
945                    disruptor,
946                    expression_arena,
947                    diagnostic_filter_arena,
948                )?,
949                S::If {
950                    condition,
951                    ref accept,
952                    ref reject,
953                } => {
954                    let condition_nur = self.add_ref(condition);
955                    let branch_disruptor =
956                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
957                    let accept_uniformity = self.process_block(
958                        accept,
959                        other_functions,
960                        branch_disruptor,
961                        expression_arena,
962                        diagnostic_filter_arena,
963                    )?;
964                    let reject_uniformity = self.process_block(
965                        reject,
966                        other_functions,
967                        branch_disruptor,
968                        expression_arena,
969                        diagnostic_filter_arena,
970                    )?;
971                    accept_uniformity | reject_uniformity
972                }
973                S::Switch {
974                    selector,
975                    ref cases,
976                } => {
977                    let selector_nur = self.add_ref(selector);
978                    let branch_disruptor =
979                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
980                    let mut uniformity = FunctionUniformity::new();
981                    let mut case_disruptor = branch_disruptor;
982                    for case in cases.iter() {
983                        let case_uniformity = self.process_block(
984                            &case.body,
985                            other_functions,
986                            case_disruptor,
987                            expression_arena,
988                            diagnostic_filter_arena,
989                        )?;
990                        case_disruptor = if case.fall_through {
991                            case_disruptor.or(case_uniformity.exit_disruptor())
992                        } else {
993                            branch_disruptor
994                        };
995                        uniformity = uniformity | case_uniformity;
996                    }
997                    uniformity
998                }
999                S::Loop {
1000                    ref body,
1001                    ref continuing,
1002                    break_if,
1003                } => {
1004                    let body_uniformity = self.process_block(
1005                        body,
1006                        other_functions,
1007                        disruptor,
1008                        expression_arena,
1009                        diagnostic_filter_arena,
1010                    )?;
1011                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1012                    let continuing_uniformity = self.process_block(
1013                        continuing,
1014                        other_functions,
1015                        continuing_disruptor,
1016                        expression_arena,
1017                        diagnostic_filter_arena,
1018                    )?;
1019                    if let Some(expr) = break_if {
1020                        let _ = self.add_ref(expr);
1021                    }
1022                    body_uniformity | continuing_uniformity
1023                }
1024                S::Return { value } => FunctionUniformity {
1025                    result: Uniformity {
1026                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1027                        requirements: UniformityRequirements::empty(),
1028                    },
1029                    exit: if disruptor.is_some() {
1030                        ExitFlags::MAY_RETURN
1031                    } else {
1032                        ExitFlags::empty()
1033                    },
1034                },
1035                // Here and below, the used expressions are already emitted,
1036                // and their results do not affect the function return value,
1037                // so we can ignore their non-uniformity.
1038                S::Store { pointer, value } => {
1039                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1040                    let _ = self.add_ref(value);
1041                    FunctionUniformity::new()
1042                }
1043                S::ImageStore {
1044                    image,
1045                    coordinate,
1046                    array_index,
1047                    value,
1048                } => {
1049                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1050                    if let Some(expr) = array_index {
1051                        let _ = self.add_ref(expr);
1052                    }
1053                    let _ = self.add_ref(coordinate);
1054                    let _ = self.add_ref(value);
1055                    FunctionUniformity::new()
1056                }
1057                S::Call {
1058                    function,
1059                    ref arguments,
1060                    result: _,
1061                } => {
1062                    for &argument in arguments {
1063                        let _ = self.add_ref(argument);
1064                    }
1065                    let info = &other_functions[function.index()];
1066                    //Note: the result is validated by the Validator, not here
1067                    self.process_call(info, arguments, expression_arena)?
1068                }
1069                S::Atomic {
1070                    pointer,
1071                    ref fun,
1072                    value,
1073                    result: _,
1074                } => {
1075                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1076                    let _ = self.add_ref(value);
1077                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1078                        let _ = self.add_ref(cmp);
1079                    }
1080                    FunctionUniformity::new()
1081                }
1082                S::ImageAtomic {
1083                    image,
1084                    coordinate,
1085                    array_index,
1086                    fun: _,
1087                    value,
1088                } => {
1089                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1090                    let _ = self.add_ref(coordinate);
1091                    if let Some(expr) = array_index {
1092                        let _ = self.add_ref(expr);
1093                    }
1094                    let _ = self.add_ref(value);
1095                    FunctionUniformity::new()
1096                }
1097                S::RayQuery { query, ref fun } => {
1098                    let _ = self.add_ref(query);
1099                    match *fun {
1100                        crate::RayQueryFunction::Initialize {
1101                            acceleration_structure,
1102                            descriptor,
1103                        } => {
1104                            let _ = self.add_ref(acceleration_structure);
1105                            let _ = self.add_ref(descriptor);
1106                        }
1107                        crate::RayQueryFunction::Proceed { result: _ } => {}
1108                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1109                            let _ = self.add_ref(hit_t);
1110                        }
1111                        crate::RayQueryFunction::ConfirmIntersection => {}
1112                        crate::RayQueryFunction::Terminate => {}
1113                    }
1114                    FunctionUniformity::new()
1115                }
1116                S::SubgroupBallot {
1117                    result: _,
1118                    predicate,
1119                } => {
1120                    if let Some(predicate) = predicate {
1121                        let _ = self.add_ref(predicate);
1122                    }
1123                    FunctionUniformity::new()
1124                }
1125                S::SubgroupCollectiveOperation {
1126                    op: _,
1127                    collective_op: _,
1128                    argument,
1129                    result: _,
1130                } => {
1131                    let _ = self.add_ref(argument);
1132                    FunctionUniformity::new()
1133                }
1134                S::SubgroupGather {
1135                    mode,
1136                    argument,
1137                    result: _,
1138                } => {
1139                    let _ = self.add_ref(argument);
1140                    match mode {
1141                        crate::GatherMode::BroadcastFirst => {}
1142                        crate::GatherMode::Broadcast(index)
1143                        | crate::GatherMode::Shuffle(index)
1144                        | crate::GatherMode::ShuffleDown(index)
1145                        | crate::GatherMode::ShuffleUp(index)
1146                        | crate::GatherMode::ShuffleXor(index)
1147                        | crate::GatherMode::QuadBroadcast(index) => {
1148                            let _ = self.add_ref(index);
1149                        }
1150                        crate::GatherMode::QuadSwap(_) => {}
1151                    }
1152                    FunctionUniformity::new()
1153                }
1154            };
1155
1156            disruptor = disruptor.or(uniformity.exit_disruptor());
1157            combined_uniformity = combined_uniformity | uniformity;
1158        }
1159        Ok(combined_uniformity)
1160    }
1161}
1162
1163impl ModuleInfo {
1164    /// Populates `self.const_expression_types`
1165    pub(super) fn process_const_expression(
1166        &mut self,
1167        handle: Handle<crate::Expression>,
1168        resolve_context: &ResolveContext,
1169        gctx: crate::proc::GlobalCtx,
1170    ) -> Result<(), super::ConstExpressionError> {
1171        self.const_expression_types[handle.index()] =
1172            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1173        Ok(())
1174    }
1175
1176    /// Builds the `FunctionInfo` based on the function, and validates the
1177    /// uniform control flow if required by the expressions of this function.
1178    pub(super) fn process_function(
1179        &self,
1180        fun: &crate::Function,
1181        module: &crate::Module,
1182        flags: ValidationFlags,
1183        capabilities: super::Capabilities,
1184    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1185        let mut info = FunctionInfo {
1186            flags,
1187            available_stages: ShaderStages::all(),
1188            uniformity: Uniformity::new(),
1189            may_kill: false,
1190            sampling_set: crate::FastHashSet::default(),
1191            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1192            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1193            sampling: crate::FastHashSet::default(),
1194            dual_source_blending: false,
1195            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1196        };
1197        let resolve_context =
1198            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1199
1200        for (handle, _) in fun.expressions.iter() {
1201            if let Err(source) = info.process_expression(
1202                handle,
1203                &fun.expressions,
1204                &self.functions,
1205                &resolve_context,
1206                capabilities,
1207            ) {
1208                return Err(FunctionError::Expression { handle, source }
1209                    .with_span_handle(handle, &fun.expressions));
1210            }
1211        }
1212
1213        for (_, expr) in fun.local_variables.iter() {
1214            if let Some(init) = expr.init {
1215                let _ = info.add_ref(init);
1216            }
1217        }
1218
1219        let uniformity = info.process_block(
1220            &fun.body,
1221            &self.functions,
1222            None,
1223            &fun.expressions,
1224            &module.diagnostic_filters,
1225        )?;
1226        info.uniformity = uniformity.result;
1227        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1228
1229        // If there are any globals referenced directly by a named expression,
1230        // ensure they are marked as used even if they are not referenced
1231        // anywhere else. An important case where this matters is phony
1232        // assignments used to include a global in the shader's resource
1233        // interface. https://www.w3.org/TR/WGSL/#phony-assignment-section
1234        for &handle in fun.named_expressions.keys() {
1235            if let Some(global) = info[handle].assignable_global {
1236                if info.global_uses[global.index()].is_empty() {
1237                    info.global_uses[global.index()] = GlobalUse::QUERY;
1238                }
1239            }
1240        }
1241
1242        Ok(info)
1243    }
1244
1245    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1246        &self.entry_points[index]
1247    }
1248}
1249
1250#[test]
1251fn uniform_control_flow() {
1252    use crate::{Expression as E, Statement as S};
1253
1254    let mut type_arena = crate::UniqueArena::new();
1255    let ty = type_arena.insert(
1256        crate::Type {
1257            name: None,
1258            inner: crate::TypeInner::Vector {
1259                size: crate::VectorSize::Bi,
1260                scalar: crate::Scalar::F32,
1261            },
1262        },
1263        Default::default(),
1264    );
1265    let mut global_var_arena = Arena::new();
1266    let non_uniform_global = global_var_arena.append(
1267        crate::GlobalVariable {
1268            name: None,
1269            init: None,
1270            ty,
1271            space: crate::AddressSpace::Handle,
1272            binding: None,
1273        },
1274        Default::default(),
1275    );
1276    let uniform_global = global_var_arena.append(
1277        crate::GlobalVariable {
1278            name: None,
1279            init: None,
1280            ty,
1281            binding: None,
1282            space: crate::AddressSpace::Uniform,
1283        },
1284        Default::default(),
1285    );
1286
1287    let mut expressions = Arena::new();
1288    // checks the uniform control flow
1289    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1290    // checks the non-uniform control flow
1291    let derivative_expr = expressions.append(
1292        E::Derivative {
1293            axis: crate::DerivativeAxis::X,
1294            ctrl: crate::DerivativeControl::None,
1295            expr: constant_expr,
1296        },
1297        Default::default(),
1298    );
1299    let emit_range_constant_derivative = expressions.range_from(0);
1300    let non_uniform_global_expr =
1301        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1302    let uniform_global_expr =
1303        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1304    let emit_range_globals = expressions.range_from(2);
1305
1306    // checks the QUERY flag
1307    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1308    // checks the transitive WRITE flag
1309    let access_expr = expressions.append(
1310        E::AccessIndex {
1311            base: non_uniform_global_expr,
1312            index: 1,
1313        },
1314        Default::default(),
1315    );
1316    let emit_range_query_access_globals = expressions.range_from(2);
1317
1318    let mut info = FunctionInfo {
1319        flags: ValidationFlags::all(),
1320        available_stages: ShaderStages::all(),
1321        uniformity: Uniformity::new(),
1322        may_kill: false,
1323        sampling_set: crate::FastHashSet::default(),
1324        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1325        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1326        sampling: crate::FastHashSet::default(),
1327        dual_source_blending: false,
1328        diagnostic_filter_leaf: None,
1329    };
1330    let resolve_context = ResolveContext {
1331        constants: &Arena::new(),
1332        overrides: &Arena::new(),
1333        types: &type_arena,
1334        special_types: &crate::SpecialTypes::default(),
1335        global_vars: &global_var_arena,
1336        local_vars: &Arena::new(),
1337        functions: &Arena::new(),
1338        arguments: &[],
1339    };
1340    for (handle, _) in expressions.iter() {
1341        info.process_expression(
1342            handle,
1343            &expressions,
1344            &[],
1345            &resolve_context,
1346            super::Capabilities::empty(),
1347        )
1348        .unwrap();
1349    }
1350    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1351    assert_eq!(info[uniform_global_expr].ref_count, 1);
1352    assert_eq!(info[query_expr].ref_count, 0);
1353    assert_eq!(info[access_expr].ref_count, 0);
1354    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1355    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1356
1357    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1358    let stmt_if_uniform = S::If {
1359        condition: uniform_global_expr,
1360        accept: crate::Block::new(),
1361        reject: vec![
1362            S::Emit(emit_range_constant_derivative.clone()),
1363            S::Store {
1364                pointer: constant_expr,
1365                value: derivative_expr,
1366            },
1367        ]
1368        .into(),
1369    };
1370    assert_eq!(
1371        info.process_block(
1372            &vec![stmt_emit1, stmt_if_uniform].into(),
1373            &[],
1374            None,
1375            &expressions,
1376            &Arena::new(),
1377        ),
1378        Ok(FunctionUniformity {
1379            result: Uniformity {
1380                non_uniform_result: None,
1381                requirements: UniformityRequirements::DERIVATIVE,
1382            },
1383            exit: ExitFlags::empty(),
1384        }),
1385    );
1386    assert_eq!(info[constant_expr].ref_count, 2);
1387    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1388
1389    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1390    let stmt_if_non_uniform = S::If {
1391        condition: non_uniform_global_expr,
1392        accept: vec![
1393            S::Emit(emit_range_constant_derivative),
1394            S::Store {
1395                pointer: constant_expr,
1396                value: derivative_expr,
1397            },
1398        ]
1399        .into(),
1400        reject: crate::Block::new(),
1401    };
1402    {
1403        let block_info = info.process_block(
1404            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1405            &[],
1406            None,
1407            &expressions,
1408            &Arena::new(),
1409        );
1410        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1411            assert_eq!(info[derivative_expr].ref_count, 2);
1412        } else {
1413            assert_eq!(
1414                block_info,
1415                Err(FunctionError::NonUniformControlFlow(
1416                    UniformityRequirements::DERIVATIVE,
1417                    derivative_expr,
1418                    UniformityDisruptor::Expression(non_uniform_global_expr)
1419                )
1420                .with_span()),
1421            );
1422            assert_eq!(info[derivative_expr].ref_count, 1);
1423
1424            // Test that the same thing passes when we disable the `derivative_uniformity`
1425            let mut diagnostic_filters = Arena::new();
1426            let diagnostic_filter_leaf = diagnostic_filters.append(
1427                DiagnosticFilterNode {
1428                    inner: crate::diagnostic_filter::DiagnosticFilter {
1429                        new_severity: crate::diagnostic_filter::Severity::Off,
1430                        triggering_rule:
1431                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1432                                StandardFilterableTriggeringRule::DerivativeUniformity,
1433                            ),
1434                    },
1435                    parent: None,
1436                },
1437                crate::Span::default(),
1438            );
1439            let mut info = FunctionInfo {
1440                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1441                ..info.clone()
1442            };
1443
1444            let block_info = info.process_block(
1445                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1446                &[],
1447                None,
1448                &expressions,
1449                &diagnostic_filters,
1450            );
1451            assert_eq!(
1452                block_info,
1453                Ok(FunctionUniformity {
1454                    result: Uniformity {
1455                        non_uniform_result: None,
1456                        requirements: UniformityRequirements::DERIVATIVE,
1457                    },
1458                    exit: ExitFlags::empty()
1459                }),
1460            );
1461            assert_eq!(info[derivative_expr].ref_count, 2);
1462        }
1463    }
1464    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1465
1466    let stmt_emit3 = S::Emit(emit_range_globals);
1467    let stmt_return_non_uniform = S::Return {
1468        value: Some(non_uniform_global_expr),
1469    };
1470    assert_eq!(
1471        info.process_block(
1472            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1473            &[],
1474            Some(UniformityDisruptor::Return),
1475            &expressions,
1476            &Arena::new(),
1477        ),
1478        Ok(FunctionUniformity {
1479            result: Uniformity {
1480                non_uniform_result: Some(non_uniform_global_expr),
1481                requirements: UniformityRequirements::empty(),
1482            },
1483            exit: ExitFlags::MAY_RETURN,
1484        }),
1485    );
1486    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1487
1488    // Check that uniformity requirements reach through a pointer
1489    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1490    let stmt_assign = S::Store {
1491        pointer: access_expr,
1492        value: query_expr,
1493    };
1494    let stmt_return_pointer = S::Return {
1495        value: Some(access_expr),
1496    };
1497    let stmt_kill = S::Kill;
1498    assert_eq!(
1499        info.process_block(
1500            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1501            &[],
1502            Some(UniformityDisruptor::Discard),
1503            &expressions,
1504            &Arena::new(),
1505        ),
1506        Ok(FunctionUniformity {
1507            result: Uniformity {
1508                non_uniform_result: Some(non_uniform_global_expr),
1509                requirements: UniformityRequirements::empty(),
1510            },
1511            exit: ExitFlags::all(),
1512        }),
1513    );
1514    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1515}