naga/valid/
handles.rs

1//! Implementation of `Validator::validate_module_handles`.
2
3use core::{convert::TryInto, hash::Hash};
4
5use super::{TypeError, ValidationError};
6use crate::non_max_u32::NonMaxU32;
7use crate::{
8    arena::{BadHandle, BadRangeError},
9    diagnostic_filter::DiagnosticFilterNode,
10    EntryPoint, Handle,
11};
12use crate::{Arena, UniqueArena};
13
14use alloc::string::ToString;
15
16impl super::Validator {
17    /// Validates that all handles within `module` are:
18    ///
19    /// * Valid, in the sense that they contain indices within each arena structure inside the
20    ///   [`crate::Module`] type.
21    /// * No arena contents contain any items that have forward dependencies; that is, the value
22    ///   associated with a handle only may contain references to handles in the same arena that
23    ///   were constructed before it.
24    ///
25    /// By validating the above conditions, we free up subsequent logic to assume that handle
26    /// accesses are infallible.
27    ///
28    /// # Errors
29    ///
30    /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
31    /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
32    /// validation pass.
33    pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
34        let &crate::Module {
35            ref constants,
36            ref overrides,
37            ref entry_points,
38            ref functions,
39            ref global_variables,
40            ref types,
41            ref special_types,
42            ref global_expressions,
43            ref diagnostic_filters,
44            ref diagnostic_filter_leaf,
45            ref doc_comments,
46        } = module;
47
48        // Because types can refer to global expressions and vice versa, to
49        // ensure the overall structure is free of cycles, we must traverse them
50        // both in tandem.
51        //
52        // Try to visit all types and global expressions in an order such that
53        // each item refers only to previously visited items. If we succeed,
54        // that shows that there cannot be any cycles, since walking any edge
55        // advances you towards the beginning of the visiting order.
56        //
57        // Validate all the handles in types and expressions as we traverse the
58        // arenas.
59        let mut global_exprs_iter = global_expressions.iter().peekable();
60        for (th, t) in types.iter() {
61            // Imagine the `for` loop and `global_exprs_iter` as two fingers
62            // walking the type and global expression arenas. They don't visit
63            // elements at the same rate: sometimes one processes a bunch of
64            // elements while the other one stays still. But at each point, they
65            // check that the two ranges of elements they've visited only refer
66            // to other elements in those ranges.
67            //
68            // For brevity, we'll say 'handles behind `global_exprs_iter`' to
69            // mean handles that have already been produced by
70            // `global_exprs_iter`. Once `global_exprs_iter` returns `None`, all
71            // global expression handles are 'behind' it.
72            //
73            // At this point:
74            //
75            // - All types visited by prior iterations (that is, before
76            //   `th`/`t`) refer only to expressions behind `global_exprs_iter`.
77            //
78            //   On the first iteration, this is obviously true: there are no
79            //   prior iterations, and `global_exprs_iter` hasn't produced
80            //   anything yet. At the bottom of the loop, we'll claim that it's
81            //   true for `th`/`t` as well, so the condition remains true when
82            //   we advance to the next type.
83            //
84            // - All expressions behind `global_exprs_iter` refer only to
85            //   previously visited types.
86            //
87            //   Again, trivially true at the start, and we'll show it's true
88            //   about each expression that `global_exprs_iter` produces.
89            //
90            // Once we also check that arena elements only refer to prior
91            // elements in that arena, we can see that `th`/`t` does not
92            // participate in a cycle: it only refers to previously visited
93            // types and expressions behind `global_exprs_iter`, and none of
94            // those refer to `th`/`t`, because they passed the same checks
95            // before we reached `th`/`t`.
96            if let Some(max_expr) = Self::validate_type_handles((th, t), overrides)? {
97                max_expr.check_valid_for(global_expressions)?;
98                // Since `t` refers to `max_expr`, if we want our invariants to
99                // remain true, we must advance `global_exprs_iter` beyond
100                // `max_expr`.
101                while let Some((eh, e)) = global_exprs_iter.next_if(|&(eh, _)| eh <= max_expr) {
102                    if let Some(max_type) =
103                        Self::validate_const_expression_handles((eh, e), constants, overrides)?
104                    {
105                        // Show that `eh` refers only to previously visited types.
106                        th.check_dep(max_type)?;
107                    }
108                    // We've advanced `global_exprs_iter` past `eh` already. But
109                    // since we now know that `eh` refers only to previously
110                    // visited types, it is again true that all expressions
111                    // behind `global_exprs_iter` refer only to previously
112                    // visited types. So we can continue to the next expression.
113                }
114            }
115
116            // Here we know that if `th` refers to any expressions at all,
117            // `max_expr` is the latest one. And we know that `max_expr` is
118            // behind `global_exprs_iter`. So `th` refers only to expressions
119            // behind `global_exprs_iter`, and the invariants will still be
120            // true on the next iteration.
121        }
122
123        // Since we also enforced the usual intra-arena rules that expressions
124        // refer only to prior expressions, expressions can only form cycles if
125        // they include types. But we've shown that all types are acyclic, so
126        // all expressions must be acyclic as well.
127        //
128        // Validate the remaining expressions normally.
129        for handle_and_expr in global_exprs_iter {
130            Self::validate_const_expression_handles(handle_and_expr, constants, overrides)?;
131        }
132
133        let validate_type = |handle| Self::validate_type_handle(handle, types);
134        let validate_const_expr =
135            |handle| Self::validate_expression_handle(handle, global_expressions);
136
137        for (_handle, constant) in constants.iter() {
138            let &crate::Constant { name: _, ty, init } = constant;
139            validate_type(ty)?;
140            validate_const_expr(init)?;
141        }
142
143        for (_handle, r#override) in overrides.iter() {
144            let &crate::Override {
145                name: _,
146                id: _,
147                ty,
148                init,
149            } = r#override;
150            validate_type(ty)?;
151            if let Some(init_expr) = init {
152                validate_const_expr(init_expr)?;
153            }
154        }
155
156        for (_handle, global_variable) in global_variables.iter() {
157            let &crate::GlobalVariable {
158                name: _,
159                space: _,
160                binding: _,
161                ty,
162                init,
163            } = global_variable;
164            validate_type(ty)?;
165            if let Some(init_expr) = init {
166                validate_const_expr(init_expr)?;
167            }
168        }
169
170        let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
171            let &crate::Function {
172                name: _,
173                ref arguments,
174                ref result,
175                ref local_variables,
176                ref expressions,
177                ref named_expressions,
178                ref body,
179                ref diagnostic_filter_leaf,
180            } = function;
181
182            for arg in arguments.iter() {
183                let &crate::FunctionArgument {
184                    name: _,
185                    ty,
186                    binding: _,
187                } = arg;
188                validate_type(ty)?;
189            }
190
191            if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
192                validate_type(ty)?;
193            }
194
195            for (_handle, local_variable) in local_variables.iter() {
196                let &crate::LocalVariable { name: _, ty, init } = local_variable;
197                validate_type(ty)?;
198                if let Some(init) = init {
199                    Self::validate_expression_handle(init, expressions)?;
200                }
201            }
202
203            for handle in named_expressions.keys().copied() {
204                Self::validate_expression_handle(handle, expressions)?;
205            }
206
207            for handle_and_expr in expressions.iter() {
208                Self::validate_expression_handles(
209                    handle_and_expr,
210                    constants,
211                    overrides,
212                    types,
213                    local_variables,
214                    global_variables,
215                    functions,
216                    function_handle,
217                )?;
218            }
219
220            Self::validate_block_handles(body, expressions, functions)?;
221
222            if let Some(handle) = *diagnostic_filter_leaf {
223                handle.check_valid_for(diagnostic_filters)?;
224            }
225
226            Ok(())
227        };
228
229        for entry_point in entry_points.iter() {
230            validate_function(None, &entry_point.function)?;
231            if let Some(sizes) = entry_point.workgroup_size_overrides {
232                for size in sizes.iter().filter_map(|x| *x) {
233                    validate_const_expr(size)?;
234                }
235            }
236        }
237
238        for (function_handle, function) in functions.iter() {
239            validate_function(Some(function_handle), function)?;
240        }
241
242        if let Some(ty) = special_types.ray_desc {
243            validate_type(ty)?;
244        }
245        if let Some(ty) = special_types.ray_intersection {
246            validate_type(ty)?;
247        }
248        if let Some(ty) = special_types.ray_vertex_return {
249            validate_type(ty)?;
250        }
251
252        for (handle, _node) in diagnostic_filters.iter() {
253            let DiagnosticFilterNode { inner: _, parent } = diagnostic_filters[handle];
254            handle.check_dep_opt(parent)?;
255        }
256        if let Some(handle) = *diagnostic_filter_leaf {
257            handle.check_valid_for(diagnostic_filters)?;
258        }
259
260        if let Some(doc_comments) = doc_comments.as_ref() {
261            let crate::DocComments {
262                module: _,
263                types: ref doc_comments_for_types,
264                struct_members: ref doc_comments_for_struct_members,
265                entry_points: ref doc_comments_for_entry_points,
266                functions: ref doc_comments_for_functions,
267                constants: ref doc_comments_for_constants,
268                global_variables: ref doc_comments_for_global_variables,
269            } = **doc_comments;
270
271            for (&ty, _) in doc_comments_for_types.iter() {
272                validate_type(ty)?;
273            }
274
275            for (&(ty, struct_member_index), _) in doc_comments_for_struct_members.iter() {
276                validate_type(ty)?;
277                let struct_type = types.get_handle(ty).unwrap();
278                match struct_type.inner {
279                    crate::TypeInner::Struct {
280                        ref members,
281                        span: ref _span,
282                    } => {
283                        (0..members.len())
284                            .contains(&struct_member_index)
285                            .then_some(())
286                            // TODO: what errors should this be?
287                            .ok_or_else(|| ValidationError::Type {
288                                handle: ty,
289                                name: struct_type.name.as_ref().map_or_else(
290                                    || "members length incorrect".to_string(),
291                                    |name| name.to_string(),
292                                ),
293                                source: TypeError::InvalidData(ty),
294                            })?;
295                    }
296                    _ => {
297                        // TODO: internal error ? We should never get here.
298                        // If entering there, it's probably that we forgot to adjust a handle in the compact phase.
299                        return Err(ValidationError::Type {
300                            handle: ty,
301                            name: struct_type
302                                .name
303                                .as_ref()
304                                .map_or_else(|| "Unknown".to_string(), |name| name.to_string()),
305                            source: TypeError::InvalidData(ty),
306                        });
307                    }
308                }
309                for (&function, _) in doc_comments_for_functions.iter() {
310                    Self::validate_function_handle(function, functions)?;
311                }
312                for (&entry_point_index, _) in doc_comments_for_entry_points.iter() {
313                    Self::validate_entry_point_index(entry_point_index, entry_points)?;
314                }
315                for (&constant, _) in doc_comments_for_constants.iter() {
316                    Self::validate_constant_handle(constant, constants)?;
317                }
318                for (&global_variable, _) in doc_comments_for_global_variables.iter() {
319                    Self::validate_global_variable_handle(global_variable, global_variables)?;
320                }
321            }
322        }
323
324        Ok(())
325    }
326
327    fn validate_type_handle(
328        handle: Handle<crate::Type>,
329        types: &UniqueArena<crate::Type>,
330    ) -> Result<(), InvalidHandleError> {
331        handle.check_valid_for_uniq(types).map(|_| ())
332    }
333
334    fn validate_constant_handle(
335        handle: Handle<crate::Constant>,
336        constants: &Arena<crate::Constant>,
337    ) -> Result<(), InvalidHandleError> {
338        handle.check_valid_for(constants).map(|_| ())
339    }
340
341    fn validate_global_variable_handle(
342        handle: Handle<crate::GlobalVariable>,
343        global_variables: &Arena<crate::GlobalVariable>,
344    ) -> Result<(), InvalidHandleError> {
345        handle.check_valid_for(global_variables).map(|_| ())
346    }
347
348    fn validate_override_handle(
349        handle: Handle<crate::Override>,
350        overrides: &Arena<crate::Override>,
351    ) -> Result<(), InvalidHandleError> {
352        handle.check_valid_for(overrides).map(|_| ())
353    }
354
355    fn validate_expression_handle(
356        handle: Handle<crate::Expression>,
357        expressions: &Arena<crate::Expression>,
358    ) -> Result<(), InvalidHandleError> {
359        handle.check_valid_for(expressions).map(|_| ())
360    }
361
362    fn validate_function_handle(
363        handle: Handle<crate::Function>,
364        functions: &Arena<crate::Function>,
365    ) -> Result<(), InvalidHandleError> {
366        handle.check_valid_for(functions).map(|_| ())
367    }
368
369    /// Validate all handles that occur in `ty`, whose handle is `handle`.
370    ///
371    /// If `ty` refers to any expressions, return the highest-indexed expression
372    /// handle that it uses. This is used for detecting cycles between the
373    /// expression and type arenas.
374    fn validate_type_handles(
375        (handle, ty): (Handle<crate::Type>, &crate::Type),
376        overrides: &Arena<crate::Override>,
377    ) -> Result<Option<Handle<crate::Expression>>, InvalidHandleError> {
378        let max_expr = match ty.inner {
379            crate::TypeInner::Scalar { .. }
380            | crate::TypeInner::Vector { .. }
381            | crate::TypeInner::Matrix { .. }
382            | crate::TypeInner::ValuePointer { .. }
383            | crate::TypeInner::Atomic { .. }
384            | crate::TypeInner::Image { .. }
385            | crate::TypeInner::Sampler { .. }
386            | crate::TypeInner::AccelerationStructure { .. }
387            | crate::TypeInner::RayQuery { .. } => None,
388            crate::TypeInner::Pointer { base, space: _ } => {
389                handle.check_dep(base)?;
390                None
391            }
392            crate::TypeInner::Array { base, size, .. }
393            | crate::TypeInner::BindingArray { base, size, .. } => {
394                handle.check_dep(base)?;
395                match size {
396                    crate::ArraySize::Pending(h) => {
397                        Self::validate_override_handle(h, overrides)?;
398                        let r#override = &overrides[h];
399                        handle.check_dep(r#override.ty)?;
400                        r#override.init
401                    }
402                    crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
403                }
404            }
405            crate::TypeInner::Struct {
406                ref members,
407                span: _,
408            } => {
409                handle.check_dep_iter(members.iter().map(|m| m.ty))?;
410                None
411            }
412        };
413
414        Ok(max_expr)
415    }
416
417    fn validate_entry_point_index(
418        entry_point_index: usize,
419        entry_points: &[EntryPoint],
420    ) -> Result<(), InvalidHandleError> {
421        (0..entry_points.len())
422            .contains(&entry_point_index)
423            .then_some(())
424            .ok_or_else(|| {
425                BadHandle {
426                    kind: "EntryPoint",
427                    index: entry_point_index,
428                }
429                .into()
430            })
431    }
432
433    /// Validate all handles that occur in `expression`, whose handle is `handle`.
434    ///
435    /// If `expression` refers to any `Type`s, return the highest-indexed type
436    /// handle that it uses. This is used for detecting cycles between the
437    /// expression and type arenas.
438    fn validate_const_expression_handles(
439        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
440        constants: &Arena<crate::Constant>,
441        overrides: &Arena<crate::Override>,
442    ) -> Result<Option<Handle<crate::Type>>, InvalidHandleError> {
443        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
444        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
445
446        let max_type = match *expression {
447            crate::Expression::Literal(_) => None,
448            crate::Expression::Constant(constant) => {
449                validate_constant(constant)?;
450                handle.check_dep(constants[constant].init)?;
451                None
452            }
453            crate::Expression::Override(r#override) => {
454                validate_override(r#override)?;
455                if let Some(init) = overrides[r#override].init {
456                    handle.check_dep(init)?;
457                }
458                None
459            }
460            crate::Expression::ZeroValue(ty) => Some(ty),
461            crate::Expression::Compose { ty, ref components } => {
462                handle.check_dep_iter(components.iter().copied())?;
463                Some(ty)
464            }
465            _ => None,
466        };
467        Ok(max_type)
468    }
469
470    #[allow(clippy::too_many_arguments)]
471    fn validate_expression_handles(
472        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
473        constants: &Arena<crate::Constant>,
474        overrides: &Arena<crate::Override>,
475        types: &UniqueArena<crate::Type>,
476        local_variables: &Arena<crate::LocalVariable>,
477        global_variables: &Arena<crate::GlobalVariable>,
478        functions: &Arena<crate::Function>,
479        // The handle of the current function or `None` if it's an entry point
480        current_function: Option<Handle<crate::Function>>,
481    ) -> Result<(), InvalidHandleError> {
482        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
483        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
484        let validate_type = |handle| Self::validate_type_handle(handle, types);
485
486        match *expression {
487            crate::Expression::Access { base, index } => {
488                handle.check_dep(base)?.check_dep(index)?;
489            }
490            crate::Expression::AccessIndex { base, .. } => {
491                handle.check_dep(base)?;
492            }
493            crate::Expression::Splat { value, .. } => {
494                handle.check_dep(value)?;
495            }
496            crate::Expression::Swizzle { vector, .. } => {
497                handle.check_dep(vector)?;
498            }
499            crate::Expression::Literal(_) => {}
500            crate::Expression::Constant(constant) => {
501                validate_constant(constant)?;
502            }
503            crate::Expression::Override(r#override) => {
504                validate_override(r#override)?;
505            }
506            crate::Expression::ZeroValue(ty) => {
507                validate_type(ty)?;
508            }
509            crate::Expression::Compose { ty, ref components } => {
510                validate_type(ty)?;
511                handle.check_dep_iter(components.iter().copied())?;
512            }
513            crate::Expression::FunctionArgument(_arg_idx) => (),
514            crate::Expression::GlobalVariable(global_variable) => {
515                global_variable.check_valid_for(global_variables)?;
516            }
517            crate::Expression::LocalVariable(local_variable) => {
518                local_variable.check_valid_for(local_variables)?;
519            }
520            crate::Expression::Load { pointer } => {
521                handle.check_dep(pointer)?;
522            }
523            crate::Expression::ImageSample {
524                image,
525                sampler,
526                gather: _,
527                coordinate,
528                array_index,
529                offset,
530                level,
531                depth_ref,
532                clamp_to_edge: _,
533            } => {
534                handle
535                    .check_dep(image)?
536                    .check_dep(sampler)?
537                    .check_dep(coordinate)?
538                    .check_dep_opt(array_index)?
539                    .check_dep_opt(offset)?;
540
541                match level {
542                    crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
543                    crate::SampleLevel::Exact(expr) => {
544                        handle.check_dep(expr)?;
545                    }
546                    crate::SampleLevel::Bias(expr) => {
547                        handle.check_dep(expr)?;
548                    }
549                    crate::SampleLevel::Gradient { x, y } => {
550                        handle.check_dep(x)?.check_dep(y)?;
551                    }
552                };
553
554                handle.check_dep_opt(depth_ref)?;
555            }
556            crate::Expression::ImageLoad {
557                image,
558                coordinate,
559                array_index,
560                sample,
561                level,
562            } => {
563                handle
564                    .check_dep(image)?
565                    .check_dep(coordinate)?
566                    .check_dep_opt(array_index)?
567                    .check_dep_opt(sample)?
568                    .check_dep_opt(level)?;
569            }
570            crate::Expression::ImageQuery { image, query } => {
571                handle.check_dep(image)?;
572                match query {
573                    crate::ImageQuery::Size { level } => {
574                        handle.check_dep_opt(level)?;
575                    }
576                    crate::ImageQuery::NumLevels
577                    | crate::ImageQuery::NumLayers
578                    | crate::ImageQuery::NumSamples => (),
579                };
580            }
581            crate::Expression::Unary {
582                op: _,
583                expr: operand,
584            } => {
585                handle.check_dep(operand)?;
586            }
587            crate::Expression::Binary { op: _, left, right } => {
588                handle.check_dep(left)?.check_dep(right)?;
589            }
590            crate::Expression::Select {
591                condition,
592                accept,
593                reject,
594            } => {
595                handle
596                    .check_dep(condition)?
597                    .check_dep(accept)?
598                    .check_dep(reject)?;
599            }
600            crate::Expression::Derivative { expr: argument, .. } => {
601                handle.check_dep(argument)?;
602            }
603            crate::Expression::Relational { fun: _, argument } => {
604                handle.check_dep(argument)?;
605            }
606            crate::Expression::Math {
607                fun: _,
608                arg,
609                arg1,
610                arg2,
611                arg3,
612            } => {
613                handle
614                    .check_dep(arg)?
615                    .check_dep_opt(arg1)?
616                    .check_dep_opt(arg2)?
617                    .check_dep_opt(arg3)?;
618            }
619            crate::Expression::As {
620                expr: input,
621                kind: _,
622                convert: _,
623            } => {
624                handle.check_dep(input)?;
625            }
626            crate::Expression::CallResult(function) => {
627                Self::validate_function_handle(function, functions)?;
628                if let Some(handle) = current_function {
629                    handle.check_dep(function)?;
630                }
631            }
632            crate::Expression::AtomicResult { .. }
633            | crate::Expression::RayQueryProceedResult
634            | crate::Expression::SubgroupBallotResult
635            | crate::Expression::SubgroupOperationResult { .. }
636            | crate::Expression::WorkGroupUniformLoadResult { .. } => (),
637            crate::Expression::ArrayLength(array) => {
638                handle.check_dep(array)?;
639            }
640            crate::Expression::RayQueryGetIntersection {
641                query,
642                committed: _,
643            }
644            | crate::Expression::RayQueryVertexPositions {
645                query,
646                committed: _,
647            } => {
648                handle.check_dep(query)?;
649            }
650        }
651        Ok(())
652    }
653
654    fn validate_block_handles(
655        block: &crate::Block,
656        expressions: &Arena<crate::Expression>,
657        functions: &Arena<crate::Function>,
658    ) -> Result<(), InvalidHandleError> {
659        let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
660        let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
661        let validate_expr_opt = |handle_opt| {
662            if let Some(handle) = handle_opt {
663                validate_expr(handle)?;
664            }
665            Ok(())
666        };
667
668        block.iter().try_for_each(|stmt| match *stmt {
669            crate::Statement::Emit(ref expr_range) => {
670                expr_range.check_valid_for(expressions)?;
671                Ok(())
672            }
673            crate::Statement::Block(ref block) => {
674                validate_block(block)?;
675                Ok(())
676            }
677            crate::Statement::If {
678                condition,
679                ref accept,
680                ref reject,
681            } => {
682                validate_expr(condition)?;
683                validate_block(accept)?;
684                validate_block(reject)?;
685                Ok(())
686            }
687            crate::Statement::Switch {
688                selector,
689                ref cases,
690            } => {
691                validate_expr(selector)?;
692                for &crate::SwitchCase {
693                    value: _,
694                    ref body,
695                    fall_through: _,
696                } in cases
697                {
698                    validate_block(body)?;
699                }
700                Ok(())
701            }
702            crate::Statement::Loop {
703                ref body,
704                ref continuing,
705                break_if,
706            } => {
707                validate_block(body)?;
708                validate_block(continuing)?;
709                validate_expr_opt(break_if)?;
710                Ok(())
711            }
712            crate::Statement::Return { value } => validate_expr_opt(value),
713            crate::Statement::Store { pointer, value } => {
714                validate_expr(pointer)?;
715                validate_expr(value)?;
716                Ok(())
717            }
718            crate::Statement::ImageStore {
719                image,
720                coordinate,
721                array_index,
722                value,
723            } => {
724                validate_expr(image)?;
725                validate_expr(coordinate)?;
726                validate_expr_opt(array_index)?;
727                validate_expr(value)?;
728                Ok(())
729            }
730            crate::Statement::Atomic {
731                pointer,
732                fun,
733                value,
734                result,
735            } => {
736                validate_expr(pointer)?;
737                match fun {
738                    crate::AtomicFunction::Add
739                    | crate::AtomicFunction::Subtract
740                    | crate::AtomicFunction::And
741                    | crate::AtomicFunction::ExclusiveOr
742                    | crate::AtomicFunction::InclusiveOr
743                    | crate::AtomicFunction::Min
744                    | crate::AtomicFunction::Max => (),
745                    crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
746                };
747                validate_expr(value)?;
748                if let Some(result) = result {
749                    validate_expr(result)?;
750                }
751                Ok(())
752            }
753            crate::Statement::ImageAtomic {
754                image,
755                coordinate,
756                array_index,
757                fun: _,
758                value,
759            } => {
760                validate_expr(image)?;
761                validate_expr(coordinate)?;
762                validate_expr_opt(array_index)?;
763                validate_expr(value)?;
764                Ok(())
765            }
766            crate::Statement::WorkGroupUniformLoad { pointer, result } => {
767                validate_expr(pointer)?;
768                validate_expr(result)?;
769                Ok(())
770            }
771            crate::Statement::Call {
772                function,
773                ref arguments,
774                result,
775            } => {
776                Self::validate_function_handle(function, functions)?;
777                for arg in arguments.iter().copied() {
778                    validate_expr(arg)?;
779                }
780                validate_expr_opt(result)?;
781                Ok(())
782            }
783            crate::Statement::RayQuery { query, ref fun } => {
784                validate_expr(query)?;
785                match *fun {
786                    crate::RayQueryFunction::Initialize {
787                        acceleration_structure,
788                        descriptor,
789                    } => {
790                        validate_expr(acceleration_structure)?;
791                        validate_expr(descriptor)?;
792                    }
793                    crate::RayQueryFunction::Proceed { result } => {
794                        validate_expr(result)?;
795                    }
796                    crate::RayQueryFunction::GenerateIntersection { hit_t } => {
797                        validate_expr(hit_t)?;
798                    }
799                    crate::RayQueryFunction::ConfirmIntersection => {}
800                    crate::RayQueryFunction::Terminate => {}
801                }
802                Ok(())
803            }
804            crate::Statement::SubgroupBallot { result, predicate } => {
805                validate_expr_opt(predicate)?;
806                validate_expr(result)?;
807                Ok(())
808            }
809            crate::Statement::SubgroupCollectiveOperation {
810                op: _,
811                collective_op: _,
812                argument,
813                result,
814            } => {
815                validate_expr(argument)?;
816                validate_expr(result)?;
817                Ok(())
818            }
819            crate::Statement::SubgroupGather {
820                mode,
821                argument,
822                result,
823            } => {
824                validate_expr(argument)?;
825                match mode {
826                    crate::GatherMode::BroadcastFirst => {}
827                    crate::GatherMode::Broadcast(index)
828                    | crate::GatherMode::Shuffle(index)
829                    | crate::GatherMode::ShuffleDown(index)
830                    | crate::GatherMode::ShuffleUp(index)
831                    | crate::GatherMode::ShuffleXor(index)
832                    | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?,
833                    crate::GatherMode::QuadSwap(_) => {}
834                }
835                validate_expr(result)?;
836                Ok(())
837            }
838            crate::Statement::Break
839            | crate::Statement::Continue
840            | crate::Statement::Kill
841            | crate::Statement::ControlBarrier(_)
842            | crate::Statement::MemoryBarrier(_) => Ok(()),
843        })
844    }
845}
846
847impl From<BadHandle> for ValidationError {
848    fn from(source: BadHandle) -> Self {
849        Self::InvalidHandle(source.into())
850    }
851}
852
853impl From<FwdDepError> for ValidationError {
854    fn from(source: FwdDepError) -> Self {
855        Self::InvalidHandle(source.into())
856    }
857}
858
859impl From<BadRangeError> for ValidationError {
860    fn from(source: BadRangeError) -> Self {
861        Self::InvalidHandle(source.into())
862    }
863}
864
865#[derive(Clone, Debug, thiserror::Error)]
866#[cfg_attr(test, derive(PartialEq))]
867pub enum InvalidHandleError {
868    #[error(transparent)]
869    BadHandle(#[from] BadHandle),
870    #[error(transparent)]
871    ForwardDependency(#[from] FwdDepError),
872    #[error(transparent)]
873    BadRange(#[from] BadRangeError),
874}
875
876#[derive(Clone, Debug, thiserror::Error)]
877#[cfg_attr(test, derive(PartialEq))]
878#[error(
879    "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
880    which has not been processed yet"
881)]
882pub struct FwdDepError {
883    // This error is used for many `Handle` types, but there's no point in making this generic, so
884    // we just flatten them all to `Handle<()>` here.
885    subject: Handle<()>,
886    subject_kind: &'static str,
887    depends_on: Handle<()>,
888    depends_on_kind: &'static str,
889}
890
891impl<T> Handle<T> {
892    /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
893    pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
894        arena.check_contains_handle(self)?;
895        Ok(())
896    }
897
898    /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
899    pub(self) fn check_valid_for_uniq(
900        self,
901        arena: &UniqueArena<T>,
902    ) -> Result<(), InvalidHandleError>
903    where
904        T: Eq + Hash,
905    {
906        arena.check_contains_handle(self)?;
907        Ok(())
908    }
909
910    /// Check that `depends_on` was constructed before `self` by comparing handle indices.
911    ///
912    /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
913    /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
914    /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
915    /// recursive definitions of arena-based values in linear time.
916    ///
917    /// # Errors
918    ///
919    /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
920    /// than `self`'s, this function returns an error.
921    pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
922        if depends_on < self {
923            Ok(self)
924        } else {
925            let erase_handle_type = |handle: Handle<_>| {
926                Handle::new(NonMaxU32::new((handle.index()).try_into().unwrap()).unwrap())
927            };
928            Err(FwdDepError {
929                subject: erase_handle_type(self),
930                subject_kind: core::any::type_name::<T>(),
931                depends_on: erase_handle_type(depends_on),
932                depends_on_kind: core::any::type_name::<T>(),
933            })
934        }
935    }
936
937    /// Like [`Self::check_dep`], except for [`Option`]al handle values.
938    pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
939        self.check_dep_iter(depends_on.into_iter())
940    }
941
942    /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
943    pub(self) fn check_dep_iter(
944        self,
945        depends_on: impl Iterator<Item = Self>,
946    ) -> Result<Self, FwdDepError> {
947        for handle in depends_on {
948            self.check_dep(handle)?;
949        }
950        Ok(self)
951    }
952}
953
954impl<T> crate::arena::Range<T> {
955    pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
956        arena.check_contains_range(self)
957    }
958}
959
960#[test]
961fn constant_deps() {
962    use crate::{Constant, Expression, Literal, Span, Type, TypeInner};
963
964    let nowhere = Span::default();
965
966    let mut types = UniqueArena::new();
967    let mut const_exprs = Arena::new();
968    let mut fun_exprs = Arena::new();
969    let mut constants = Arena::new();
970    let overrides = Arena::new();
971
972    let i32_handle = types.insert(
973        Type {
974            name: None,
975            inner: TypeInner::Scalar(crate::Scalar::I32),
976        },
977        nowhere,
978    );
979
980    // Construct a self-referential constant by misusing a handle to
981    // fun_exprs as a constant initializer.
982    let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere);
983    let self_referential_const = constants.append(
984        Constant {
985            name: None,
986            ty: i32_handle,
987            init: fun_expr,
988        },
989        nowhere,
990    );
991    let _self_referential_expr =
992        const_exprs.append(Expression::Constant(self_referential_const), nowhere);
993
994    for handle_and_expr in const_exprs.iter() {
995        assert!(super::Validator::validate_const_expression_handles(
996            handle_and_expr,
997            &constants,
998            &overrides,
999        )
1000        .is_err());
1001    }
1002}
1003
1004#[test]
1005fn array_size_deps() {
1006    use super::Validator;
1007    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1008
1009    let nowhere = Span::default();
1010
1011    let mut m = crate::Module::default();
1012
1013    let ty_u32 = m.types.insert(
1014        Type {
1015            name: Some("u32".to_string()),
1016            inner: TypeInner::Scalar(Scalar::U32),
1017        },
1018        nowhere,
1019    );
1020    let ex_zero = m
1021        .global_expressions
1022        .append(Expression::ZeroValue(ty_u32), nowhere);
1023    let ty_handle = m.overrides.append(
1024        Override {
1025            name: None,
1026            id: None,
1027            ty: ty_u32,
1028            init: Some(ex_zero),
1029        },
1030        nowhere,
1031    );
1032    let ty_arr = m.types.insert(
1033        Type {
1034            name: Some("bad_array".to_string()),
1035            inner: TypeInner::Array {
1036                base: ty_u32,
1037                size: ArraySize::Pending(ty_handle),
1038                stride: 4,
1039            },
1040        },
1041        nowhere,
1042    );
1043
1044    // Everything should be okay now.
1045    assert!(Validator::validate_module_handles(&m).is_ok());
1046
1047    // Mutate `ex_zero`'s type to `ty_arr`, introducing a cycle.
1048    // Validation should catch the cycle.
1049    m.global_expressions[ex_zero] = Expression::ZeroValue(ty_arr);
1050    assert!(Validator::validate_module_handles(&m).is_err());
1051}
1052
1053#[test]
1054fn array_size_override() {
1055    use super::Validator;
1056    use crate::{ArraySize, Override, Scalar, Span, Type, TypeInner};
1057
1058    let nowhere = Span::default();
1059
1060    let mut m = crate::Module::default();
1061
1062    let ty_u32 = m.types.insert(
1063        Type {
1064            name: Some("u32".to_string()),
1065            inner: TypeInner::Scalar(Scalar::U32),
1066        },
1067        nowhere,
1068    );
1069
1070    let bad_override: Handle<Override> = Handle::new(NonMaxU32::new(1000).unwrap());
1071    let _ty_arr = m.types.insert(
1072        Type {
1073            name: Some("bad_array".to_string()),
1074            inner: TypeInner::Array {
1075                base: ty_u32,
1076                size: ArraySize::Pending(bad_override),
1077                stride: 4,
1078            },
1079        },
1080        nowhere,
1081    );
1082
1083    assert!(Validator::validate_module_handles(&m).is_err());
1084}
1085
1086#[test]
1087fn override_init_deps() {
1088    use super::Validator;
1089    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1090
1091    let nowhere = Span::default();
1092
1093    let mut m = crate::Module::default();
1094
1095    let ty_u32 = m.types.insert(
1096        Type {
1097            name: Some("u32".to_string()),
1098            inner: TypeInner::Scalar(Scalar::U32),
1099        },
1100        nowhere,
1101    );
1102    let ex_zero = m
1103        .global_expressions
1104        .append(Expression::ZeroValue(ty_u32), nowhere);
1105    let r#override = m.overrides.append(
1106        Override {
1107            name: Some("bad_override".into()),
1108            id: None,
1109            ty: ty_u32,
1110            init: Some(ex_zero),
1111        },
1112        nowhere,
1113    );
1114    let ty_arr = m.types.insert(
1115        Type {
1116            name: Some("bad_array".to_string()),
1117            inner: TypeInner::Array {
1118                base: ty_u32,
1119                size: ArraySize::Pending(r#override),
1120                stride: 4,
1121            },
1122        },
1123        nowhere,
1124    );
1125    let ex_arr = m
1126        .global_expressions
1127        .append(Expression::ZeroValue(ty_arr), nowhere);
1128
1129    assert!(Validator::validate_module_handles(&m).is_ok());
1130
1131    // Mutate `r#override`'s initializer to `ex_arr`, introducing a cycle.
1132    // Validation should catch the cycle.
1133    m.overrides[r#override].init = Some(ex_arr);
1134    assert!(Validator::validate_module_handles(&m).is_err());
1135}