naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{
12    ExpressionError, FunctionError, ImmediateSlots, ModuleInfo, ShaderStages, ValidationFlags,
13};
14use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
15use crate::span::{AddSpan as _, WithSpan};
16use crate::{
17    arena::{Arena, Handle},
18    proc::{ResolveContext, TypeResolution},
19};
20
21pub type NonUniformResult = Option<Handle<crate::Expression>>;
22
23const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
24
25bitflags::bitflags! {
26    /// Kinds of expressions that require uniform control flow.
27    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
28    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
29    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
30    pub struct UniformityRequirements: u8 {
31        const WORK_GROUP_BARRIER = 0x1;
32        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
33        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
34        const COOP_OPS = 0x8;
35    }
36}
37
38/// Uniform control flow characteristics.
39#[derive(Clone, Debug)]
40#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
41#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct Uniformity {
44    /// A child expression with non-uniform result.
45    ///
46    /// This means, when the relevant invocations are scheduled on a compute unit,
47    /// they have to use vector registers to store an individual value
48    /// per invocation.
49    ///
50    /// Whenever the control flow is conditioned on such value,
51    /// the hardware needs to keep track of the mask of invocations,
52    /// and process all branches of the control flow.
53    ///
54    /// Any operations that depend on non-uniform results also produce non-uniform.
55    pub non_uniform_result: NonUniformResult,
56    /// If this expression requires uniform control flow, store the reason here.
57    pub requirements: UniformityRequirements,
58}
59
60impl Uniformity {
61    const fn new() -> Self {
62        Uniformity {
63            non_uniform_result: None,
64            requirements: UniformityRequirements::empty(),
65        }
66    }
67}
68
69bitflags::bitflags! {
70    #[derive(Clone, Copy, Debug, PartialEq)]
71    struct ExitFlags: u8 {
72        /// Control flow may return from the function, which makes all the
73        /// subsequent statements within the current function (only!)
74        /// to be executed in a non-uniform control flow.
75        const MAY_RETURN = 0x1;
76        /// Control flow may be killed. Anything after [`Statement::Kill`] is
77        /// considered inside non-uniform context.
78        ///
79        /// [`Statement::Kill`]: crate::Statement::Kill
80        const MAY_KILL = 0x2;
81    }
82}
83
84/// Uniformity characteristics of a function.
85#[cfg_attr(test, derive(Debug, PartialEq))]
86struct FunctionUniformity {
87    result: Uniformity,
88    exit: ExitFlags,
89}
90
91impl ops::BitOr for FunctionUniformity {
92    type Output = Self;
93    fn bitor(self, other: Self) -> Self {
94        FunctionUniformity {
95            result: Uniformity {
96                non_uniform_result: self
97                    .result
98                    .non_uniform_result
99                    .or(other.result.non_uniform_result),
100                requirements: self.result.requirements | other.result.requirements,
101            },
102            exit: self.exit | other.exit,
103        }
104    }
105}
106
107impl FunctionUniformity {
108    const fn new() -> Self {
109        FunctionUniformity {
110            result: Uniformity::new(),
111            exit: ExitFlags::empty(),
112        }
113    }
114
115    /// Returns a disruptor based on the stored exit flags, if any.
116    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
117        if self.exit.contains(ExitFlags::MAY_RETURN) {
118            Some(UniformityDisruptor::Return)
119        } else if self.exit.contains(ExitFlags::MAY_KILL) {
120            Some(UniformityDisruptor::Discard)
121        } else {
122            None
123        }
124    }
125}
126
127bitflags::bitflags! {
128    /// Indicates how a global variable is used.
129    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
130    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
131    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
132    pub struct GlobalUse: u8 {
133        /// Data will be read from the variable.
134        const READ = 0x1;
135        /// Data will be written to the variable.
136        const WRITE = 0x2;
137        /// The information about the data is queried.
138        const QUERY = 0x4;
139        /// Atomic operations will be performed on the variable.
140        const ATOMIC = 0x8;
141    }
142}
143
144#[derive(Clone, Debug, Eq, Hash, PartialEq)]
145#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147pub struct SamplingKey {
148    pub image: Handle<crate::GlobalVariable>,
149    pub sampler: Handle<crate::GlobalVariable>,
150}
151
152#[derive(Clone, Debug)]
153#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
154#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
155/// Information about an expression in a function body.
156pub struct ExpressionInfo {
157    /// Whether this expression is uniform, and why.
158    ///
159    /// If this expression's value is not uniform, this is the handle
160    /// of the expression from which this one's non-uniformity
161    /// originates. Otherwise, this is `None`.
162    pub uniformity: Uniformity,
163
164    /// The number of direct references to this expression in statements and
165    /// other expressions.
166    ///
167    /// This is a _local_ reference count only, it may be non-zero for
168    /// expressions that are ultimately unused.
169    pub ref_count: usize,
170
171    /// The global variable into which this expression produces a pointer.
172    ///
173    /// This is `None` unless this expression is either a
174    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
175    /// ultimately refers to some part of a global.
176    ///
177    /// [`Load`] expressions applied to pointer-typed arguments could
178    /// refer to globals, but we leave this as `None` for them.
179    ///
180    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
181    /// [`Access`]: crate::Expression::Access
182    /// [`AccessIndex`]: crate::Expression::AccessIndex
183    /// [`Load`]: crate::Expression::Load
184    assignable_global: Option<Handle<crate::GlobalVariable>>,
185
186    /// The type of this expression.
187    pub ty: TypeResolution,
188}
189
190impl ExpressionInfo {
191    const fn new() -> Self {
192        ExpressionInfo {
193            uniformity: Uniformity::new(),
194            ref_count: 0,
195            assignable_global: None,
196            // this doesn't matter at this point, will be overwritten
197            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
198                kind: crate::ScalarKind::Bool,
199                width: 0,
200            })),
201        }
202    }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
206#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
207#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
208enum GlobalOrArgument {
209    Global(Handle<crate::GlobalVariable>),
210    Argument(u32),
211}
212
213impl GlobalOrArgument {
214    fn from_expression(
215        expression_arena: &Arena<crate::Expression>,
216        expression: Handle<crate::Expression>,
217    ) -> Result<GlobalOrArgument, ExpressionError> {
218        Ok(match expression_arena[expression] {
219            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
220            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
221            crate::Expression::Access { base, .. }
222            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
223                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
224                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
225            },
226            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
227        })
228    }
229}
230
231#[derive(Debug, Clone, PartialEq, Eq, Hash)]
232#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
233#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
234struct Sampling {
235    image: GlobalOrArgument,
236    sampler: GlobalOrArgument,
237}
238
239#[derive(Debug, Clone)]
240#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
241#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
242pub struct FunctionInfo {
243    /// Validation flags.
244    flags: ValidationFlags,
245    /// Set of shader stages where calling this function is valid.
246    pub available_stages: ShaderStages,
247    /// Uniformity characteristics.
248    pub uniformity: Uniformity,
249    /// Function may kill the invocation.
250    pub may_kill: bool,
251
252    /// All pairs of (texture, sampler) globals that may be used together in
253    /// sampling operations by this function and its callees. This includes
254    /// pairings that arise when this function passes textures and samplers as
255    /// arguments to its callees.
256    ///
257    /// This table does not include uses of textures and samplers passed as
258    /// arguments to this function itself, since we do not know which globals
259    /// those will be. However, this table *is* exhaustive when computed for an
260    /// entry point function: entry points never receive textures or samplers as
261    /// arguments, so all an entry point's sampling can be reported in terms of
262    /// globals.
263    ///
264    /// The GLSL back end uses this table to construct reflection info that
265    /// clients need to construct texture-combined sampler values.
266    pub sampling_set: crate::FastHashSet<SamplingKey>,
267
268    /// How this function and its callees use this module's globals.
269    ///
270    /// This is indexed by `Handle<GlobalVariable>` indices. However,
271    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
272    /// so you can simply index this struct with a global handle to retrieve
273    /// its usage information.
274    pub global_uses: Box<[GlobalUse]>,
275
276    /// Information about each expression in this function's body.
277    ///
278    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
279    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
280    /// index this struct with an expression handle to retrieve its
281    /// `ExpressionInfo`.
282    expressions: Box<[ExpressionInfo]>,
283
284    /// All (texture, sampler) pairs that may be used together in sampling
285    /// operations by this function and its callees, whether they are accessed
286    /// as globals or passed as arguments.
287    ///
288    /// Participants are represented by [`GlobalVariable`] handles whenever
289    /// possible, and otherwise by indices of this function's arguments.
290    ///
291    /// When analyzing a function call, we combine this data about the callee
292    /// with the actual arguments being passed to produce the callers' own
293    /// `sampling_set` and `sampling` tables.
294    ///
295    /// [`GlobalVariable`]: crate::GlobalVariable
296    sampling: crate::FastHashSet<Sampling>,
297
298    /// Indicates that the function is using dual source blending.
299    pub dual_source_blending: bool,
300
301    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
302    /// module.
303    ///
304    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
305    /// validation.
306    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
307
308    /// A bitmask, tracking which 4-byte slots this function (possibly transitively) reads.
309    /// Used to determine the minimum set of slots that must be written via `set_immediates`.
310    pub immediate_slots_used: ImmediateSlots,
311}
312
313impl FunctionInfo {
314    pub const fn global_variable_count(&self) -> usize {
315        self.global_uses.len()
316    }
317    pub const fn expression_count(&self) -> usize {
318        self.expressions.len()
319    }
320    pub fn dominates_global_use(&self, other: &Self) -> bool {
321        for (self_global_uses, other_global_uses) in
322            self.global_uses.iter().zip(other.global_uses.iter())
323        {
324            if !self_global_uses.contains(*other_global_uses) {
325                return false;
326            }
327        }
328        true
329    }
330}
331
332impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
333    type Output = GlobalUse;
334    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
335        &self.global_uses[handle.index()]
336    }
337}
338
339impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
340    type Output = ExpressionInfo;
341    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
342        &self.expressions[handle.index()]
343    }
344}
345
346/// Disruptor of the uniform control flow.
347#[derive(Clone, Copy, Debug, thiserror::Error)]
348#[cfg_attr(test, derive(PartialEq))]
349pub enum UniformityDisruptor {
350    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
351    Expression(Handle<crate::Expression>),
352    #[error("There is a Return earlier in the control flow of the function")]
353    Return,
354    #[error("There is a Discard earlier in the entry point across all called functions")]
355    Discard,
356}
357
358impl FunctionInfo {
359    /// Record a use of `expr` of the sort given by `global_use`.
360    ///
361    /// Bump `expr`'s reference count, and return its uniformity.
362    ///
363    /// If `expr` is a pointer to a global variable, or some part of
364    /// a global variable, add `global_use` to that global's set of
365    /// uses.
366    #[must_use]
367    fn add_ref_impl(
368        &mut self,
369        expr: Handle<crate::Expression>,
370        global_use: GlobalUse,
371    ) -> NonUniformResult {
372        let info = &mut self.expressions[expr.index()];
373        info.ref_count += 1;
374        // Record usage if this expression may access a global
375        if let Some(global) = info.assignable_global {
376            self.global_uses[global.index()] |= global_use;
377        }
378        info.uniformity.non_uniform_result
379    }
380
381    /// Note an entry point's use of `global` not recorded by [`ModuleInfo::process_function`].
382    ///
383    /// Most global variable usage should be recorded via [`add_ref_impl`] in the process
384    /// of expression behavior analysis by [`ModuleInfo::process_function`]. But that code
385    /// has no access to entrypoint-specific information, so interface analysis uses this
386    /// function to record global uses there (like task shader payloads).
387    ///
388    /// [`add_ref_impl`]: Self::add_ref_impl
389    pub(super) fn insert_global_use(
390        &mut self,
391        global_use: GlobalUse,
392        global: Handle<crate::GlobalVariable>,
393    ) {
394        self.global_uses[global.index()] |= global_use;
395    }
396
397    /// Record a use of `expr` for its value.
398    ///
399    /// This is used for almost all expression references. Anything
400    /// that writes to the value `expr` points to, or otherwise wants
401    /// contribute flags other than `GlobalUse::READ`, should use
402    /// `add_ref_impl` directly.
403    #[must_use]
404    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
405        self.add_ref_impl(expr, GlobalUse::READ)
406    }
407
408    /// Record a use of `expr`, and indicate which global variable it
409    /// refers to, if any.
410    ///
411    /// Bump `expr`'s reference count, and return its uniformity.
412    ///
413    /// If `expr` is a pointer to a global variable, or some part
414    /// thereof, store that global in `*assignable_global`. Leave the
415    /// global's uses unchanged.
416    ///
417    /// This is used to determine the [`assignable_global`] for
418    /// [`Access`] and [`AccessIndex`] expressions that ultimately
419    /// refer to a global variable. Those expressions don't contribute
420    /// any usage to the global themselves; that depends on how other
421    /// expressions use them.
422    ///
423    /// [`assignable_global`]: ExpressionInfo::assignable_global
424    /// [`Access`]: crate::Expression::Access
425    /// [`AccessIndex`]: crate::Expression::AccessIndex
426    #[must_use]
427    fn add_assignable_ref(
428        &mut self,
429        expr: Handle<crate::Expression>,
430        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
431    ) -> NonUniformResult {
432        let info = &mut self.expressions[expr.index()];
433        info.ref_count += 1;
434        // propagate the assignable global up the chain, till it either hits
435        // a value-type expression, or the assignment statement.
436        if let Some(global) = info.assignable_global {
437            if let Some(_old) = assignable_global.replace(global) {
438                unreachable!()
439            }
440        }
441        info.uniformity.non_uniform_result
442    }
443
444    /// Inherit information from a called function.
445    fn process_call(
446        &mut self,
447        callee: &Self,
448        arguments: &[Handle<crate::Expression>],
449        expression_arena: &Arena<crate::Expression>,
450    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
451        self.sampling_set
452            .extend(callee.sampling_set.iter().cloned());
453        for sampling in callee.sampling.iter() {
454            // If the callee was passed the texture or sampler as an argument,
455            // we may now be able to determine which globals those referred to.
456            let image_storage = match sampling.image {
457                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
458                GlobalOrArgument::Argument(i) => {
459                    let Some(handle) = arguments.get(i as usize).cloned() else {
460                        // Argument count mismatch, will be reported later by validate_call
461                        break;
462                    };
463                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
464                        |source| {
465                            FunctionError::Expression { handle, source }
466                                .with_span_handle(handle, expression_arena)
467                        },
468                    )?
469                }
470            };
471
472            let sampler_storage = match sampling.sampler {
473                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
474                GlobalOrArgument::Argument(i) => {
475                    let Some(handle) = arguments.get(i as usize).cloned() else {
476                        // Argument count mismatch, will be reported later by validate_call
477                        break;
478                    };
479                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
480                        |source| {
481                            FunctionError::Expression { handle, source }
482                                .with_span_handle(handle, expression_arena)
483                        },
484                    )?
485                }
486            };
487
488            // If we've managed to pin both the image and sampler down to
489            // specific globals, record that in our `sampling_set`. Otherwise,
490            // record as much as we do know in our own `sampling` table, for our
491            // callers to sort out.
492            match (image_storage, sampler_storage) {
493                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
494                    self.sampling_set.insert(SamplingKey { image, sampler });
495                }
496                (image, sampler) => {
497                    self.sampling.insert(Sampling { image, sampler });
498                }
499            }
500        }
501
502        // Inherit global use and immediate slot tracking from our callees.
503        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
504            *mine |= *other;
505        }
506        self.immediate_slots_used |= callee.immediate_slots_used;
507
508        Ok(FunctionUniformity {
509            result: callee.uniformity.clone(),
510            exit: if callee.may_kill {
511                ExitFlags::MAY_KILL
512            } else {
513                ExitFlags::empty()
514            },
515        })
516    }
517
518    /// Compute the [`ExpressionInfo`] for `handle`.
519    ///
520    /// Replace the dummy entry in [`self.expressions`] for `handle`
521    /// with a real `ExpressionInfo` value describing that expression.
522    ///
523    /// This function is called as part of a forward sweep through the
524    /// arena, so we can assume that all earlier expressions in the
525    /// arena already have valid info. Since expressions only depend
526    /// on earlier expressions, this includes all our subexpressions.
527    ///
528    /// Adjust the reference counts on all expressions we use.
529    ///
530    /// Also populate the [`sampling_set`], [`sampling`] and
531    /// [`global_uses`] fields of `self`.
532    ///
533    /// [`self.expressions`]: FunctionInfo::expressions
534    /// [`sampling_set`]: FunctionInfo::sampling_set
535    /// [`sampling`]: FunctionInfo::sampling
536    /// [`global_uses`]: FunctionInfo::global_uses
537    #[allow(clippy::or_fun_call)]
538    fn process_expression(
539        &mut self,
540        handle: Handle<crate::Expression>,
541        expression_arena: &Arena<crate::Expression>,
542        other_functions: &[FunctionInfo],
543        resolve_context: &ResolveContext,
544        capabilities: super::Capabilities,
545    ) -> Result<(), ExpressionError> {
546        use crate::{Expression as E, SampleLevel as Sl};
547
548        let expression = &expression_arena[handle];
549        let mut assignable_global = None;
550        let uniformity = match *expression {
551            E::Access { base, index } => {
552                let base_ty = self[base].ty.inner_with(resolve_context.types);
553
554                // build up the caps needed if this is indexed non-uniformly
555                let mut needed_caps = super::Capabilities::empty();
556                let is_binding_array = match *base_ty {
557                    crate::TypeInner::BindingArray {
558                        base: array_element_ty_handle,
559                        ..
560                    } => {
561                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
562                        let array_element_ty =
563                            &resolve_context.types[array_element_ty_handle].inner;
564
565                        needed_caps |= match *array_element_ty {
566                            // If we're an image, use the appropriate capability.
567                            crate::TypeInner::Image { class, .. } => match class {
568                                crate::ImageClass::Storage { .. } => {
569                                    super::Capabilities::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING
570                                }
571                                _ => {
572                                    super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
573                                }
574                            },
575                            crate::TypeInner::Sampler { .. } => {
576                                super::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
577                            }
578                            // If we're anything but an image or sampler, assume we're a buffer and use the address space.
579                            _ => {
580                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
581                                    let global = &resolve_context.global_vars[global_handle];
582                                    match global.space {
583                                        crate::AddressSpace::Uniform => {
584                                            super::Capabilities::BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
585                                        }
586                                        crate::AddressSpace::Storage { .. } => {
587                                            super::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING
588                                        }
589                                        _ => unreachable!(),
590                                    }
591                                } else {
592                                    unreachable!()
593                                }
594                            }
595                        };
596
597                        true
598                    }
599                    _ => false,
600                };
601
602                if self[index].uniformity.non_uniform_result.is_some()
603                    && !capabilities.contains(needed_caps)
604                    && is_binding_array
605                {
606                    return Err(ExpressionError::MissingCapabilities(needed_caps));
607                }
608
609                Uniformity {
610                    non_uniform_result: self
611                        .add_assignable_ref(base, &mut assignable_global)
612                        .or(self.add_ref(index)),
613                    requirements: UniformityRequirements::empty(),
614                }
615            }
616            E::AccessIndex { base, .. } => Uniformity {
617                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
618                requirements: UniformityRequirements::empty(),
619            },
620            // always uniform
621            E::Splat { size: _, value } => Uniformity {
622                non_uniform_result: self.add_ref(value),
623                requirements: UniformityRequirements::empty(),
624            },
625            E::Swizzle { vector, .. } => Uniformity {
626                non_uniform_result: self.add_ref(vector),
627                requirements: UniformityRequirements::empty(),
628            },
629            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
630            E::Compose { ref components, .. } => {
631                let non_uniform_result = components
632                    .iter()
633                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
634                Uniformity {
635                    non_uniform_result,
636                    requirements: UniformityRequirements::empty(),
637                }
638            }
639            // depends on the builtin
640            E::FunctionArgument(index) => {
641                let arg = &resolve_context.arguments[index as usize];
642                let uniform = match arg.binding {
643                    Some(crate::Binding::BuiltIn(
644                        // per-work-group built-ins are uniform
645                        crate::BuiltIn::WorkGroupId
646                        | crate::BuiltIn::WorkGroupSize
647                        | crate::BuiltIn::NumWorkGroups,
648                    )) => true,
649                    _ => false,
650                };
651                Uniformity {
652                    non_uniform_result: if uniform { None } else { Some(handle) },
653                    requirements: UniformityRequirements::empty(),
654                }
655            }
656            // depends on the address space
657            E::GlobalVariable(gh) => {
658                use crate::AddressSpace as As;
659                assignable_global = Some(gh);
660                let var = &resolve_context.global_vars[gh];
661                let uniform = match var.space {
662                    // local data is non-uniform
663                    As::Function | As::Private | As::RayPayload | As::IncomingRayPayload => false,
664                    // workgroup memory is exclusively accessed by the group
665                    // task payload memory is very similar to workgroup memory
666                    As::WorkGroup | As::TaskPayload => true,
667                    // uniform data
668                    As::Uniform | As::Immediate => true,
669                    // storage data is only uniform when read-only
670                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
671                    As::Handle => false,
672                };
673                Uniformity {
674                    non_uniform_result: if uniform { None } else { Some(handle) },
675                    requirements: UniformityRequirements::empty(),
676                }
677            }
678            E::LocalVariable(_) => Uniformity {
679                non_uniform_result: Some(handle),
680                requirements: UniformityRequirements::empty(),
681            },
682            E::Load { pointer } => {
683                let non_uniform_result = self.add_ref(pointer);
684                // Track which immediate slots this load touches.
685                if let Some(global) = self.expressions[pointer.index()].assignable_global {
686                    if resolve_context.global_vars[global].space == crate::AddressSpace::Immediate {
687                        self.immediate_slots_used |= ImmediateSlots::for_pointer(
688                            pointer,
689                            global,
690                            expression_arena,
691                            resolve_context.global_vars,
692                            resolve_context.types,
693                        );
694                    }
695                }
696                Uniformity {
697                    non_uniform_result,
698                    requirements: UniformityRequirements::empty(),
699                }
700            }
701            E::ImageSample {
702                image,
703                sampler,
704                gather: _,
705                coordinate,
706                array_index,
707                offset,
708                level,
709                depth_ref,
710                clamp_to_edge: _,
711            } => {
712                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
713                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
714
715                match (image_storage, sampler_storage) {
716                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
717                        self.sampling_set.insert(SamplingKey { image, sampler });
718                    }
719                    _ => {
720                        self.sampling.insert(Sampling {
721                            image: image_storage,
722                            sampler: sampler_storage,
723                        });
724                    }
725                }
726
727                // "nur" == "Non-Uniform Result"
728                let array_nur = array_index.and_then(|h| self.add_ref(h));
729                let level_nur = match level {
730                    Sl::Auto | Sl::Zero => None,
731                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
732                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
733                };
734                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
735                let offset_nur = offset.and_then(|h| self.add_ref(h));
736                Uniformity {
737                    non_uniform_result: self
738                        .add_ref(image)
739                        .or(self.add_ref(sampler))
740                        .or(self.add_ref(coordinate))
741                        .or(array_nur)
742                        .or(level_nur)
743                        .or(dref_nur)
744                        .or(offset_nur),
745                    requirements: if level.implicit_derivatives() {
746                        UniformityRequirements::IMPLICIT_LEVEL
747                    } else {
748                        UniformityRequirements::empty()
749                    },
750                }
751            }
752            E::ImageLoad {
753                image,
754                coordinate,
755                array_index,
756                sample,
757                level,
758            } => {
759                let array_nur = array_index.and_then(|h| self.add_ref(h));
760                let sample_nur = sample.and_then(|h| self.add_ref(h));
761                let level_nur = level.and_then(|h| self.add_ref(h));
762                Uniformity {
763                    non_uniform_result: self
764                        .add_ref(image)
765                        .or(self.add_ref(coordinate))
766                        .or(array_nur)
767                        .or(sample_nur)
768                        .or(level_nur),
769                    requirements: UniformityRequirements::empty(),
770                }
771            }
772            E::ImageQuery { image, query } => {
773                let query_nur = match query {
774                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
775                    _ => None,
776                };
777                Uniformity {
778                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
779                    requirements: UniformityRequirements::empty(),
780                }
781            }
782            E::Unary { expr, .. } => Uniformity {
783                non_uniform_result: self.add_ref(expr),
784                requirements: UniformityRequirements::empty(),
785            },
786            E::Binary { left, right, .. } => Uniformity {
787                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
788                requirements: UniformityRequirements::empty(),
789            },
790            E::Select {
791                condition,
792                accept,
793                reject,
794            } => Uniformity {
795                non_uniform_result: self
796                    .add_ref(condition)
797                    .or(self.add_ref(accept))
798                    .or(self.add_ref(reject)),
799                requirements: UniformityRequirements::empty(),
800            },
801            // explicit derivatives require uniform
802            E::Derivative { expr, .. } => Uniformity {
803                //Note: taking a derivative of a uniform doesn't make it non-uniform
804                non_uniform_result: self.add_ref(expr),
805                requirements: UniformityRequirements::DERIVATIVE,
806            },
807            E::Relational { argument, .. } => Uniformity {
808                non_uniform_result: self.add_ref(argument),
809                requirements: UniformityRequirements::empty(),
810            },
811            E::Math {
812                fun: _,
813                arg,
814                arg1,
815                arg2,
816                arg3,
817            } => {
818                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
819                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
820                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
821                Uniformity {
822                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
823                    requirements: UniformityRequirements::empty(),
824                }
825            }
826            E::As { expr, .. } => Uniformity {
827                non_uniform_result: self.add_ref(expr),
828                requirements: UniformityRequirements::empty(),
829            },
830            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
831            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
832                non_uniform_result: Some(handle),
833                requirements: UniformityRequirements::empty(),
834            },
835            E::WorkGroupUniformLoadResult { .. } => Uniformity {
836                // The result of WorkGroupUniformLoad is always uniform by definition
837                non_uniform_result: None,
838                // The call is what cares about uniformity, not the expression
839                // This expression is never emitted, so this requirement should never be used anyway?
840                requirements: UniformityRequirements::empty(),
841            },
842            E::ArrayLength(expr) => Uniformity {
843                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
844                requirements: UniformityRequirements::empty(),
845            },
846            E::RayQueryGetIntersection {
847                query,
848                committed: _,
849            } => Uniformity {
850                non_uniform_result: self.add_ref(query),
851                requirements: UniformityRequirements::empty(),
852            },
853            E::SubgroupBallotResult => Uniformity {
854                non_uniform_result: Some(handle),
855                requirements: UniformityRequirements::empty(),
856            },
857            E::SubgroupOperationResult { .. } => Uniformity {
858                non_uniform_result: Some(handle),
859                requirements: UniformityRequirements::empty(),
860            },
861            E::RayQueryVertexPositions {
862                query,
863                committed: _,
864            } => Uniformity {
865                non_uniform_result: self.add_ref(query),
866                requirements: UniformityRequirements::empty(),
867            },
868            E::CooperativeLoad { ref data, .. } => Uniformity {
869                non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)),
870                requirements: UniformityRequirements::COOP_OPS,
871            },
872            E::CooperativeMultiplyAdd { a, b, c } => Uniformity {
873                non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))),
874                requirements: UniformityRequirements::COOP_OPS,
875            },
876        };
877
878        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
879        self.expressions[handle.index()] = ExpressionInfo {
880            uniformity,
881            ref_count: 0,
882            assignable_global,
883            ty,
884        };
885        Ok(())
886    }
887
888    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
889    /// Returns the uniformity characteristics at the *function* level, i.e.
890    /// whether or not the function requires to be called in uniform control flow,
891    /// and whether the produced result is not disrupting the control flow.
892    ///
893    /// The parent control flow is uniform if `disruptor.is_none()`.
894    ///
895    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
896    /// require uniformity, but the current flow is non-uniform.
897    #[allow(clippy::or_fun_call)]
898    fn process_block(
899        &mut self,
900        statements: &crate::Block,
901        other_functions: &[FunctionInfo],
902        mut disruptor: Option<UniformityDisruptor>,
903        expression_arena: &Arena<crate::Expression>,
904        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
905    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
906        use crate::Statement as S;
907
908        let mut combined_uniformity = FunctionUniformity::new();
909        for statement in statements {
910            let uniformity = match *statement {
911                S::Emit(ref range) => {
912                    let mut requirements = UniformityRequirements::empty();
913                    for expr in range.clone() {
914                        let req = self.expressions[expr.index()].uniformity.requirements;
915                        if self
916                            .flags
917                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
918                            && !req.is_empty()
919                        {
920                            if let Some(cause) = disruptor {
921                                let severity = DiagnosticFilterNode::search(
922                                    self.diagnostic_filter_leaf,
923                                    diagnostic_filter_arena,
924                                    StandardFilterableTriggeringRule::DerivativeUniformity,
925                                );
926                                severity.report_diag(
927                                    FunctionError::NonUniformControlFlow(req, expr, cause)
928                                        .with_span_handle(expr, expression_arena),
929                                    // TODO: Yes, this isn't contextualized with source, because
930                                    // the user is supposed to render what would normally be an
931                                    // error here. Once we actually support warning-level
932                                    // diagnostic items, then we won't need this non-compliant hack:
933                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
934                                    |e, level| log::log!(level, "{e}"),
935                                )?;
936                            }
937                        }
938                        requirements |= req;
939                    }
940                    FunctionUniformity {
941                        result: Uniformity {
942                            non_uniform_result: None,
943                            requirements,
944                        },
945                        exit: ExitFlags::empty(),
946                    }
947                }
948                S::Break | S::Continue => FunctionUniformity::new(),
949                S::Kill => FunctionUniformity {
950                    result: Uniformity::new(),
951                    exit: if disruptor.is_some() {
952                        ExitFlags::MAY_KILL
953                    } else {
954                        ExitFlags::empty()
955                    },
956                },
957                S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
958                    result: Uniformity {
959                        non_uniform_result: None,
960                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
961                    },
962                    exit: ExitFlags::empty(),
963                },
964                S::WorkGroupUniformLoad { pointer, .. } => {
965                    let _condition_nur = self.add_ref(pointer);
966
967                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
968                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
969                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
970                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
971
972                    /*
973                    if self
974                        .flags
975                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
976                    {
977                        let condition_nur = self.add_ref(pointer);
978                        let this_disruptor =
979                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
980                        if let Some(cause) = this_disruptor {
981                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
982                                .with_span_static(*span, "WorkGroupUniformLoad"));
983                        }
984                    } */
985                    FunctionUniformity {
986                        result: Uniformity {
987                            non_uniform_result: None,
988                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
989                        },
990                        exit: ExitFlags::empty(),
991                    }
992                }
993                S::Block(ref b) => self.process_block(
994                    b,
995                    other_functions,
996                    disruptor,
997                    expression_arena,
998                    diagnostic_filter_arena,
999                )?,
1000                S::If {
1001                    condition,
1002                    ref accept,
1003                    ref reject,
1004                } => {
1005                    let condition_nur = self.add_ref(condition);
1006                    let branch_disruptor =
1007                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
1008                    let accept_uniformity = self.process_block(
1009                        accept,
1010                        other_functions,
1011                        branch_disruptor,
1012                        expression_arena,
1013                        diagnostic_filter_arena,
1014                    )?;
1015                    let reject_uniformity = self.process_block(
1016                        reject,
1017                        other_functions,
1018                        branch_disruptor,
1019                        expression_arena,
1020                        diagnostic_filter_arena,
1021                    )?;
1022                    accept_uniformity | reject_uniformity
1023                }
1024                S::Switch {
1025                    selector,
1026                    ref cases,
1027                } => {
1028                    let selector_nur = self.add_ref(selector);
1029                    let branch_disruptor =
1030                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1031                    let mut uniformity = FunctionUniformity::new();
1032                    let mut case_disruptor = branch_disruptor;
1033                    for case in cases.iter() {
1034                        let case_uniformity = self.process_block(
1035                            &case.body,
1036                            other_functions,
1037                            case_disruptor,
1038                            expression_arena,
1039                            diagnostic_filter_arena,
1040                        )?;
1041                        case_disruptor = if case.fall_through {
1042                            case_disruptor.or(case_uniformity.exit_disruptor())
1043                        } else {
1044                            branch_disruptor
1045                        };
1046                        uniformity = uniformity | case_uniformity;
1047                    }
1048                    uniformity
1049                }
1050                S::Loop {
1051                    ref body,
1052                    ref continuing,
1053                    break_if,
1054                } => {
1055                    let body_uniformity = self.process_block(
1056                        body,
1057                        other_functions,
1058                        disruptor,
1059                        expression_arena,
1060                        diagnostic_filter_arena,
1061                    )?;
1062                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1063                    let continuing_uniformity = self.process_block(
1064                        continuing,
1065                        other_functions,
1066                        continuing_disruptor,
1067                        expression_arena,
1068                        diagnostic_filter_arena,
1069                    )?;
1070                    if let Some(expr) = break_if {
1071                        let _ = self.add_ref(expr);
1072                    }
1073                    body_uniformity | continuing_uniformity
1074                }
1075                S::Return { value } => FunctionUniformity {
1076                    result: Uniformity {
1077                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1078                        requirements: UniformityRequirements::empty(),
1079                    },
1080                    exit: if disruptor.is_some() {
1081                        ExitFlags::MAY_RETURN
1082                    } else {
1083                        ExitFlags::empty()
1084                    },
1085                },
1086                // Here and below, the used expressions are already emitted,
1087                // and their results do not affect the function return value,
1088                // so we can ignore their non-uniformity.
1089                S::Store { pointer, value } => {
1090                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1091                    let _ = self.add_ref(value);
1092                    FunctionUniformity::new()
1093                }
1094                S::ImageStore {
1095                    image,
1096                    coordinate,
1097                    array_index,
1098                    value,
1099                } => {
1100                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1101                    if let Some(expr) = array_index {
1102                        let _ = self.add_ref(expr);
1103                    }
1104                    let _ = self.add_ref(coordinate);
1105                    let _ = self.add_ref(value);
1106                    FunctionUniformity::new()
1107                }
1108                S::Call {
1109                    function,
1110                    ref arguments,
1111                    result: _,
1112                } => {
1113                    for &argument in arguments {
1114                        let _ = self.add_ref(argument);
1115                    }
1116                    let info = &other_functions[function.index()];
1117                    //Note: the result is validated by the Validator, not here
1118                    self.process_call(info, arguments, expression_arena)?
1119                }
1120                S::Atomic {
1121                    pointer,
1122                    ref fun,
1123                    value,
1124                    result: _,
1125                } => {
1126                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1127                    let _ = self.add_ref(value);
1128                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1129                        let _ = self.add_ref(cmp);
1130                    }
1131                    FunctionUniformity::new()
1132                }
1133                S::ImageAtomic {
1134                    image,
1135                    coordinate,
1136                    array_index,
1137                    fun: _,
1138                    value,
1139                } => {
1140                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1141                    let _ = self.add_ref(coordinate);
1142                    if let Some(expr) = array_index {
1143                        let _ = self.add_ref(expr);
1144                    }
1145                    let _ = self.add_ref(value);
1146                    FunctionUniformity::new()
1147                }
1148                S::RayQuery { query, ref fun } => {
1149                    let _ = self.add_ref(query);
1150                    match *fun {
1151                        crate::RayQueryFunction::Initialize {
1152                            acceleration_structure,
1153                            descriptor,
1154                        } => {
1155                            let _ = self.add_ref(acceleration_structure);
1156                            let _ = self.add_ref(descriptor);
1157                        }
1158                        crate::RayQueryFunction::Proceed { result: _ } => {}
1159                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1160                            let _ = self.add_ref(hit_t);
1161                        }
1162                        crate::RayQueryFunction::ConfirmIntersection => {}
1163                        crate::RayQueryFunction::Terminate => {}
1164                    }
1165                    FunctionUniformity::new()
1166                }
1167                S::SubgroupBallot {
1168                    result: _,
1169                    predicate,
1170                } => {
1171                    if let Some(predicate) = predicate {
1172                        let _ = self.add_ref(predicate);
1173                    }
1174                    FunctionUniformity::new()
1175                }
1176                S::SubgroupCollectiveOperation {
1177                    op: _,
1178                    collective_op: _,
1179                    argument,
1180                    result: _,
1181                } => {
1182                    let _ = self.add_ref(argument);
1183                    FunctionUniformity::new()
1184                }
1185                S::SubgroupGather {
1186                    mode,
1187                    argument,
1188                    result: _,
1189                } => {
1190                    let _ = self.add_ref(argument);
1191                    match mode {
1192                        crate::GatherMode::BroadcastFirst => {}
1193                        crate::GatherMode::Broadcast(index)
1194                        | crate::GatherMode::Shuffle(index)
1195                        | crate::GatherMode::ShuffleDown(index)
1196                        | crate::GatherMode::ShuffleUp(index)
1197                        | crate::GatherMode::ShuffleXor(index)
1198                        | crate::GatherMode::QuadBroadcast(index) => {
1199                            let _ = self.add_ref(index);
1200                        }
1201                        crate::GatherMode::QuadSwap(_) => {}
1202                    }
1203                    FunctionUniformity::new()
1204                }
1205                S::CooperativeStore { target, ref data } => FunctionUniformity {
1206                    result: Uniformity {
1207                        non_uniform_result: self
1208                            .add_ref(target)
1209                            .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE))
1210                            .or(self.add_ref(data.stride)),
1211                        requirements: UniformityRequirements::COOP_OPS,
1212                    },
1213                    exit: ExitFlags::empty(),
1214                },
1215                S::RayPipelineFunction(ref fun) => {
1216                    match *fun {
1217                        crate::RayPipelineFunction::TraceRay {
1218                            acceleration_structure,
1219                            descriptor,
1220                            payload,
1221                        } => {
1222                            let _ = self.add_ref(acceleration_structure);
1223                            let _ = self.add_ref(descriptor);
1224                            let _ = self.add_ref(payload);
1225                        }
1226                    }
1227                    FunctionUniformity::new()
1228                }
1229            };
1230
1231            disruptor = disruptor.or(uniformity.exit_disruptor());
1232            combined_uniformity = combined_uniformity | uniformity;
1233        }
1234        Ok(combined_uniformity)
1235    }
1236}
1237
1238impl ModuleInfo {
1239    /// Populates `self.const_expression_types`
1240    pub(super) fn process_const_expression(
1241        &mut self,
1242        handle: Handle<crate::Expression>,
1243        resolve_context: &ResolveContext,
1244        gctx: crate::proc::GlobalCtx,
1245    ) -> Result<(), super::ConstExpressionError> {
1246        self.const_expression_types[handle.index()] =
1247            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1248        Ok(())
1249    }
1250
1251    /// Builds the `FunctionInfo` based on the function, and validates the
1252    /// uniform control flow if required by the expressions of this function.
1253    pub(super) fn process_function(
1254        &self,
1255        fun: &crate::Function,
1256        module: &crate::Module,
1257        flags: ValidationFlags,
1258        capabilities: super::Capabilities,
1259    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1260        let mut info = FunctionInfo {
1261            flags,
1262            available_stages: ShaderStages::all(),
1263            uniformity: Uniformity::new(),
1264            may_kill: false,
1265            sampling_set: crate::FastHashSet::default(),
1266            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1267            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1268            sampling: crate::FastHashSet::default(),
1269            dual_source_blending: false,
1270            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1271            immediate_slots_used: ImmediateSlots::default(),
1272        };
1273        let resolve_context =
1274            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1275
1276        for (handle, _) in fun.expressions.iter() {
1277            if let Err(source) = info.process_expression(
1278                handle,
1279                &fun.expressions,
1280                &self.functions,
1281                &resolve_context,
1282                capabilities,
1283            ) {
1284                return Err(FunctionError::Expression { handle, source }
1285                    .with_span_handle(handle, &fun.expressions));
1286            }
1287        }
1288
1289        for (_, expr) in fun.local_variables.iter() {
1290            if let Some(init) = expr.init {
1291                let _ = info.add_ref(init);
1292            }
1293        }
1294
1295        let uniformity = info.process_block(
1296            &fun.body,
1297            &self.functions,
1298            None,
1299            &fun.expressions,
1300            &module.diagnostic_filters,
1301        )?;
1302        info.uniformity = uniformity.result;
1303        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1304
1305        // If there are any globals referenced directly by a named expression,
1306        // ensure they are marked as used even if they are not referenced
1307        // anywhere else. An important case where this matters is phony
1308        // assignments used to include a global in the shader's resource
1309        // interface. https://www.w3.org/TR/WGSL/#phony-assignment-section
1310        for &handle in fun.named_expressions.keys() {
1311            if let Some(global) = info[handle].assignable_global {
1312                if info.global_uses[global.index()].is_empty() {
1313                    info.global_uses[global.index()] = GlobalUse::QUERY;
1314                }
1315            }
1316        }
1317
1318        Ok(info)
1319    }
1320
1321    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1322        &self.entry_points[index]
1323    }
1324}
1325
1326#[test]
1327fn uniform_control_flow() {
1328    use crate::{Expression as E, Statement as S};
1329
1330    let mut type_arena = crate::UniqueArena::new();
1331    let ty = type_arena.insert(
1332        crate::Type {
1333            name: None,
1334            inner: crate::TypeInner::Vector {
1335                size: crate::VectorSize::Bi,
1336                scalar: crate::Scalar::F32,
1337            },
1338        },
1339        Default::default(),
1340    );
1341    let mut global_var_arena = Arena::new();
1342    let non_uniform_global = global_var_arena.append(
1343        crate::GlobalVariable {
1344            name: None,
1345            init: None,
1346            ty,
1347            space: crate::AddressSpace::Handle,
1348            binding: None,
1349            memory_decorations: crate::MemoryDecorations::empty(),
1350        },
1351        Default::default(),
1352    );
1353    let uniform_global = global_var_arena.append(
1354        crate::GlobalVariable {
1355            name: None,
1356            init: None,
1357            ty,
1358            binding: None,
1359            space: crate::AddressSpace::Uniform,
1360            memory_decorations: crate::MemoryDecorations::empty(),
1361        },
1362        Default::default(),
1363    );
1364
1365    let mut expressions = Arena::new();
1366    // checks the uniform control flow
1367    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1368    // checks the non-uniform control flow
1369    let derivative_expr = expressions.append(
1370        E::Derivative {
1371            axis: crate::DerivativeAxis::X,
1372            ctrl: crate::DerivativeControl::None,
1373            expr: constant_expr,
1374        },
1375        Default::default(),
1376    );
1377    let emit_range_constant_derivative = expressions.range_from(0);
1378    let non_uniform_global_expr =
1379        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1380    let uniform_global_expr =
1381        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1382    let emit_range_globals = expressions.range_from(2);
1383
1384    // checks the QUERY flag
1385    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1386    // checks the transitive WRITE flag
1387    let access_expr = expressions.append(
1388        E::AccessIndex {
1389            base: non_uniform_global_expr,
1390            index: 1,
1391        },
1392        Default::default(),
1393    );
1394    let emit_range_query_access_globals = expressions.range_from(2);
1395
1396    let mut info = FunctionInfo {
1397        flags: ValidationFlags::all(),
1398        available_stages: ShaderStages::all(),
1399        uniformity: Uniformity::new(),
1400        may_kill: false,
1401        sampling_set: crate::FastHashSet::default(),
1402        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1403        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1404        sampling: crate::FastHashSet::default(),
1405        dual_source_blending: false,
1406        diagnostic_filter_leaf: None,
1407        immediate_slots_used: ImmediateSlots::default(),
1408    };
1409    let resolve_context = ResolveContext {
1410        constants: &Arena::new(),
1411        overrides: &Arena::new(),
1412        types: &type_arena,
1413        special_types: &crate::SpecialTypes::default(),
1414        global_vars: &global_var_arena,
1415        local_vars: &Arena::new(),
1416        functions: &Arena::new(),
1417        arguments: &[],
1418    };
1419    for (handle, _) in expressions.iter() {
1420        info.process_expression(
1421            handle,
1422            &expressions,
1423            &[],
1424            &resolve_context,
1425            super::Capabilities::empty(),
1426        )
1427        .unwrap();
1428    }
1429    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1430    assert_eq!(info[uniform_global_expr].ref_count, 1);
1431    assert_eq!(info[query_expr].ref_count, 0);
1432    assert_eq!(info[access_expr].ref_count, 0);
1433    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1434    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1435
1436    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1437    let stmt_if_uniform = S::If {
1438        condition: uniform_global_expr,
1439        accept: crate::Block::new(),
1440        reject: vec![
1441            S::Emit(emit_range_constant_derivative.clone()),
1442            S::Store {
1443                pointer: constant_expr,
1444                value: derivative_expr,
1445            },
1446        ]
1447        .into(),
1448    };
1449    assert_eq!(
1450        info.process_block(
1451            &vec![stmt_emit1, stmt_if_uniform].into(),
1452            &[],
1453            None,
1454            &expressions,
1455            &Arena::new(),
1456        ),
1457        Ok(FunctionUniformity {
1458            result: Uniformity {
1459                non_uniform_result: None,
1460                requirements: UniformityRequirements::DERIVATIVE,
1461            },
1462            exit: ExitFlags::empty(),
1463        }),
1464    );
1465    assert_eq!(info[constant_expr].ref_count, 2);
1466    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1467
1468    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1469    let stmt_if_non_uniform = S::If {
1470        condition: non_uniform_global_expr,
1471        accept: vec![
1472            S::Emit(emit_range_constant_derivative),
1473            S::Store {
1474                pointer: constant_expr,
1475                value: derivative_expr,
1476            },
1477        ]
1478        .into(),
1479        reject: crate::Block::new(),
1480    };
1481    {
1482        let block_info = info.process_block(
1483            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1484            &[],
1485            None,
1486            &expressions,
1487            &Arena::new(),
1488        );
1489        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1490            assert_eq!(info[derivative_expr].ref_count, 2);
1491        } else {
1492            assert_eq!(
1493                block_info,
1494                Err(FunctionError::NonUniformControlFlow(
1495                    UniformityRequirements::DERIVATIVE,
1496                    derivative_expr,
1497                    UniformityDisruptor::Expression(non_uniform_global_expr)
1498                )
1499                .with_span()),
1500            );
1501            assert_eq!(info[derivative_expr].ref_count, 1);
1502
1503            // Test that the same thing passes when we disable the `derivative_uniformity`
1504            let mut diagnostic_filters = Arena::new();
1505            let diagnostic_filter_leaf = diagnostic_filters.append(
1506                DiagnosticFilterNode {
1507                    inner: crate::diagnostic_filter::DiagnosticFilter {
1508                        new_severity: crate::diagnostic_filter::Severity::Off,
1509                        triggering_rule:
1510                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1511                                StandardFilterableTriggeringRule::DerivativeUniformity,
1512                            ),
1513                    },
1514                    parent: None,
1515                },
1516                crate::Span::default(),
1517            );
1518            let mut info = FunctionInfo {
1519                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1520                ..info.clone()
1521            };
1522
1523            let block_info = info.process_block(
1524                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1525                &[],
1526                None,
1527                &expressions,
1528                &diagnostic_filters,
1529            );
1530            assert_eq!(
1531                block_info,
1532                Ok(FunctionUniformity {
1533                    result: Uniformity {
1534                        non_uniform_result: None,
1535                        requirements: UniformityRequirements::DERIVATIVE,
1536                    },
1537                    exit: ExitFlags::empty()
1538                }),
1539            );
1540            assert_eq!(info[derivative_expr].ref_count, 2);
1541        }
1542    }
1543    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1544
1545    let stmt_emit3 = S::Emit(emit_range_globals);
1546    let stmt_return_non_uniform = S::Return {
1547        value: Some(non_uniform_global_expr),
1548    };
1549    assert_eq!(
1550        info.process_block(
1551            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1552            &[],
1553            Some(UniformityDisruptor::Return),
1554            &expressions,
1555            &Arena::new(),
1556        ),
1557        Ok(FunctionUniformity {
1558            result: Uniformity {
1559                non_uniform_result: Some(non_uniform_global_expr),
1560                requirements: UniformityRequirements::empty(),
1561            },
1562            exit: ExitFlags::MAY_RETURN,
1563        }),
1564    );
1565    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1566
1567    // Check that uniformity requirements reach through a pointer
1568    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1569    let stmt_assign = S::Store {
1570        pointer: access_expr,
1571        value: query_expr,
1572    };
1573    let stmt_return_pointer = S::Return {
1574        value: Some(access_expr),
1575    };
1576    let stmt_kill = S::Kill;
1577    assert_eq!(
1578        info.process_block(
1579            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1580            &[],
1581            Some(UniformityDisruptor::Discard),
1582            &expressions,
1583            &Arena::new(),
1584        ),
1585        Ok(FunctionUniformity {
1586            result: Uniformity {
1587                non_uniform_result: Some(non_uniform_global_expr),
1588                requirements: UniformityRequirements::empty(),
1589            },
1590            exit: ExitFlags::all(),
1591        }),
1592    );
1593    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1594}