naga/valid/
handles.rs

1//! Implementation of `Validator::validate_module_handles`.
2
3use core::{convert::TryInto, hash::Hash};
4
5use super::{TypeError, ValidationError};
6use crate::non_max_u32::NonMaxU32;
7use crate::{
8    arena::{BadHandle, BadRangeError},
9    diagnostic_filter::DiagnosticFilterNode,
10    EntryPoint, Handle,
11};
12use crate::{Arena, UniqueArena};
13
14use alloc::string::ToString;
15
16impl super::Validator {
17    /// Validates that all handles within `module` are:
18    ///
19    /// * Valid, in the sense that they contain indices within each arena structure inside the
20    ///   [`crate::Module`] type.
21    /// * No arena contents contain any items that have forward dependencies; that is, the value
22    ///   associated with a handle only may contain references to handles in the same arena that
23    ///   were constructed before it.
24    ///
25    /// By validating the above conditions, we free up subsequent logic to assume that handle
26    /// accesses are infallible.
27    ///
28    /// # Errors
29    ///
30    /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
31    /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
32    /// validation pass.
33    pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
34        let &crate::Module {
35            ref constants,
36            ref overrides,
37            ref entry_points,
38            ref functions,
39            ref global_variables,
40            ref types,
41            ref special_types,
42            ref global_expressions,
43            ref diagnostic_filters,
44            ref diagnostic_filter_leaf,
45            ref doc_comments,
46        } = module;
47
48        // Because types can refer to global expressions and vice versa, to
49        // ensure the overall structure is free of cycles, we must traverse them
50        // both in tandem.
51        //
52        // Try to visit all types and global expressions in an order such that
53        // each item refers only to previously visited items. If we succeed,
54        // that shows that there cannot be any cycles, since walking any edge
55        // advances you towards the beginning of the visiting order.
56        //
57        // Validate all the handles in types and expressions as we traverse the
58        // arenas.
59        let mut global_exprs_iter = global_expressions.iter().peekable();
60        for (th, t) in types.iter() {
61            // Imagine the `for` loop and `global_exprs_iter` as two fingers
62            // walking the type and global expression arenas. They don't visit
63            // elements at the same rate: sometimes one processes a bunch of
64            // elements while the other one stays still. But at each point, they
65            // check that the two ranges of elements they've visited only refer
66            // to other elements in those ranges.
67            //
68            // For brevity, we'll say 'handles behind `global_exprs_iter`' to
69            // mean handles that have already been produced by
70            // `global_exprs_iter`. Once `global_exprs_iter` returns `None`, all
71            // global expression handles are 'behind' it.
72            //
73            // At this point:
74            //
75            // - All types visited by prior iterations (that is, before
76            //   `th`/`t`) refer only to expressions behind `global_exprs_iter`.
77            //
78            //   On the first iteration, this is obviously true: there are no
79            //   prior iterations, and `global_exprs_iter` hasn't produced
80            //   anything yet. At the bottom of the loop, we'll claim that it's
81            //   true for `th`/`t` as well, so the condition remains true when
82            //   we advance to the next type.
83            //
84            // - All expressions behind `global_exprs_iter` refer only to
85            //   previously visited types.
86            //
87            //   Again, trivially true at the start, and we'll show it's true
88            //   about each expression that `global_exprs_iter` produces.
89            //
90            // Once we also check that arena elements only refer to prior
91            // elements in that arena, we can see that `th`/`t` does not
92            // participate in a cycle: it only refers to previously visited
93            // types and expressions behind `global_exprs_iter`, and none of
94            // those refer to `th`/`t`, because they passed the same checks
95            // before we reached `th`/`t`.
96            if let Some(max_expr) = Self::validate_type_handles((th, t), overrides)? {
97                max_expr.check_valid_for(global_expressions)?;
98                // Since `t` refers to `max_expr`, if we want our invariants to
99                // remain true, we must advance `global_exprs_iter` beyond
100                // `max_expr`.
101                while let Some((eh, e)) = global_exprs_iter.next_if(|&(eh, _)| eh <= max_expr) {
102                    if let Some(max_type) =
103                        Self::validate_const_expression_handles((eh, e), constants, overrides)?
104                    {
105                        // Show that `eh` refers only to previously visited types.
106                        th.check_dep(max_type)?;
107                    }
108                    // We've advanced `global_exprs_iter` past `eh` already. But
109                    // since we now know that `eh` refers only to previously
110                    // visited types, it is again true that all expressions
111                    // behind `global_exprs_iter` refer only to previously
112                    // visited types. So we can continue to the next expression.
113                }
114            }
115
116            // Here we know that if `th` refers to any expressions at all,
117            // `max_expr` is the latest one. And we know that `max_expr` is
118            // behind `global_exprs_iter`. So `th` refers only to expressions
119            // behind `global_exprs_iter`, and the invariants will still be
120            // true on the next iteration.
121        }
122
123        // Since we also enforced the usual intra-arena rules that expressions
124        // refer only to prior expressions, expressions can only form cycles if
125        // they include types. But we've shown that all types are acyclic, so
126        // all expressions must be acyclic as well.
127        //
128        // Validate the remaining expressions normally.
129        for handle_and_expr in global_exprs_iter {
130            Self::validate_const_expression_handles(handle_and_expr, constants, overrides)?;
131        }
132
133        let validate_type = |handle| Self::validate_type_handle(handle, types);
134        let validate_const_expr =
135            |handle| Self::validate_expression_handle(handle, global_expressions);
136
137        for (_handle, constant) in constants.iter() {
138            let &crate::Constant { name: _, ty, init } = constant;
139            validate_type(ty)?;
140            validate_const_expr(init)?;
141        }
142
143        for (_handle, r#override) in overrides.iter() {
144            let &crate::Override {
145                name: _,
146                id: _,
147                ty,
148                init,
149            } = r#override;
150            validate_type(ty)?;
151            if let Some(init_expr) = init {
152                validate_const_expr(init_expr)?;
153            }
154        }
155
156        for (_handle, global_variable) in global_variables.iter() {
157            let &crate::GlobalVariable {
158                name: _,
159                space: _,
160                binding: _,
161                ty,
162                init,
163            } = global_variable;
164            validate_type(ty)?;
165            if let Some(init_expr) = init {
166                validate_const_expr(init_expr)?;
167            }
168        }
169
170        let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
171            let &crate::Function {
172                name: _,
173                ref arguments,
174                ref result,
175                ref local_variables,
176                ref expressions,
177                ref named_expressions,
178                ref body,
179                ref diagnostic_filter_leaf,
180            } = function;
181
182            for arg in arguments.iter() {
183                let &crate::FunctionArgument {
184                    name: _,
185                    ty,
186                    binding: _,
187                } = arg;
188                validate_type(ty)?;
189            }
190
191            if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
192                validate_type(ty)?;
193            }
194
195            for (_handle, local_variable) in local_variables.iter() {
196                let &crate::LocalVariable { name: _, ty, init } = local_variable;
197                validate_type(ty)?;
198                if let Some(init) = init {
199                    Self::validate_expression_handle(init, expressions)?;
200                }
201            }
202
203            for handle in named_expressions.keys().copied() {
204                Self::validate_expression_handle(handle, expressions)?;
205            }
206
207            for handle_and_expr in expressions.iter() {
208                Self::validate_expression_handles(
209                    handle_and_expr,
210                    constants,
211                    overrides,
212                    types,
213                    local_variables,
214                    global_variables,
215                    functions,
216                    function_handle,
217                )?;
218            }
219
220            Self::validate_block_handles(body, expressions, functions)?;
221
222            if let Some(handle) = *diagnostic_filter_leaf {
223                handle.check_valid_for(diagnostic_filters)?;
224            }
225
226            Ok(())
227        };
228
229        for entry_point in entry_points.iter() {
230            validate_function(None, &entry_point.function)?;
231            if let Some(sizes) = entry_point.workgroup_size_overrides {
232                for size in sizes.iter().filter_map(|x| *x) {
233                    validate_const_expr(size)?;
234                }
235            }
236            if let Some(task_payload) = entry_point.task_payload {
237                Self::validate_global_variable_handle(task_payload, global_variables)?;
238            }
239            if let Some(ref mesh_info) = entry_point.mesh_info {
240                validate_type(mesh_info.vertex_output_type)?;
241                validate_type(mesh_info.primitive_output_type)?;
242                for ov in mesh_info
243                    .max_vertices_override
244                    .iter()
245                    .chain(mesh_info.max_primitives_override.iter())
246                {
247                    validate_const_expr(*ov)?;
248                }
249            }
250        }
251
252        for (function_handle, function) in functions.iter() {
253            validate_function(Some(function_handle), function)?;
254        }
255
256        if let Some(ty) = special_types.ray_desc {
257            validate_type(ty)?;
258        }
259        if let Some(ty) = special_types.ray_intersection {
260            validate_type(ty)?;
261        }
262        if let Some(ty) = special_types.ray_vertex_return {
263            validate_type(ty)?;
264        }
265
266        for (handle, _node) in diagnostic_filters.iter() {
267            let DiagnosticFilterNode { inner: _, parent } = diagnostic_filters[handle];
268            handle.check_dep_opt(parent)?;
269        }
270        if let Some(handle) = *diagnostic_filter_leaf {
271            handle.check_valid_for(diagnostic_filters)?;
272        }
273
274        if let Some(doc_comments) = doc_comments.as_ref() {
275            let crate::DocComments {
276                module: _,
277                types: ref doc_comments_for_types,
278                struct_members: ref doc_comments_for_struct_members,
279                entry_points: ref doc_comments_for_entry_points,
280                functions: ref doc_comments_for_functions,
281                constants: ref doc_comments_for_constants,
282                global_variables: ref doc_comments_for_global_variables,
283            } = **doc_comments;
284
285            for (&ty, _) in doc_comments_for_types.iter() {
286                validate_type(ty)?;
287            }
288
289            for (&(ty, struct_member_index), _) in doc_comments_for_struct_members.iter() {
290                validate_type(ty)?;
291                let struct_type = types.get_handle(ty).unwrap();
292                match struct_type.inner {
293                    crate::TypeInner::Struct {
294                        ref members,
295                        span: ref _span,
296                    } => {
297                        (0..members.len())
298                            .contains(&struct_member_index)
299                            .then_some(())
300                            // TODO: what errors should this be?
301                            .ok_or_else(|| ValidationError::Type {
302                                handle: ty,
303                                name: struct_type.name.as_ref().map_or_else(
304                                    || "members length incorrect".to_string(),
305                                    |name| name.to_string(),
306                                ),
307                                source: TypeError::InvalidData(ty),
308                            })?;
309                    }
310                    _ => {
311                        // TODO: internal error ? We should never get here.
312                        // If entering there, it's probably that we forgot to adjust a handle in the compact phase.
313                        return Err(ValidationError::Type {
314                            handle: ty,
315                            name: struct_type
316                                .name
317                                .as_ref()
318                                .map_or_else(|| "Unknown".to_string(), |name| name.to_string()),
319                            source: TypeError::InvalidData(ty),
320                        });
321                    }
322                }
323                for (&function, _) in doc_comments_for_functions.iter() {
324                    Self::validate_function_handle(function, functions)?;
325                }
326                for (&entry_point_index, _) in doc_comments_for_entry_points.iter() {
327                    Self::validate_entry_point_index(entry_point_index, entry_points)?;
328                }
329                for (&constant, _) in doc_comments_for_constants.iter() {
330                    Self::validate_constant_handle(constant, constants)?;
331                }
332                for (&global_variable, _) in doc_comments_for_global_variables.iter() {
333                    Self::validate_global_variable_handle(global_variable, global_variables)?;
334                }
335            }
336        }
337
338        Ok(())
339    }
340
341    fn validate_type_handle(
342        handle: Handle<crate::Type>,
343        types: &UniqueArena<crate::Type>,
344    ) -> Result<(), InvalidHandleError> {
345        handle.check_valid_for_uniq(types).map(|_| ())
346    }
347
348    fn validate_constant_handle(
349        handle: Handle<crate::Constant>,
350        constants: &Arena<crate::Constant>,
351    ) -> Result<(), InvalidHandleError> {
352        handle.check_valid_for(constants).map(|_| ())
353    }
354
355    fn validate_global_variable_handle(
356        handle: Handle<crate::GlobalVariable>,
357        global_variables: &Arena<crate::GlobalVariable>,
358    ) -> Result<(), InvalidHandleError> {
359        handle.check_valid_for(global_variables).map(|_| ())
360    }
361
362    fn validate_override_handle(
363        handle: Handle<crate::Override>,
364        overrides: &Arena<crate::Override>,
365    ) -> Result<(), InvalidHandleError> {
366        handle.check_valid_for(overrides).map(|_| ())
367    }
368
369    fn validate_expression_handle(
370        handle: Handle<crate::Expression>,
371        expressions: &Arena<crate::Expression>,
372    ) -> Result<(), InvalidHandleError> {
373        handle.check_valid_for(expressions).map(|_| ())
374    }
375
376    fn validate_function_handle(
377        handle: Handle<crate::Function>,
378        functions: &Arena<crate::Function>,
379    ) -> Result<(), InvalidHandleError> {
380        handle.check_valid_for(functions).map(|_| ())
381    }
382
383    /// Validate all handles that occur in `ty`, whose handle is `handle`.
384    ///
385    /// If `ty` refers to any expressions, return the highest-indexed expression
386    /// handle that it uses. This is used for detecting cycles between the
387    /// expression and type arenas.
388    fn validate_type_handles(
389        (handle, ty): (Handle<crate::Type>, &crate::Type),
390        overrides: &Arena<crate::Override>,
391    ) -> Result<Option<Handle<crate::Expression>>, InvalidHandleError> {
392        let max_expr = match ty.inner {
393            crate::TypeInner::Scalar { .. }
394            | crate::TypeInner::Vector { .. }
395            | crate::TypeInner::Matrix { .. }
396            | crate::TypeInner::ValuePointer { .. }
397            | crate::TypeInner::Atomic { .. }
398            | crate::TypeInner::Image { .. }
399            | crate::TypeInner::Sampler { .. }
400            | crate::TypeInner::AccelerationStructure { .. }
401            | crate::TypeInner::RayQuery { .. } => None,
402            crate::TypeInner::Pointer { base, space: _ } => {
403                handle.check_dep(base)?;
404                None
405            }
406            crate::TypeInner::Array { base, size, .. }
407            | crate::TypeInner::BindingArray { base, size, .. } => {
408                handle.check_dep(base)?;
409                match size {
410                    crate::ArraySize::Pending(h) => {
411                        Self::validate_override_handle(h, overrides)?;
412                        let r#override = &overrides[h];
413                        handle.check_dep(r#override.ty)?;
414                        r#override.init
415                    }
416                    crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
417                }
418            }
419            crate::TypeInner::Struct {
420                ref members,
421                span: _,
422            } => {
423                handle.check_dep_iter(members.iter().map(|m| m.ty))?;
424                None
425            }
426        };
427
428        Ok(max_expr)
429    }
430
431    fn validate_entry_point_index(
432        entry_point_index: usize,
433        entry_points: &[EntryPoint],
434    ) -> Result<(), InvalidHandleError> {
435        (0..entry_points.len())
436            .contains(&entry_point_index)
437            .then_some(())
438            .ok_or_else(|| {
439                BadHandle {
440                    kind: "EntryPoint",
441                    index: entry_point_index,
442                }
443                .into()
444            })
445    }
446
447    /// Validate all handles that occur in `expression`, whose handle is `handle`.
448    ///
449    /// If `expression` refers to any `Type`s, return the highest-indexed type
450    /// handle that it uses. This is used for detecting cycles between the
451    /// expression and type arenas.
452    fn validate_const_expression_handles(
453        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
454        constants: &Arena<crate::Constant>,
455        overrides: &Arena<crate::Override>,
456    ) -> Result<Option<Handle<crate::Type>>, InvalidHandleError> {
457        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
458        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
459
460        let max_type = match *expression {
461            crate::Expression::Literal(_) => None,
462            crate::Expression::Constant(constant) => {
463                validate_constant(constant)?;
464                handle.check_dep(constants[constant].init)?;
465                None
466            }
467            crate::Expression::Override(r#override) => {
468                validate_override(r#override)?;
469                if let Some(init) = overrides[r#override].init {
470                    handle.check_dep(init)?;
471                }
472                None
473            }
474            crate::Expression::ZeroValue(ty) => Some(ty),
475            crate::Expression::Compose { ty, ref components } => {
476                handle.check_dep_iter(components.iter().copied())?;
477                Some(ty)
478            }
479            _ => None,
480        };
481        Ok(max_type)
482    }
483
484    #[allow(clippy::too_many_arguments)]
485    fn validate_expression_handles(
486        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
487        constants: &Arena<crate::Constant>,
488        overrides: &Arena<crate::Override>,
489        types: &UniqueArena<crate::Type>,
490        local_variables: &Arena<crate::LocalVariable>,
491        global_variables: &Arena<crate::GlobalVariable>,
492        functions: &Arena<crate::Function>,
493        // The handle of the current function or `None` if it's an entry point
494        current_function: Option<Handle<crate::Function>>,
495    ) -> Result<(), InvalidHandleError> {
496        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
497        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
498        let validate_type = |handle| Self::validate_type_handle(handle, types);
499
500        match *expression {
501            crate::Expression::Access { base, index } => {
502                handle.check_dep(base)?.check_dep(index)?;
503            }
504            crate::Expression::AccessIndex { base, .. } => {
505                handle.check_dep(base)?;
506            }
507            crate::Expression::Splat { value, .. } => {
508                handle.check_dep(value)?;
509            }
510            crate::Expression::Swizzle { vector, .. } => {
511                handle.check_dep(vector)?;
512            }
513            crate::Expression::Literal(_) => {}
514            crate::Expression::Constant(constant) => {
515                validate_constant(constant)?;
516            }
517            crate::Expression::Override(r#override) => {
518                validate_override(r#override)?;
519            }
520            crate::Expression::ZeroValue(ty) => {
521                validate_type(ty)?;
522            }
523            crate::Expression::Compose { ty, ref components } => {
524                validate_type(ty)?;
525                handle.check_dep_iter(components.iter().copied())?;
526            }
527            crate::Expression::FunctionArgument(_arg_idx) => (),
528            crate::Expression::GlobalVariable(global_variable) => {
529                global_variable.check_valid_for(global_variables)?;
530            }
531            crate::Expression::LocalVariable(local_variable) => {
532                local_variable.check_valid_for(local_variables)?;
533            }
534            crate::Expression::Load { pointer } => {
535                handle.check_dep(pointer)?;
536            }
537            crate::Expression::ImageSample {
538                image,
539                sampler,
540                gather: _,
541                coordinate,
542                array_index,
543                offset,
544                level,
545                depth_ref,
546                clamp_to_edge: _,
547            } => {
548                handle
549                    .check_dep(image)?
550                    .check_dep(sampler)?
551                    .check_dep(coordinate)?
552                    .check_dep_opt(array_index)?
553                    .check_dep_opt(offset)?;
554
555                match level {
556                    crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
557                    crate::SampleLevel::Exact(expr) => {
558                        handle.check_dep(expr)?;
559                    }
560                    crate::SampleLevel::Bias(expr) => {
561                        handle.check_dep(expr)?;
562                    }
563                    crate::SampleLevel::Gradient { x, y } => {
564                        handle.check_dep(x)?.check_dep(y)?;
565                    }
566                };
567
568                handle.check_dep_opt(depth_ref)?;
569            }
570            crate::Expression::ImageLoad {
571                image,
572                coordinate,
573                array_index,
574                sample,
575                level,
576            } => {
577                handle
578                    .check_dep(image)?
579                    .check_dep(coordinate)?
580                    .check_dep_opt(array_index)?
581                    .check_dep_opt(sample)?
582                    .check_dep_opt(level)?;
583            }
584            crate::Expression::ImageQuery { image, query } => {
585                handle.check_dep(image)?;
586                match query {
587                    crate::ImageQuery::Size { level } => {
588                        handle.check_dep_opt(level)?;
589                    }
590                    crate::ImageQuery::NumLevels
591                    | crate::ImageQuery::NumLayers
592                    | crate::ImageQuery::NumSamples => (),
593                };
594            }
595            crate::Expression::Unary {
596                op: _,
597                expr: operand,
598            } => {
599                handle.check_dep(operand)?;
600            }
601            crate::Expression::Binary { op: _, left, right } => {
602                handle.check_dep(left)?.check_dep(right)?;
603            }
604            crate::Expression::Select {
605                condition,
606                accept,
607                reject,
608            } => {
609                handle
610                    .check_dep(condition)?
611                    .check_dep(accept)?
612                    .check_dep(reject)?;
613            }
614            crate::Expression::Derivative { expr: argument, .. } => {
615                handle.check_dep(argument)?;
616            }
617            crate::Expression::Relational { fun: _, argument } => {
618                handle.check_dep(argument)?;
619            }
620            crate::Expression::Math {
621                fun: _,
622                arg,
623                arg1,
624                arg2,
625                arg3,
626            } => {
627                handle
628                    .check_dep(arg)?
629                    .check_dep_opt(arg1)?
630                    .check_dep_opt(arg2)?
631                    .check_dep_opt(arg3)?;
632            }
633            crate::Expression::As {
634                expr: input,
635                kind: _,
636                convert: _,
637            } => {
638                handle.check_dep(input)?;
639            }
640            crate::Expression::CallResult(function) => {
641                Self::validate_function_handle(function, functions)?;
642                if let Some(handle) = current_function {
643                    handle.check_dep(function)?;
644                }
645            }
646            crate::Expression::AtomicResult { .. }
647            | crate::Expression::RayQueryProceedResult
648            | crate::Expression::SubgroupBallotResult
649            | crate::Expression::SubgroupOperationResult { .. }
650            | crate::Expression::WorkGroupUniformLoadResult { .. } => (),
651            crate::Expression::ArrayLength(array) => {
652                handle.check_dep(array)?;
653            }
654            crate::Expression::RayQueryGetIntersection {
655                query,
656                committed: _,
657            }
658            | crate::Expression::RayQueryVertexPositions {
659                query,
660                committed: _,
661            } => {
662                handle.check_dep(query)?;
663            }
664        }
665        Ok(())
666    }
667
668    fn validate_block_handles(
669        block: &crate::Block,
670        expressions: &Arena<crate::Expression>,
671        functions: &Arena<crate::Function>,
672    ) -> Result<(), InvalidHandleError> {
673        let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
674        let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
675        let validate_expr_opt = |handle_opt| {
676            if let Some(handle) = handle_opt {
677                validate_expr(handle)?;
678            }
679            Ok(())
680        };
681
682        block.iter().try_for_each(|stmt| match *stmt {
683            crate::Statement::Emit(ref expr_range) => {
684                expr_range.check_valid_for(expressions)?;
685                Ok(())
686            }
687            crate::Statement::Block(ref block) => {
688                validate_block(block)?;
689                Ok(())
690            }
691            crate::Statement::If {
692                condition,
693                ref accept,
694                ref reject,
695            } => {
696                validate_expr(condition)?;
697                validate_block(accept)?;
698                validate_block(reject)?;
699                Ok(())
700            }
701            crate::Statement::Switch {
702                selector,
703                ref cases,
704            } => {
705                validate_expr(selector)?;
706                for &crate::SwitchCase {
707                    value: _,
708                    ref body,
709                    fall_through: _,
710                } in cases
711                {
712                    validate_block(body)?;
713                }
714                Ok(())
715            }
716            crate::Statement::Loop {
717                ref body,
718                ref continuing,
719                break_if,
720            } => {
721                validate_block(body)?;
722                validate_block(continuing)?;
723                validate_expr_opt(break_if)?;
724                Ok(())
725            }
726            crate::Statement::Return { value } => validate_expr_opt(value),
727            crate::Statement::Store { pointer, value } => {
728                validate_expr(pointer)?;
729                validate_expr(value)?;
730                Ok(())
731            }
732            crate::Statement::ImageStore {
733                image,
734                coordinate,
735                array_index,
736                value,
737            } => {
738                validate_expr(image)?;
739                validate_expr(coordinate)?;
740                validate_expr_opt(array_index)?;
741                validate_expr(value)?;
742                Ok(())
743            }
744            crate::Statement::Atomic {
745                pointer,
746                fun,
747                value,
748                result,
749            } => {
750                validate_expr(pointer)?;
751                match fun {
752                    crate::AtomicFunction::Add
753                    | crate::AtomicFunction::Subtract
754                    | crate::AtomicFunction::And
755                    | crate::AtomicFunction::ExclusiveOr
756                    | crate::AtomicFunction::InclusiveOr
757                    | crate::AtomicFunction::Min
758                    | crate::AtomicFunction::Max => (),
759                    crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
760                };
761                validate_expr(value)?;
762                if let Some(result) = result {
763                    validate_expr(result)?;
764                }
765                Ok(())
766            }
767            crate::Statement::ImageAtomic {
768                image,
769                coordinate,
770                array_index,
771                fun: _,
772                value,
773            } => {
774                validate_expr(image)?;
775                validate_expr(coordinate)?;
776                validate_expr_opt(array_index)?;
777                validate_expr(value)?;
778                Ok(())
779            }
780            crate::Statement::WorkGroupUniformLoad { pointer, result } => {
781                validate_expr(pointer)?;
782                validate_expr(result)?;
783                Ok(())
784            }
785            crate::Statement::Call {
786                function,
787                ref arguments,
788                result,
789            } => {
790                Self::validate_function_handle(function, functions)?;
791                for arg in arguments.iter().copied() {
792                    validate_expr(arg)?;
793                }
794                validate_expr_opt(result)?;
795                Ok(())
796            }
797            crate::Statement::RayQuery { query, ref fun } => {
798                validate_expr(query)?;
799                match *fun {
800                    crate::RayQueryFunction::Initialize {
801                        acceleration_structure,
802                        descriptor,
803                    } => {
804                        validate_expr(acceleration_structure)?;
805                        validate_expr(descriptor)?;
806                    }
807                    crate::RayQueryFunction::Proceed { result } => {
808                        validate_expr(result)?;
809                    }
810                    crate::RayQueryFunction::GenerateIntersection { hit_t } => {
811                        validate_expr(hit_t)?;
812                    }
813                    crate::RayQueryFunction::ConfirmIntersection => {}
814                    crate::RayQueryFunction::Terminate => {}
815                }
816                Ok(())
817            }
818            crate::Statement::MeshFunction(func) => match func {
819                crate::MeshFunction::SetMeshOutputs {
820                    vertex_count,
821                    primitive_count,
822                } => {
823                    validate_expr(vertex_count)?;
824                    validate_expr(primitive_count)?;
825                    Ok(())
826                }
827                crate::MeshFunction::SetVertex { index, value }
828                | crate::MeshFunction::SetPrimitive { index, value } => {
829                    validate_expr(index)?;
830                    validate_expr(value)?;
831                    Ok(())
832                }
833            },
834            crate::Statement::SubgroupBallot { result, predicate } => {
835                validate_expr_opt(predicate)?;
836                validate_expr(result)?;
837                Ok(())
838            }
839            crate::Statement::SubgroupCollectiveOperation {
840                op: _,
841                collective_op: _,
842                argument,
843                result,
844            } => {
845                validate_expr(argument)?;
846                validate_expr(result)?;
847                Ok(())
848            }
849            crate::Statement::SubgroupGather {
850                mode,
851                argument,
852                result,
853            } => {
854                validate_expr(argument)?;
855                match mode {
856                    crate::GatherMode::BroadcastFirst => {}
857                    crate::GatherMode::Broadcast(index)
858                    | crate::GatherMode::Shuffle(index)
859                    | crate::GatherMode::ShuffleDown(index)
860                    | crate::GatherMode::ShuffleUp(index)
861                    | crate::GatherMode::ShuffleXor(index)
862                    | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?,
863                    crate::GatherMode::QuadSwap(_) => {}
864                }
865                validate_expr(result)?;
866                Ok(())
867            }
868            crate::Statement::Break
869            | crate::Statement::Continue
870            | crate::Statement::Kill
871            | crate::Statement::ControlBarrier(_)
872            | crate::Statement::MemoryBarrier(_) => Ok(()),
873        })
874    }
875}
876
877impl From<BadHandle> for ValidationError {
878    fn from(source: BadHandle) -> Self {
879        Self::InvalidHandle(source.into())
880    }
881}
882
883impl From<FwdDepError> for ValidationError {
884    fn from(source: FwdDepError) -> Self {
885        Self::InvalidHandle(source.into())
886    }
887}
888
889impl From<BadRangeError> for ValidationError {
890    fn from(source: BadRangeError) -> Self {
891        Self::InvalidHandle(source.into())
892    }
893}
894
895#[derive(Clone, Debug, thiserror::Error)]
896#[cfg_attr(test, derive(PartialEq))]
897pub enum InvalidHandleError {
898    #[error(transparent)]
899    BadHandle(#[from] BadHandle),
900    #[error(transparent)]
901    ForwardDependency(#[from] FwdDepError),
902    #[error(transparent)]
903    BadRange(#[from] BadRangeError),
904}
905
906#[derive(Clone, Debug, thiserror::Error)]
907#[cfg_attr(test, derive(PartialEq))]
908#[error(
909    "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
910    which has not been processed yet"
911)]
912pub struct FwdDepError {
913    // This error is used for many `Handle` types, but there's no point in making this generic, so
914    // we just flatten them all to `Handle<()>` here.
915    subject: Handle<()>,
916    subject_kind: &'static str,
917    depends_on: Handle<()>,
918    depends_on_kind: &'static str,
919}
920
921impl<T> Handle<T> {
922    /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
923    pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
924        arena.check_contains_handle(self)?;
925        Ok(())
926    }
927
928    /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
929    pub(self) fn check_valid_for_uniq(
930        self,
931        arena: &UniqueArena<T>,
932    ) -> Result<(), InvalidHandleError>
933    where
934        T: Eq + Hash,
935    {
936        arena.check_contains_handle(self)?;
937        Ok(())
938    }
939
940    /// Check that `depends_on` was constructed before `self` by comparing handle indices.
941    ///
942    /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
943    /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
944    /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
945    /// recursive definitions of arena-based values in linear time.
946    ///
947    /// # Errors
948    ///
949    /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
950    /// than `self`'s, this function returns an error.
951    pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
952        if depends_on < self {
953            Ok(self)
954        } else {
955            let erase_handle_type = |handle: Handle<_>| {
956                Handle::new(NonMaxU32::new((handle.index()).try_into().unwrap()).unwrap())
957            };
958            Err(FwdDepError {
959                subject: erase_handle_type(self),
960                subject_kind: core::any::type_name::<T>(),
961                depends_on: erase_handle_type(depends_on),
962                depends_on_kind: core::any::type_name::<T>(),
963            })
964        }
965    }
966
967    /// Like [`Self::check_dep`], except for [`Option`]al handle values.
968    pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
969        self.check_dep_iter(depends_on.into_iter())
970    }
971
972    /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
973    pub(self) fn check_dep_iter(
974        self,
975        depends_on: impl Iterator<Item = Self>,
976    ) -> Result<Self, FwdDepError> {
977        for handle in depends_on {
978            self.check_dep(handle)?;
979        }
980        Ok(self)
981    }
982}
983
984impl<T> crate::arena::Range<T> {
985    pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
986        arena.check_contains_range(self)
987    }
988}
989
990#[test]
991fn constant_deps() {
992    use crate::{Constant, Expression, Literal, Span, Type, TypeInner};
993
994    let nowhere = Span::default();
995
996    let mut types = UniqueArena::new();
997    let mut const_exprs = Arena::new();
998    let mut fun_exprs = Arena::new();
999    let mut constants = Arena::new();
1000    let overrides = Arena::new();
1001
1002    let i32_handle = types.insert(
1003        Type {
1004            name: None,
1005            inner: TypeInner::Scalar(crate::Scalar::I32),
1006        },
1007        nowhere,
1008    );
1009
1010    // Construct a self-referential constant by misusing a handle to
1011    // fun_exprs as a constant initializer.
1012    let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere);
1013    let self_referential_const = constants.append(
1014        Constant {
1015            name: None,
1016            ty: i32_handle,
1017            init: fun_expr,
1018        },
1019        nowhere,
1020    );
1021    let _self_referential_expr =
1022        const_exprs.append(Expression::Constant(self_referential_const), nowhere);
1023
1024    for handle_and_expr in const_exprs.iter() {
1025        assert!(super::Validator::validate_const_expression_handles(
1026            handle_and_expr,
1027            &constants,
1028            &overrides,
1029        )
1030        .is_err());
1031    }
1032}
1033
1034#[test]
1035fn array_size_deps() {
1036    use super::Validator;
1037    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1038
1039    let nowhere = Span::default();
1040
1041    let mut m = crate::Module::default();
1042
1043    let ty_u32 = m.types.insert(
1044        Type {
1045            name: Some("u32".to_string()),
1046            inner: TypeInner::Scalar(Scalar::U32),
1047        },
1048        nowhere,
1049    );
1050    let ex_zero = m
1051        .global_expressions
1052        .append(Expression::ZeroValue(ty_u32), nowhere);
1053    let ty_handle = m.overrides.append(
1054        Override {
1055            name: None,
1056            id: None,
1057            ty: ty_u32,
1058            init: Some(ex_zero),
1059        },
1060        nowhere,
1061    );
1062    let ty_arr = m.types.insert(
1063        Type {
1064            name: Some("bad_array".to_string()),
1065            inner: TypeInner::Array {
1066                base: ty_u32,
1067                size: ArraySize::Pending(ty_handle),
1068                stride: 4,
1069            },
1070        },
1071        nowhere,
1072    );
1073
1074    // Everything should be okay now.
1075    assert!(Validator::validate_module_handles(&m).is_ok());
1076
1077    // Mutate `ex_zero`'s type to `ty_arr`, introducing a cycle.
1078    // Validation should catch the cycle.
1079    m.global_expressions[ex_zero] = Expression::ZeroValue(ty_arr);
1080    assert!(Validator::validate_module_handles(&m).is_err());
1081}
1082
1083#[test]
1084fn array_size_override() {
1085    use super::Validator;
1086    use crate::{ArraySize, Override, Scalar, Span, Type, TypeInner};
1087
1088    let nowhere = Span::default();
1089
1090    let mut m = crate::Module::default();
1091
1092    let ty_u32 = m.types.insert(
1093        Type {
1094            name: Some("u32".to_string()),
1095            inner: TypeInner::Scalar(Scalar::U32),
1096        },
1097        nowhere,
1098    );
1099
1100    let bad_override: Handle<Override> = Handle::new(NonMaxU32::new(1000).unwrap());
1101    let _ty_arr = m.types.insert(
1102        Type {
1103            name: Some("bad_array".to_string()),
1104            inner: TypeInner::Array {
1105                base: ty_u32,
1106                size: ArraySize::Pending(bad_override),
1107                stride: 4,
1108            },
1109        },
1110        nowhere,
1111    );
1112
1113    assert!(Validator::validate_module_handles(&m).is_err());
1114}
1115
1116#[test]
1117fn override_init_deps() {
1118    use super::Validator;
1119    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1120
1121    let nowhere = Span::default();
1122
1123    let mut m = crate::Module::default();
1124
1125    let ty_u32 = m.types.insert(
1126        Type {
1127            name: Some("u32".to_string()),
1128            inner: TypeInner::Scalar(Scalar::U32),
1129        },
1130        nowhere,
1131    );
1132    let ex_zero = m
1133        .global_expressions
1134        .append(Expression::ZeroValue(ty_u32), nowhere);
1135    let r#override = m.overrides.append(
1136        Override {
1137            name: Some("bad_override".into()),
1138            id: None,
1139            ty: ty_u32,
1140            init: Some(ex_zero),
1141        },
1142        nowhere,
1143    );
1144    let ty_arr = m.types.insert(
1145        Type {
1146            name: Some("bad_array".to_string()),
1147            inner: TypeInner::Array {
1148                base: ty_u32,
1149                size: ArraySize::Pending(r#override),
1150                stride: 4,
1151            },
1152        },
1153        nowhere,
1154    );
1155    let ex_arr = m
1156        .global_expressions
1157        .append(Expression::ZeroValue(ty_arr), nowhere);
1158
1159    assert!(Validator::validate_module_handles(&m).is_ok());
1160
1161    // Mutate `r#override`'s initializer to `ex_arr`, introducing a cycle.
1162    // Validation should catch the cycle.
1163    m.overrides[r#override].init = Some(ex_arr);
1164    assert!(Validator::validate_module_handles(&m).is_err());
1165}