naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15    arena::{Arena, Handle},
16    proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24    /// Kinds of expressions that require uniform control flow.
25    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28    pub struct UniformityRequirements: u8 {
29        const WORK_GROUP_BARRIER = 0x1;
30        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32    }
33}
34
35/// Uniform control flow characteristics.
36#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41    /// A child expression with non-uniform result.
42    ///
43    /// This means, when the relevant invocations are scheduled on a compute unit,
44    /// they have to use vector registers to store an individual value
45    /// per invocation.
46    ///
47    /// Whenever the control flow is conditioned on such value,
48    /// the hardware needs to keep track of the mask of invocations,
49    /// and process all branches of the control flow.
50    ///
51    /// Any operations that depend on non-uniform results also produce non-uniform.
52    pub non_uniform_result: NonUniformResult,
53    /// If this expression requires uniform control flow, store the reason here.
54    pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58    const fn new() -> Self {
59        Uniformity {
60            non_uniform_result: None,
61            requirements: UniformityRequirements::empty(),
62        }
63    }
64}
65
66bitflags::bitflags! {
67    #[derive(Clone, Copy, Debug, PartialEq)]
68    struct ExitFlags: u8 {
69        /// Control flow may return from the function, which makes all the
70        /// subsequent statements within the current function (only!)
71        /// to be executed in a non-uniform control flow.
72        const MAY_RETURN = 0x1;
73        /// Control flow may be killed. Anything after [`Statement::Kill`] is
74        /// considered inside non-uniform context.
75        ///
76        /// [`Statement::Kill`]: crate::Statement::Kill
77        const MAY_KILL = 0x2;
78    }
79}
80
81/// Uniformity characteristics of a function.
82#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84    result: Uniformity,
85    exit: ExitFlags,
86}
87
88/// Mesh shader related characteristics of a function.
89#[derive(Debug, Clone, Default)]
90#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
91#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
92#[cfg_attr(test, derive(PartialEq))]
93pub struct FunctionMeshShaderInfo {
94    /// The type of value this function passes to [`SetVertex`], and the
95    /// expression that first established it.
96    ///
97    /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex
98    pub vertex_type: Option<(Handle<crate::Type>, Handle<crate::Expression>)>,
99
100    /// The type of value this function passes to [`SetPrimitive`], and the
101    /// expression that first established it.
102    ///
103    /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive
104    pub primitive_type: Option<(Handle<crate::Type>, Handle<crate::Expression>)>,
105}
106
107impl ops::BitOr for FunctionUniformity {
108    type Output = Self;
109    fn bitor(self, other: Self) -> Self {
110        FunctionUniformity {
111            result: Uniformity {
112                non_uniform_result: self
113                    .result
114                    .non_uniform_result
115                    .or(other.result.non_uniform_result),
116                requirements: self.result.requirements | other.result.requirements,
117            },
118            exit: self.exit | other.exit,
119        }
120    }
121}
122
123impl FunctionUniformity {
124    const fn new() -> Self {
125        FunctionUniformity {
126            result: Uniformity::new(),
127            exit: ExitFlags::empty(),
128        }
129    }
130
131    /// Returns a disruptor based on the stored exit flags, if any.
132    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
133        if self.exit.contains(ExitFlags::MAY_RETURN) {
134            Some(UniformityDisruptor::Return)
135        } else if self.exit.contains(ExitFlags::MAY_KILL) {
136            Some(UniformityDisruptor::Discard)
137        } else {
138            None
139        }
140    }
141}
142
143bitflags::bitflags! {
144    /// Indicates how a global variable is used.
145    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
148    pub struct GlobalUse: u8 {
149        /// Data will be read from the variable.
150        const READ = 0x1;
151        /// Data will be written to the variable.
152        const WRITE = 0x2;
153        /// The information about the data is queried.
154        const QUERY = 0x4;
155        /// Atomic operations will be performed on the variable.
156        const ATOMIC = 0x8;
157    }
158}
159
160#[derive(Clone, Debug, Eq, Hash, PartialEq)]
161#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
162#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
163pub struct SamplingKey {
164    pub image: Handle<crate::GlobalVariable>,
165    pub sampler: Handle<crate::GlobalVariable>,
166}
167
168#[derive(Clone, Debug)]
169#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
170#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
171/// Information about an expression in a function body.
172pub struct ExpressionInfo {
173    /// Whether this expression is uniform, and why.
174    ///
175    /// If this expression's value is not uniform, this is the handle
176    /// of the expression from which this one's non-uniformity
177    /// originates. Otherwise, this is `None`.
178    pub uniformity: Uniformity,
179
180    /// The number of direct references to this expression in statements and
181    /// other expressions.
182    ///
183    /// This is a _local_ reference count only, it may be non-zero for
184    /// expressions that are ultimately unused.
185    pub ref_count: usize,
186
187    /// The global variable into which this expression produces a pointer.
188    ///
189    /// This is `None` unless this expression is either a
190    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
191    /// ultimately refers to some part of a global.
192    ///
193    /// [`Load`] expressions applied to pointer-typed arguments could
194    /// refer to globals, but we leave this as `None` for them.
195    ///
196    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
197    /// [`Access`]: crate::Expression::Access
198    /// [`AccessIndex`]: crate::Expression::AccessIndex
199    /// [`Load`]: crate::Expression::Load
200    assignable_global: Option<Handle<crate::GlobalVariable>>,
201
202    /// The type of this expression.
203    pub ty: TypeResolution,
204}
205
206impl ExpressionInfo {
207    const fn new() -> Self {
208        ExpressionInfo {
209            uniformity: Uniformity::new(),
210            ref_count: 0,
211            assignable_global: None,
212            // this doesn't matter at this point, will be overwritten
213            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
214                kind: crate::ScalarKind::Bool,
215                width: 0,
216            })),
217        }
218    }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
222#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
223#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
224enum GlobalOrArgument {
225    Global(Handle<crate::GlobalVariable>),
226    Argument(u32),
227}
228
229impl GlobalOrArgument {
230    fn from_expression(
231        expression_arena: &Arena<crate::Expression>,
232        expression: Handle<crate::Expression>,
233    ) -> Result<GlobalOrArgument, ExpressionError> {
234        Ok(match expression_arena[expression] {
235            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
236            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
237            crate::Expression::Access { base, .. }
238            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
239                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
240                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
241            },
242            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
243        })
244    }
245}
246
247#[derive(Debug, Clone, PartialEq, Eq, Hash)]
248#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
249#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
250struct Sampling {
251    image: GlobalOrArgument,
252    sampler: GlobalOrArgument,
253}
254
255#[derive(Debug, Clone)]
256#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
257#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
258pub struct FunctionInfo {
259    /// Validation flags.
260    #[allow(dead_code)]
261    flags: ValidationFlags,
262    /// Set of shader stages where calling this function is valid.
263    pub available_stages: ShaderStages,
264    /// Uniformity characteristics.
265    pub uniformity: Uniformity,
266    /// Function may kill the invocation.
267    pub may_kill: bool,
268
269    /// All pairs of (texture, sampler) globals that may be used together in
270    /// sampling operations by this function and its callees. This includes
271    /// pairings that arise when this function passes textures and samplers as
272    /// arguments to its callees.
273    ///
274    /// This table does not include uses of textures and samplers passed as
275    /// arguments to this function itself, since we do not know which globals
276    /// those will be. However, this table *is* exhaustive when computed for an
277    /// entry point function: entry points never receive textures or samplers as
278    /// arguments, so all an entry point's sampling can be reported in terms of
279    /// globals.
280    ///
281    /// The GLSL back end uses this table to construct reflection info that
282    /// clients need to construct texture-combined sampler values.
283    pub sampling_set: crate::FastHashSet<SamplingKey>,
284
285    /// How this function and its callees use this module's globals.
286    ///
287    /// This is indexed by `Handle<GlobalVariable>` indices. However,
288    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
289    /// so you can simply index this struct with a global handle to retrieve
290    /// its usage information.
291    global_uses: Box<[GlobalUse]>,
292
293    /// Information about each expression in this function's body.
294    ///
295    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
296    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
297    /// index this struct with an expression handle to retrieve its
298    /// `ExpressionInfo`.
299    expressions: Box<[ExpressionInfo]>,
300
301    /// All (texture, sampler) pairs that may be used together in sampling
302    /// operations by this function and its callees, whether they are accessed
303    /// as globals or passed as arguments.
304    ///
305    /// Participants are represented by [`GlobalVariable`] handles whenever
306    /// possible, and otherwise by indices of this function's arguments.
307    ///
308    /// When analyzing a function call, we combine this data about the callee
309    /// with the actual arguments being passed to produce the callers' own
310    /// `sampling_set` and `sampling` tables.
311    ///
312    /// [`GlobalVariable`]: crate::GlobalVariable
313    sampling: crate::FastHashSet<Sampling>,
314
315    /// Indicates that the function is using dual source blending.
316    pub dual_source_blending: bool,
317
318    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
319    /// module.
320    ///
321    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
322    /// validation.
323    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
324
325    /// Mesh shader info for this function and its callees.
326    pub mesh_shader_info: FunctionMeshShaderInfo,
327}
328
329impl FunctionInfo {
330    pub const fn global_variable_count(&self) -> usize {
331        self.global_uses.len()
332    }
333    pub const fn expression_count(&self) -> usize {
334        self.expressions.len()
335    }
336    pub fn dominates_global_use(&self, other: &Self) -> bool {
337        for (self_global_uses, other_global_uses) in
338            self.global_uses.iter().zip(other.global_uses.iter())
339        {
340            if !self_global_uses.contains(*other_global_uses) {
341                return false;
342            }
343        }
344        true
345    }
346}
347
348impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
349    type Output = GlobalUse;
350    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
351        &self.global_uses[handle.index()]
352    }
353}
354
355impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
356    type Output = ExpressionInfo;
357    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
358        &self.expressions[handle.index()]
359    }
360}
361
362/// Disruptor of the uniform control flow.
363#[derive(Clone, Copy, Debug, thiserror::Error)]
364#[cfg_attr(test, derive(PartialEq))]
365pub enum UniformityDisruptor {
366    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
367    Expression(Handle<crate::Expression>),
368    #[error("There is a Return earlier in the control flow of the function")]
369    Return,
370    #[error("There is a Discard earlier in the entry point across all called functions")]
371    Discard,
372}
373
374impl FunctionInfo {
375    /// Record a use of `expr` of the sort given by `global_use`.
376    ///
377    /// Bump `expr`'s reference count, and return its uniformity.
378    ///
379    /// If `expr` is a pointer to a global variable, or some part of
380    /// a global variable, add `global_use` to that global's set of
381    /// uses.
382    #[must_use]
383    fn add_ref_impl(
384        &mut self,
385        expr: Handle<crate::Expression>,
386        global_use: GlobalUse,
387    ) -> NonUniformResult {
388        let info = &mut self.expressions[expr.index()];
389        info.ref_count += 1;
390        // Record usage if this expression may access a global
391        if let Some(global) = info.assignable_global {
392            self.global_uses[global.index()] |= global_use;
393        }
394        info.uniformity.non_uniform_result
395    }
396
397    /// Note an entry point's use of `global` not recorded by [`ModuleInfo::process_function`].
398    ///
399    /// Most global variable usage should be recorded via [`add_ref_impl`] in the process
400    /// of expression behavior analysis by [`ModuleInfo::process_function`]. But that code
401    /// has no access to entrypoint-specific information, so interface analysis uses this
402    /// function to record global uses there (like task shader payloads).
403    ///
404    /// [`add_ref_impl`]: Self::add_ref_impl
405    pub(super) fn insert_global_use(
406        &mut self,
407        global_use: GlobalUse,
408        global: Handle<crate::GlobalVariable>,
409    ) {
410        self.global_uses[global.index()] |= global_use;
411    }
412
413    /// Record a use of `expr` for its value.
414    ///
415    /// This is used for almost all expression references. Anything
416    /// that writes to the value `expr` points to, or otherwise wants
417    /// contribute flags other than `GlobalUse::READ`, should use
418    /// `add_ref_impl` directly.
419    #[must_use]
420    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
421        self.add_ref_impl(expr, GlobalUse::READ)
422    }
423
424    /// Record a use of `expr`, and indicate which global variable it
425    /// refers to, if any.
426    ///
427    /// Bump `expr`'s reference count, and return its uniformity.
428    ///
429    /// If `expr` is a pointer to a global variable, or some part
430    /// thereof, store that global in `*assignable_global`. Leave the
431    /// global's uses unchanged.
432    ///
433    /// This is used to determine the [`assignable_global`] for
434    /// [`Access`] and [`AccessIndex`] expressions that ultimately
435    /// refer to a global variable. Those expressions don't contribute
436    /// any usage to the global themselves; that depends on how other
437    /// expressions use them.
438    ///
439    /// [`assignable_global`]: ExpressionInfo::assignable_global
440    /// [`Access`]: crate::Expression::Access
441    /// [`AccessIndex`]: crate::Expression::AccessIndex
442    #[must_use]
443    fn add_assignable_ref(
444        &mut self,
445        expr: Handle<crate::Expression>,
446        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
447    ) -> NonUniformResult {
448        let info = &mut self.expressions[expr.index()];
449        info.ref_count += 1;
450        // propagate the assignable global up the chain, till it either hits
451        // a value-type expression, or the assignment statement.
452        if let Some(global) = info.assignable_global {
453            if let Some(_old) = assignable_global.replace(global) {
454                unreachable!()
455            }
456        }
457        info.uniformity.non_uniform_result
458    }
459
460    /// Inherit information from a called function.
461    fn process_call(
462        &mut self,
463        callee: &Self,
464        arguments: &[Handle<crate::Expression>],
465        expression_arena: &Arena<crate::Expression>,
466    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
467        self.sampling_set
468            .extend(callee.sampling_set.iter().cloned());
469        for sampling in callee.sampling.iter() {
470            // If the callee was passed the texture or sampler as an argument,
471            // we may now be able to determine which globals those referred to.
472            let image_storage = match sampling.image {
473                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
474                GlobalOrArgument::Argument(i) => {
475                    let Some(handle) = arguments.get(i as usize).cloned() else {
476                        // Argument count mismatch, will be reported later by validate_call
477                        break;
478                    };
479                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
480                        |source| {
481                            FunctionError::Expression { handle, source }
482                                .with_span_handle(handle, expression_arena)
483                        },
484                    )?
485                }
486            };
487
488            let sampler_storage = match sampling.sampler {
489                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
490                GlobalOrArgument::Argument(i) => {
491                    let Some(handle) = arguments.get(i as usize).cloned() else {
492                        // Argument count mismatch, will be reported later by validate_call
493                        break;
494                    };
495                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
496                        |source| {
497                            FunctionError::Expression { handle, source }
498                                .with_span_handle(handle, expression_arena)
499                        },
500                    )?
501                }
502            };
503
504            // If we've managed to pin both the image and sampler down to
505            // specific globals, record that in our `sampling_set`. Otherwise,
506            // record as much as we do know in our own `sampling` table, for our
507            // callers to sort out.
508            match (image_storage, sampler_storage) {
509                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
510                    self.sampling_set.insert(SamplingKey { image, sampler });
511                }
512                (image, sampler) => {
513                    self.sampling.insert(Sampling { image, sampler });
514                }
515            }
516        }
517
518        // Inherit global use from our callees.
519        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
520            *mine |= *other;
521        }
522
523        // Inherit mesh output types from our callees.
524        self.try_update_mesh_info(&callee.mesh_shader_info)?;
525
526        Ok(FunctionUniformity {
527            result: callee.uniformity.clone(),
528            exit: if callee.may_kill {
529                ExitFlags::MAY_KILL
530            } else {
531                ExitFlags::empty()
532            },
533        })
534    }
535
536    /// Compute the [`ExpressionInfo`] for `handle`.
537    ///
538    /// Replace the dummy entry in [`self.expressions`] for `handle`
539    /// with a real `ExpressionInfo` value describing that expression.
540    ///
541    /// This function is called as part of a forward sweep through the
542    /// arena, so we can assume that all earlier expressions in the
543    /// arena already have valid info. Since expressions only depend
544    /// on earlier expressions, this includes all our subexpressions.
545    ///
546    /// Adjust the reference counts on all expressions we use.
547    ///
548    /// Also populate the [`sampling_set`], [`sampling`] and
549    /// [`global_uses`] fields of `self`.
550    ///
551    /// [`self.expressions`]: FunctionInfo::expressions
552    /// [`sampling_set`]: FunctionInfo::sampling_set
553    /// [`sampling`]: FunctionInfo::sampling
554    /// [`global_uses`]: FunctionInfo::global_uses
555    #[allow(clippy::or_fun_call)]
556    fn process_expression(
557        &mut self,
558        handle: Handle<crate::Expression>,
559        expression_arena: &Arena<crate::Expression>,
560        other_functions: &[FunctionInfo],
561        resolve_context: &ResolveContext,
562        capabilities: super::Capabilities,
563    ) -> Result<(), ExpressionError> {
564        use crate::{Expression as E, SampleLevel as Sl};
565
566        let expression = &expression_arena[handle];
567        let mut assignable_global = None;
568        let uniformity = match *expression {
569            E::Access { base, index } => {
570                let base_ty = self[base].ty.inner_with(resolve_context.types);
571
572                // build up the caps needed if this is indexed non-uniformly
573                let mut needed_caps = super::Capabilities::empty();
574                let is_binding_array = match *base_ty {
575                    crate::TypeInner::BindingArray {
576                        base: array_element_ty_handle,
577                        ..
578                    } => {
579                        // these are nasty aliases, but these idents are too long and break rustfmt
580                        let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
581                        let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
582                        let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
583                        let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
584
585                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
586                        let array_element_ty =
587                            &resolve_context.types[array_element_ty_handle].inner;
588
589                        needed_caps |= match *array_element_ty {
590                            // If we're an image, use the appropriate limit.
591                            crate::TypeInner::Image { class, .. } => match class {
592                                crate::ImageClass::Storage { .. } => sto,
593                                _ => st_sb,
594                            },
595                            crate::TypeInner::Sampler { .. } => sampler,
596                            // If we're anything but an image, assume we're a buffer and use the address space.
597                            _ => {
598                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
599                                    let global = &resolve_context.global_vars[global_handle];
600                                    match global.space {
601                                        crate::AddressSpace::Uniform => uni,
602                                        crate::AddressSpace::Storage { .. } => st_sb,
603                                        _ => unreachable!(),
604                                    }
605                                } else {
606                                    unreachable!()
607                                }
608                            }
609                        };
610
611                        true
612                    }
613                    _ => false,
614                };
615
616                if self[index].uniformity.non_uniform_result.is_some()
617                    && !capabilities.contains(needed_caps)
618                    && is_binding_array
619                {
620                    return Err(ExpressionError::MissingCapabilities(needed_caps));
621                }
622
623                Uniformity {
624                    non_uniform_result: self
625                        .add_assignable_ref(base, &mut assignable_global)
626                        .or(self.add_ref(index)),
627                    requirements: UniformityRequirements::empty(),
628                }
629            }
630            E::AccessIndex { base, .. } => Uniformity {
631                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
632                requirements: UniformityRequirements::empty(),
633            },
634            // always uniform
635            E::Splat { size: _, value } => Uniformity {
636                non_uniform_result: self.add_ref(value),
637                requirements: UniformityRequirements::empty(),
638            },
639            E::Swizzle { vector, .. } => Uniformity {
640                non_uniform_result: self.add_ref(vector),
641                requirements: UniformityRequirements::empty(),
642            },
643            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
644            E::Compose { ref components, .. } => {
645                let non_uniform_result = components
646                    .iter()
647                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
648                Uniformity {
649                    non_uniform_result,
650                    requirements: UniformityRequirements::empty(),
651                }
652            }
653            // depends on the builtin
654            E::FunctionArgument(index) => {
655                let arg = &resolve_context.arguments[index as usize];
656                let uniform = match arg.binding {
657                    Some(crate::Binding::BuiltIn(
658                        // per-work-group built-ins are uniform
659                        crate::BuiltIn::WorkGroupId
660                        | crate::BuiltIn::WorkGroupSize
661                        | crate::BuiltIn::NumWorkGroups,
662                    )) => true,
663                    _ => false,
664                };
665                Uniformity {
666                    non_uniform_result: if uniform { None } else { Some(handle) },
667                    requirements: UniformityRequirements::empty(),
668                }
669            }
670            // depends on the address space
671            E::GlobalVariable(gh) => {
672                use crate::AddressSpace as As;
673                assignable_global = Some(gh);
674                let var = &resolve_context.global_vars[gh];
675                let uniform = match var.space {
676                    // local data is non-uniform
677                    As::Function | As::Private => false,
678                    // workgroup memory is exclusively accessed by the group
679                    // task payload memory is very similar to workgroup memory
680                    As::WorkGroup | As::TaskPayload => true,
681                    // uniform data
682                    As::Uniform | As::PushConstant => true,
683                    // storage data is only uniform when read-only
684                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
685                    As::Handle => false,
686                };
687                Uniformity {
688                    non_uniform_result: if uniform { None } else { Some(handle) },
689                    requirements: UniformityRequirements::empty(),
690                }
691            }
692            E::LocalVariable(_) => Uniformity {
693                non_uniform_result: Some(handle),
694                requirements: UniformityRequirements::empty(),
695            },
696            E::Load { pointer } => Uniformity {
697                non_uniform_result: self.add_ref(pointer),
698                requirements: UniformityRequirements::empty(),
699            },
700            E::ImageSample {
701                image,
702                sampler,
703                gather: _,
704                coordinate,
705                array_index,
706                offset,
707                level,
708                depth_ref,
709                clamp_to_edge: _,
710            } => {
711                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
712                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
713
714                match (image_storage, sampler_storage) {
715                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
716                        self.sampling_set.insert(SamplingKey { image, sampler });
717                    }
718                    _ => {
719                        self.sampling.insert(Sampling {
720                            image: image_storage,
721                            sampler: sampler_storage,
722                        });
723                    }
724                }
725
726                // "nur" == "Non-Uniform Result"
727                let array_nur = array_index.and_then(|h| self.add_ref(h));
728                let level_nur = match level {
729                    Sl::Auto | Sl::Zero => None,
730                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
731                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
732                };
733                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
734                let offset_nur = offset.and_then(|h| self.add_ref(h));
735                Uniformity {
736                    non_uniform_result: self
737                        .add_ref(image)
738                        .or(self.add_ref(sampler))
739                        .or(self.add_ref(coordinate))
740                        .or(array_nur)
741                        .or(level_nur)
742                        .or(dref_nur)
743                        .or(offset_nur),
744                    requirements: if level.implicit_derivatives() {
745                        UniformityRequirements::IMPLICIT_LEVEL
746                    } else {
747                        UniformityRequirements::empty()
748                    },
749                }
750            }
751            E::ImageLoad {
752                image,
753                coordinate,
754                array_index,
755                sample,
756                level,
757            } => {
758                let array_nur = array_index.and_then(|h| self.add_ref(h));
759                let sample_nur = sample.and_then(|h| self.add_ref(h));
760                let level_nur = level.and_then(|h| self.add_ref(h));
761                Uniformity {
762                    non_uniform_result: self
763                        .add_ref(image)
764                        .or(self.add_ref(coordinate))
765                        .or(array_nur)
766                        .or(sample_nur)
767                        .or(level_nur),
768                    requirements: UniformityRequirements::empty(),
769                }
770            }
771            E::ImageQuery { image, query } => {
772                let query_nur = match query {
773                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
774                    _ => None,
775                };
776                Uniformity {
777                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
778                    requirements: UniformityRequirements::empty(),
779                }
780            }
781            E::Unary { expr, .. } => Uniformity {
782                non_uniform_result: self.add_ref(expr),
783                requirements: UniformityRequirements::empty(),
784            },
785            E::Binary { left, right, .. } => Uniformity {
786                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
787                requirements: UniformityRequirements::empty(),
788            },
789            E::Select {
790                condition,
791                accept,
792                reject,
793            } => Uniformity {
794                non_uniform_result: self
795                    .add_ref(condition)
796                    .or(self.add_ref(accept))
797                    .or(self.add_ref(reject)),
798                requirements: UniformityRequirements::empty(),
799            },
800            // explicit derivatives require uniform
801            E::Derivative { expr, .. } => Uniformity {
802                //Note: taking a derivative of a uniform doesn't make it non-uniform
803                non_uniform_result: self.add_ref(expr),
804                requirements: UniformityRequirements::DERIVATIVE,
805            },
806            E::Relational { argument, .. } => Uniformity {
807                non_uniform_result: self.add_ref(argument),
808                requirements: UniformityRequirements::empty(),
809            },
810            E::Math {
811                fun: _,
812                arg,
813                arg1,
814                arg2,
815                arg3,
816            } => {
817                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
818                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
819                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
820                Uniformity {
821                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
822                    requirements: UniformityRequirements::empty(),
823                }
824            }
825            E::As { expr, .. } => Uniformity {
826                non_uniform_result: self.add_ref(expr),
827                requirements: UniformityRequirements::empty(),
828            },
829            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
830            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
831                non_uniform_result: Some(handle),
832                requirements: UniformityRequirements::empty(),
833            },
834            E::WorkGroupUniformLoadResult { .. } => Uniformity {
835                // The result of WorkGroupUniformLoad is always uniform by definition
836                non_uniform_result: None,
837                // The call is what cares about uniformity, not the expression
838                // This expression is never emitted, so this requirement should never be used anyway?
839                requirements: UniformityRequirements::empty(),
840            },
841            E::ArrayLength(expr) => Uniformity {
842                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
843                requirements: UniformityRequirements::empty(),
844            },
845            E::RayQueryGetIntersection {
846                query,
847                committed: _,
848            } => Uniformity {
849                non_uniform_result: self.add_ref(query),
850                requirements: UniformityRequirements::empty(),
851            },
852            E::SubgroupBallotResult => Uniformity {
853                non_uniform_result: Some(handle),
854                requirements: UniformityRequirements::empty(),
855            },
856            E::SubgroupOperationResult { .. } => Uniformity {
857                non_uniform_result: Some(handle),
858                requirements: UniformityRequirements::empty(),
859            },
860            E::RayQueryVertexPositions {
861                query,
862                committed: _,
863            } => Uniformity {
864                non_uniform_result: self.add_ref(query),
865                requirements: UniformityRequirements::empty(),
866            },
867        };
868
869        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
870        self.expressions[handle.index()] = ExpressionInfo {
871            uniformity,
872            ref_count: 0,
873            assignable_global,
874            ty,
875        };
876        Ok(())
877    }
878
879    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
880    /// Returns the uniformity characteristics at the *function* level, i.e.
881    /// whether or not the function requires to be called in uniform control flow,
882    /// and whether the produced result is not disrupting the control flow.
883    ///
884    /// The parent control flow is uniform if `disruptor.is_none()`.
885    ///
886    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
887    /// require uniformity, but the current flow is non-uniform.
888    #[allow(clippy::or_fun_call)]
889    fn process_block(
890        &mut self,
891        statements: &crate::Block,
892        other_functions: &[FunctionInfo],
893        mut disruptor: Option<UniformityDisruptor>,
894        expression_arena: &Arena<crate::Expression>,
895        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
896    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
897        use crate::Statement as S;
898
899        let mut combined_uniformity = FunctionUniformity::new();
900        for statement in statements {
901            let uniformity = match *statement {
902                S::Emit(ref range) => {
903                    let mut requirements = UniformityRequirements::empty();
904                    for expr in range.clone() {
905                        let req = self.expressions[expr.index()].uniformity.requirements;
906                        if self
907                            .flags
908                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
909                            && !req.is_empty()
910                        {
911                            if let Some(cause) = disruptor {
912                                let severity = DiagnosticFilterNode::search(
913                                    self.diagnostic_filter_leaf,
914                                    diagnostic_filter_arena,
915                                    StandardFilterableTriggeringRule::DerivativeUniformity,
916                                );
917                                severity.report_diag(
918                                    FunctionError::NonUniformControlFlow(req, expr, cause)
919                                        .with_span_handle(expr, expression_arena),
920                                    // TODO: Yes, this isn't contextualized with source, because
921                                    // the user is supposed to render what would normally be an
922                                    // error here. Once we actually support warning-level
923                                    // diagnostic items, then we won't need this non-compliant hack:
924                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
925                                    |e, level| log::log!(level, "{e}"),
926                                )?;
927                            }
928                        }
929                        requirements |= req;
930                    }
931                    FunctionUniformity {
932                        result: Uniformity {
933                            non_uniform_result: None,
934                            requirements,
935                        },
936                        exit: ExitFlags::empty(),
937                    }
938                }
939                S::Break | S::Continue => FunctionUniformity::new(),
940                S::Kill => FunctionUniformity {
941                    result: Uniformity::new(),
942                    exit: if disruptor.is_some() {
943                        ExitFlags::MAY_KILL
944                    } else {
945                        ExitFlags::empty()
946                    },
947                },
948                S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
949                    result: Uniformity {
950                        non_uniform_result: None,
951                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
952                    },
953                    exit: ExitFlags::empty(),
954                },
955                S::WorkGroupUniformLoad { pointer, .. } => {
956                    let _condition_nur = self.add_ref(pointer);
957
958                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
959                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
960                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
961                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
962
963                    /*
964                    if self
965                        .flags
966                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
967                    {
968                        let condition_nur = self.add_ref(pointer);
969                        let this_disruptor =
970                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
971                        if let Some(cause) = this_disruptor {
972                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
973                                .with_span_static(*span, "WorkGroupUniformLoad"));
974                        }
975                    } */
976                    FunctionUniformity {
977                        result: Uniformity {
978                            non_uniform_result: None,
979                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
980                        },
981                        exit: ExitFlags::empty(),
982                    }
983                }
984                S::Block(ref b) => self.process_block(
985                    b,
986                    other_functions,
987                    disruptor,
988                    expression_arena,
989                    diagnostic_filter_arena,
990                )?,
991                S::If {
992                    condition,
993                    ref accept,
994                    ref reject,
995                } => {
996                    let condition_nur = self.add_ref(condition);
997                    let branch_disruptor =
998                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
999                    let accept_uniformity = self.process_block(
1000                        accept,
1001                        other_functions,
1002                        branch_disruptor,
1003                        expression_arena,
1004                        diagnostic_filter_arena,
1005                    )?;
1006                    let reject_uniformity = self.process_block(
1007                        reject,
1008                        other_functions,
1009                        branch_disruptor,
1010                        expression_arena,
1011                        diagnostic_filter_arena,
1012                    )?;
1013                    accept_uniformity | reject_uniformity
1014                }
1015                S::Switch {
1016                    selector,
1017                    ref cases,
1018                } => {
1019                    let selector_nur = self.add_ref(selector);
1020                    let branch_disruptor =
1021                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
1022                    let mut uniformity = FunctionUniformity::new();
1023                    let mut case_disruptor = branch_disruptor;
1024                    for case in cases.iter() {
1025                        let case_uniformity = self.process_block(
1026                            &case.body,
1027                            other_functions,
1028                            case_disruptor,
1029                            expression_arena,
1030                            diagnostic_filter_arena,
1031                        )?;
1032                        case_disruptor = if case.fall_through {
1033                            case_disruptor.or(case_uniformity.exit_disruptor())
1034                        } else {
1035                            branch_disruptor
1036                        };
1037                        uniformity = uniformity | case_uniformity;
1038                    }
1039                    uniformity
1040                }
1041                S::Loop {
1042                    ref body,
1043                    ref continuing,
1044                    break_if,
1045                } => {
1046                    let body_uniformity = self.process_block(
1047                        body,
1048                        other_functions,
1049                        disruptor,
1050                        expression_arena,
1051                        diagnostic_filter_arena,
1052                    )?;
1053                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1054                    let continuing_uniformity = self.process_block(
1055                        continuing,
1056                        other_functions,
1057                        continuing_disruptor,
1058                        expression_arena,
1059                        diagnostic_filter_arena,
1060                    )?;
1061                    if let Some(expr) = break_if {
1062                        let _ = self.add_ref(expr);
1063                    }
1064                    body_uniformity | continuing_uniformity
1065                }
1066                S::Return { value } => FunctionUniformity {
1067                    result: Uniformity {
1068                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1069                        requirements: UniformityRequirements::empty(),
1070                    },
1071                    exit: if disruptor.is_some() {
1072                        ExitFlags::MAY_RETURN
1073                    } else {
1074                        ExitFlags::empty()
1075                    },
1076                },
1077                // Here and below, the used expressions are already emitted,
1078                // and their results do not affect the function return value,
1079                // so we can ignore their non-uniformity.
1080                S::Store { pointer, value } => {
1081                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1082                    let _ = self.add_ref(value);
1083                    FunctionUniformity::new()
1084                }
1085                S::ImageStore {
1086                    image,
1087                    coordinate,
1088                    array_index,
1089                    value,
1090                } => {
1091                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1092                    if let Some(expr) = array_index {
1093                        let _ = self.add_ref(expr);
1094                    }
1095                    let _ = self.add_ref(coordinate);
1096                    let _ = self.add_ref(value);
1097                    FunctionUniformity::new()
1098                }
1099                S::Call {
1100                    function,
1101                    ref arguments,
1102                    result: _,
1103                } => {
1104                    for &argument in arguments {
1105                        let _ = self.add_ref(argument);
1106                    }
1107                    let info = &other_functions[function.index()];
1108                    //Note: the result is validated by the Validator, not here
1109                    self.process_call(info, arguments, expression_arena)?
1110                }
1111                S::Atomic {
1112                    pointer,
1113                    ref fun,
1114                    value,
1115                    result: _,
1116                } => {
1117                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1118                    let _ = self.add_ref(value);
1119                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1120                        let _ = self.add_ref(cmp);
1121                    }
1122                    FunctionUniformity::new()
1123                }
1124                S::ImageAtomic {
1125                    image,
1126                    coordinate,
1127                    array_index,
1128                    fun: _,
1129                    value,
1130                } => {
1131                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1132                    let _ = self.add_ref(coordinate);
1133                    if let Some(expr) = array_index {
1134                        let _ = self.add_ref(expr);
1135                    }
1136                    let _ = self.add_ref(value);
1137                    FunctionUniformity::new()
1138                }
1139                S::RayQuery { query, ref fun } => {
1140                    let _ = self.add_ref(query);
1141                    match *fun {
1142                        crate::RayQueryFunction::Initialize {
1143                            acceleration_structure,
1144                            descriptor,
1145                        } => {
1146                            let _ = self.add_ref(acceleration_structure);
1147                            let _ = self.add_ref(descriptor);
1148                        }
1149                        crate::RayQueryFunction::Proceed { result: _ } => {}
1150                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1151                            let _ = self.add_ref(hit_t);
1152                        }
1153                        crate::RayQueryFunction::ConfirmIntersection => {}
1154                        crate::RayQueryFunction::Terminate => {}
1155                    }
1156                    FunctionUniformity::new()
1157                }
1158                S::MeshFunction(func) => {
1159                    self.available_stages |= ShaderStages::MESH;
1160                    match &func {
1161                        // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it.
1162                        &crate::MeshFunction::SetMeshOutputs {
1163                            vertex_count,
1164                            primitive_count,
1165                        } => {
1166                            let _ = self.add_ref(vertex_count);
1167                            let _ = self.add_ref(primitive_count);
1168                            FunctionUniformity::new()
1169                        }
1170                        &crate::MeshFunction::SetVertex { index, value }
1171                        | &crate::MeshFunction::SetPrimitive { index, value } => {
1172                            let _ = self.add_ref(index);
1173                            let _ = self.add_ref(value);
1174                            let ty = self.expressions[value.index()].ty.handle().ok_or(
1175                                FunctionError::InvalidMeshShaderOutputType(value).with_span(),
1176                            )?;
1177
1178                            if matches!(func, crate::MeshFunction::SetVertex { .. }) {
1179                                self.try_update_mesh_vertex_type(ty, value)?;
1180                            } else {
1181                                self.try_update_mesh_primitive_type(ty, value)?;
1182                            };
1183
1184                            FunctionUniformity::new()
1185                        }
1186                    }
1187                }
1188                S::SubgroupBallot {
1189                    result: _,
1190                    predicate,
1191                } => {
1192                    if let Some(predicate) = predicate {
1193                        let _ = self.add_ref(predicate);
1194                    }
1195                    FunctionUniformity::new()
1196                }
1197                S::SubgroupCollectiveOperation {
1198                    op: _,
1199                    collective_op: _,
1200                    argument,
1201                    result: _,
1202                } => {
1203                    let _ = self.add_ref(argument);
1204                    FunctionUniformity::new()
1205                }
1206                S::SubgroupGather {
1207                    mode,
1208                    argument,
1209                    result: _,
1210                } => {
1211                    let _ = self.add_ref(argument);
1212                    match mode {
1213                        crate::GatherMode::BroadcastFirst => {}
1214                        crate::GatherMode::Broadcast(index)
1215                        | crate::GatherMode::Shuffle(index)
1216                        | crate::GatherMode::ShuffleDown(index)
1217                        | crate::GatherMode::ShuffleUp(index)
1218                        | crate::GatherMode::ShuffleXor(index)
1219                        | crate::GatherMode::QuadBroadcast(index) => {
1220                            let _ = self.add_ref(index);
1221                        }
1222                        crate::GatherMode::QuadSwap(_) => {}
1223                    }
1224                    FunctionUniformity::new()
1225                }
1226            };
1227
1228            disruptor = disruptor.or(uniformity.exit_disruptor());
1229            combined_uniformity = combined_uniformity | uniformity;
1230        }
1231        Ok(combined_uniformity)
1232    }
1233
1234    /// Note the type of value passed to [`SetVertex`].
1235    ///
1236    /// Record that this function passed a value of type `ty` as the second
1237    /// argument to the [`SetVertex`] builtin function. All calls to
1238    /// `SetVertex` must pass the same type, and this must match the
1239    /// function's [`vertex_output_type`].
1240    ///
1241    /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex
1242    /// [`vertex_output_type`]: crate::ir::MeshStageInfo::vertex_output_type
1243    fn try_update_mesh_vertex_type(
1244        &mut self,
1245        ty: Handle<crate::Type>,
1246        value: Handle<crate::Expression>,
1247    ) -> Result<(), WithSpan<FunctionError>> {
1248        if let &Some(ref existing) = &self.mesh_shader_info.vertex_type {
1249            if existing.0 != ty {
1250                return Err(
1251                    FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span()
1252                );
1253            }
1254        } else {
1255            self.mesh_shader_info.vertex_type = Some((ty, value));
1256        }
1257        Ok(())
1258    }
1259
1260    /// Note the type of value passed to [`SetPrimitive`].
1261    ///
1262    /// Record that this function passed a value of type `ty` as the second
1263    /// argument to the [`SetPrimitive`] builtin function. All calls to
1264    /// `SetPrimitive` must pass the same type, and this must match the
1265    /// function's [`primitive_output_type`].
1266    ///
1267    /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive
1268    /// [`primitive_output_type`]: crate::ir::MeshStageInfo::primitive_output_type
1269    fn try_update_mesh_primitive_type(
1270        &mut self,
1271        ty: Handle<crate::Type>,
1272        value: Handle<crate::Expression>,
1273    ) -> Result<(), WithSpan<FunctionError>> {
1274        if let &Some(ref existing) = &self.mesh_shader_info.primitive_type {
1275            if existing.0 != ty {
1276                return Err(
1277                    FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span()
1278                );
1279            }
1280        } else {
1281            self.mesh_shader_info.primitive_type = Some((ty, value));
1282        }
1283        Ok(())
1284    }
1285
1286    /// Update this function's mesh shader info, given that it calls `callee`.
1287    fn try_update_mesh_info(
1288        &mut self,
1289        callee: &FunctionMeshShaderInfo,
1290    ) -> Result<(), WithSpan<FunctionError>> {
1291        if let &Some(ref other_vertex) = &callee.vertex_type {
1292            self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?;
1293        }
1294        if let &Some(ref other_primitive) = &callee.primitive_type {
1295            self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?;
1296        }
1297        Ok(())
1298    }
1299}
1300
1301impl ModuleInfo {
1302    /// Populates `self.const_expression_types`
1303    pub(super) fn process_const_expression(
1304        &mut self,
1305        handle: Handle<crate::Expression>,
1306        resolve_context: &ResolveContext,
1307        gctx: crate::proc::GlobalCtx,
1308    ) -> Result<(), super::ConstExpressionError> {
1309        self.const_expression_types[handle.index()] =
1310            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1311        Ok(())
1312    }
1313
1314    /// Builds the `FunctionInfo` based on the function, and validates the
1315    /// uniform control flow if required by the expressions of this function.
1316    pub(super) fn process_function(
1317        &self,
1318        fun: &crate::Function,
1319        module: &crate::Module,
1320        flags: ValidationFlags,
1321        capabilities: super::Capabilities,
1322    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1323        let mut info = FunctionInfo {
1324            flags,
1325            available_stages: ShaderStages::all(),
1326            uniformity: Uniformity::new(),
1327            may_kill: false,
1328            sampling_set: crate::FastHashSet::default(),
1329            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1330            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1331            sampling: crate::FastHashSet::default(),
1332            dual_source_blending: false,
1333            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1334            mesh_shader_info: FunctionMeshShaderInfo::default(),
1335        };
1336        let resolve_context =
1337            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1338
1339        for (handle, _) in fun.expressions.iter() {
1340            if let Err(source) = info.process_expression(
1341                handle,
1342                &fun.expressions,
1343                &self.functions,
1344                &resolve_context,
1345                capabilities,
1346            ) {
1347                return Err(FunctionError::Expression { handle, source }
1348                    .with_span_handle(handle, &fun.expressions));
1349            }
1350        }
1351
1352        for (_, expr) in fun.local_variables.iter() {
1353            if let Some(init) = expr.init {
1354                let _ = info.add_ref(init);
1355            }
1356        }
1357
1358        let uniformity = info.process_block(
1359            &fun.body,
1360            &self.functions,
1361            None,
1362            &fun.expressions,
1363            &module.diagnostic_filters,
1364        )?;
1365        info.uniformity = uniformity.result;
1366        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1367
1368        // If there are any globals referenced directly by a named expression,
1369        // ensure they are marked as used even if they are not referenced
1370        // anywhere else. An important case where this matters is phony
1371        // assignments used to include a global in the shader's resource
1372        // interface. https://www.w3.org/TR/WGSL/#phony-assignment-section
1373        for &handle in fun.named_expressions.keys() {
1374            if let Some(global) = info[handle].assignable_global {
1375                if info.global_uses[global.index()].is_empty() {
1376                    info.global_uses[global.index()] = GlobalUse::QUERY;
1377                }
1378            }
1379        }
1380
1381        Ok(info)
1382    }
1383
1384    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1385        &self.entry_points[index]
1386    }
1387}
1388
1389#[test]
1390fn uniform_control_flow() {
1391    use crate::{Expression as E, Statement as S};
1392
1393    let mut type_arena = crate::UniqueArena::new();
1394    let ty = type_arena.insert(
1395        crate::Type {
1396            name: None,
1397            inner: crate::TypeInner::Vector {
1398                size: crate::VectorSize::Bi,
1399                scalar: crate::Scalar::F32,
1400            },
1401        },
1402        Default::default(),
1403    );
1404    let mut global_var_arena = Arena::new();
1405    let non_uniform_global = global_var_arena.append(
1406        crate::GlobalVariable {
1407            name: None,
1408            init: None,
1409            ty,
1410            space: crate::AddressSpace::Handle,
1411            binding: None,
1412        },
1413        Default::default(),
1414    );
1415    let uniform_global = global_var_arena.append(
1416        crate::GlobalVariable {
1417            name: None,
1418            init: None,
1419            ty,
1420            binding: None,
1421            space: crate::AddressSpace::Uniform,
1422        },
1423        Default::default(),
1424    );
1425
1426    let mut expressions = Arena::new();
1427    // checks the uniform control flow
1428    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1429    // checks the non-uniform control flow
1430    let derivative_expr = expressions.append(
1431        E::Derivative {
1432            axis: crate::DerivativeAxis::X,
1433            ctrl: crate::DerivativeControl::None,
1434            expr: constant_expr,
1435        },
1436        Default::default(),
1437    );
1438    let emit_range_constant_derivative = expressions.range_from(0);
1439    let non_uniform_global_expr =
1440        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1441    let uniform_global_expr =
1442        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1443    let emit_range_globals = expressions.range_from(2);
1444
1445    // checks the QUERY flag
1446    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1447    // checks the transitive WRITE flag
1448    let access_expr = expressions.append(
1449        E::AccessIndex {
1450            base: non_uniform_global_expr,
1451            index: 1,
1452        },
1453        Default::default(),
1454    );
1455    let emit_range_query_access_globals = expressions.range_from(2);
1456
1457    let mut info = FunctionInfo {
1458        flags: ValidationFlags::all(),
1459        available_stages: ShaderStages::all(),
1460        uniformity: Uniformity::new(),
1461        may_kill: false,
1462        sampling_set: crate::FastHashSet::default(),
1463        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1464        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1465        sampling: crate::FastHashSet::default(),
1466        dual_source_blending: false,
1467        diagnostic_filter_leaf: None,
1468        mesh_shader_info: FunctionMeshShaderInfo::default(),
1469    };
1470    let resolve_context = ResolveContext {
1471        constants: &Arena::new(),
1472        overrides: &Arena::new(),
1473        types: &type_arena,
1474        special_types: &crate::SpecialTypes::default(),
1475        global_vars: &global_var_arena,
1476        local_vars: &Arena::new(),
1477        functions: &Arena::new(),
1478        arguments: &[],
1479    };
1480    for (handle, _) in expressions.iter() {
1481        info.process_expression(
1482            handle,
1483            &expressions,
1484            &[],
1485            &resolve_context,
1486            super::Capabilities::empty(),
1487        )
1488        .unwrap();
1489    }
1490    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1491    assert_eq!(info[uniform_global_expr].ref_count, 1);
1492    assert_eq!(info[query_expr].ref_count, 0);
1493    assert_eq!(info[access_expr].ref_count, 0);
1494    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1495    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1496
1497    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1498    let stmt_if_uniform = S::If {
1499        condition: uniform_global_expr,
1500        accept: crate::Block::new(),
1501        reject: vec![
1502            S::Emit(emit_range_constant_derivative.clone()),
1503            S::Store {
1504                pointer: constant_expr,
1505                value: derivative_expr,
1506            },
1507        ]
1508        .into(),
1509    };
1510    assert_eq!(
1511        info.process_block(
1512            &vec![stmt_emit1, stmt_if_uniform].into(),
1513            &[],
1514            None,
1515            &expressions,
1516            &Arena::new(),
1517        ),
1518        Ok(FunctionUniformity {
1519            result: Uniformity {
1520                non_uniform_result: None,
1521                requirements: UniformityRequirements::DERIVATIVE,
1522            },
1523            exit: ExitFlags::empty(),
1524        }),
1525    );
1526    assert_eq!(info[constant_expr].ref_count, 2);
1527    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1528
1529    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1530    let stmt_if_non_uniform = S::If {
1531        condition: non_uniform_global_expr,
1532        accept: vec![
1533            S::Emit(emit_range_constant_derivative),
1534            S::Store {
1535                pointer: constant_expr,
1536                value: derivative_expr,
1537            },
1538        ]
1539        .into(),
1540        reject: crate::Block::new(),
1541    };
1542    {
1543        let block_info = info.process_block(
1544            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1545            &[],
1546            None,
1547            &expressions,
1548            &Arena::new(),
1549        );
1550        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1551            assert_eq!(info[derivative_expr].ref_count, 2);
1552        } else {
1553            assert_eq!(
1554                block_info,
1555                Err(FunctionError::NonUniformControlFlow(
1556                    UniformityRequirements::DERIVATIVE,
1557                    derivative_expr,
1558                    UniformityDisruptor::Expression(non_uniform_global_expr)
1559                )
1560                .with_span()),
1561            );
1562            assert_eq!(info[derivative_expr].ref_count, 1);
1563
1564            // Test that the same thing passes when we disable the `derivative_uniformity`
1565            let mut diagnostic_filters = Arena::new();
1566            let diagnostic_filter_leaf = diagnostic_filters.append(
1567                DiagnosticFilterNode {
1568                    inner: crate::diagnostic_filter::DiagnosticFilter {
1569                        new_severity: crate::diagnostic_filter::Severity::Off,
1570                        triggering_rule:
1571                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1572                                StandardFilterableTriggeringRule::DerivativeUniformity,
1573                            ),
1574                    },
1575                    parent: None,
1576                },
1577                crate::Span::default(),
1578            );
1579            let mut info = FunctionInfo {
1580                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1581                ..info.clone()
1582            };
1583
1584            let block_info = info.process_block(
1585                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1586                &[],
1587                None,
1588                &expressions,
1589                &diagnostic_filters,
1590            );
1591            assert_eq!(
1592                block_info,
1593                Ok(FunctionUniformity {
1594                    result: Uniformity {
1595                        non_uniform_result: None,
1596                        requirements: UniformityRequirements::DERIVATIVE,
1597                    },
1598                    exit: ExitFlags::empty()
1599                }),
1600            );
1601            assert_eq!(info[derivative_expr].ref_count, 2);
1602        }
1603    }
1604    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1605
1606    let stmt_emit3 = S::Emit(emit_range_globals);
1607    let stmt_return_non_uniform = S::Return {
1608        value: Some(non_uniform_global_expr),
1609    };
1610    assert_eq!(
1611        info.process_block(
1612            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1613            &[],
1614            Some(UniformityDisruptor::Return),
1615            &expressions,
1616            &Arena::new(),
1617        ),
1618        Ok(FunctionUniformity {
1619            result: Uniformity {
1620                non_uniform_result: Some(non_uniform_global_expr),
1621                requirements: UniformityRequirements::empty(),
1622            },
1623            exit: ExitFlags::MAY_RETURN,
1624        }),
1625    );
1626    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1627
1628    // Check that uniformity requirements reach through a pointer
1629    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1630    let stmt_assign = S::Store {
1631        pointer: access_expr,
1632        value: query_expr,
1633    };
1634    let stmt_return_pointer = S::Return {
1635        value: Some(access_expr),
1636    };
1637    let stmt_kill = S::Kill;
1638    assert_eq!(
1639        info.process_block(
1640            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1641            &[],
1642            Some(UniformityDisruptor::Discard),
1643            &expressions,
1644            &Arena::new(),
1645        ),
1646        Ok(FunctionUniformity {
1647            result: Uniformity {
1648                non_uniform_result: Some(non_uniform_global_expr),
1649                requirements: UniformityRequirements::empty(),
1650            },
1651            exit: ExitFlags::all(),
1652        }),
1653    );
1654    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1655}