naga/valid/
handles.rs

1//! Implementation of `Validator::validate_module_handles`.
2
3use core::{convert::TryInto, hash::Hash};
4
5use super::{TypeError, ValidationError};
6use crate::non_max_u32::NonMaxU32;
7use crate::{
8    arena::{BadHandle, BadRangeError},
9    diagnostic_filter::DiagnosticFilterNode,
10    EntryPoint, Handle,
11};
12use crate::{Arena, UniqueArena};
13
14use alloc::string::ToString;
15
16impl super::Validator {
17    /// Validates that all handles within `module` are:
18    ///
19    /// * Valid, in the sense that they contain indices within each arena structure inside the
20    ///   [`crate::Module`] type.
21    /// * No arena contents contain any items that have forward dependencies; that is, the value
22    ///   associated with a handle only may contain references to handles in the same arena that
23    ///   were constructed before it.
24    ///
25    /// By validating the above conditions, we free up subsequent logic to assume that handle
26    /// accesses are infallible.
27    ///
28    /// # Errors
29    ///
30    /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
31    /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
32    /// validation pass.
33    pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
34        let &crate::Module {
35            ref constants,
36            ref overrides,
37            ref entry_points,
38            ref functions,
39            ref global_variables,
40            ref types,
41            ref special_types,
42            ref global_expressions,
43            ref diagnostic_filters,
44            ref diagnostic_filter_leaf,
45            ref doc_comments,
46        } = module;
47
48        // Because types can refer to global expressions and vice versa, to
49        // ensure the overall structure is free of cycles, we must traverse them
50        // both in tandem.
51        //
52        // Try to visit all types and global expressions in an order such that
53        // each item refers only to previously visited items. If we succeed,
54        // that shows that there cannot be any cycles, since walking any edge
55        // advances you towards the beginning of the visiting order.
56        //
57        // Validate all the handles in types and expressions as we traverse the
58        // arenas.
59        let mut global_exprs_iter = global_expressions.iter().peekable();
60        for (th, t) in types.iter() {
61            // Imagine the `for` loop and `global_exprs_iter` as two fingers
62            // walking the type and global expression arenas. They don't visit
63            // elements at the same rate: sometimes one processes a bunch of
64            // elements while the other one stays still. But at each point, they
65            // check that the two ranges of elements they've visited only refer
66            // to other elements in those ranges.
67            //
68            // For brevity, we'll say 'handles behind `global_exprs_iter`' to
69            // mean handles that have already been produced by
70            // `global_exprs_iter`. Once `global_exprs_iter` returns `None`, all
71            // global expression handles are 'behind' it.
72            //
73            // At this point:
74            //
75            // - All types visited by prior iterations (that is, before
76            //   `th`/`t`) refer only to expressions behind `global_exprs_iter`.
77            //
78            //   On the first iteration, this is obviously true: there are no
79            //   prior iterations, and `global_exprs_iter` hasn't produced
80            //   anything yet. At the bottom of the loop, we'll claim that it's
81            //   true for `th`/`t` as well, so the condition remains true when
82            //   we advance to the next type.
83            //
84            // - All expressions behind `global_exprs_iter` refer only to
85            //   previously visited types.
86            //
87            //   Again, trivially true at the start, and we'll show it's true
88            //   about each expression that `global_exprs_iter` produces.
89            //
90            // Once we also check that arena elements only refer to prior
91            // elements in that arena, we can see that `th`/`t` does not
92            // participate in a cycle: it only refers to previously visited
93            // types and expressions behind `global_exprs_iter`, and none of
94            // those refer to `th`/`t`, because they passed the same checks
95            // before we reached `th`/`t`.
96            if let Some(max_expr) = Self::validate_type_handles((th, t), overrides)? {
97                max_expr.check_valid_for(global_expressions)?;
98                // Since `t` refers to `max_expr`, if we want our invariants to
99                // remain true, we must advance `global_exprs_iter` beyond
100                // `max_expr`.
101                while let Some((eh, e)) = global_exprs_iter.next_if(|&(eh, _)| eh <= max_expr) {
102                    if let Some(max_type) =
103                        Self::validate_const_expression_handles((eh, e), constants, overrides)?
104                    {
105                        // Show that `eh` refers only to previously visited types.
106                        th.check_dep(max_type)?;
107                    }
108                    // We've advanced `global_exprs_iter` past `eh` already. But
109                    // since we now know that `eh` refers only to previously
110                    // visited types, it is again true that all expressions
111                    // behind `global_exprs_iter` refer only to previously
112                    // visited types. So we can continue to the next expression.
113                }
114            }
115
116            // Here we know that if `th` refers to any expressions at all,
117            // `max_expr` is the latest one. And we know that `max_expr` is
118            // behind `global_exprs_iter`. So `th` refers only to expressions
119            // behind `global_exprs_iter`, and the invariants will still be
120            // true on the next iteration.
121        }
122
123        // Since we also enforced the usual intra-arena rules that expressions
124        // refer only to prior expressions, expressions can only form cycles if
125        // they include types. But we've shown that all types are acyclic, so
126        // all expressions must be acyclic as well.
127        //
128        // Validate the remaining expressions normally.
129        for handle_and_expr in global_exprs_iter {
130            Self::validate_const_expression_handles(handle_and_expr, constants, overrides)?;
131        }
132
133        let validate_type = |handle| Self::validate_type_handle(handle, types);
134        let validate_const_expr =
135            |handle| Self::validate_expression_handle(handle, global_expressions);
136
137        for (_handle, constant) in constants.iter() {
138            let &crate::Constant { name: _, ty, init } = constant;
139            validate_type(ty)?;
140            validate_const_expr(init)?;
141        }
142
143        for (_handle, r#override) in overrides.iter() {
144            let &crate::Override {
145                name: _,
146                id: _,
147                ty,
148                init,
149            } = r#override;
150            validate_type(ty)?;
151            if let Some(init_expr) = init {
152                validate_const_expr(init_expr)?;
153            }
154        }
155
156        for (_handle, global_variable) in global_variables.iter() {
157            let &crate::GlobalVariable {
158                name: _,
159                space: _,
160                binding: _,
161                ty,
162                init,
163            } = global_variable;
164            validate_type(ty)?;
165            if let Some(init_expr) = init {
166                validate_const_expr(init_expr)?;
167            }
168        }
169
170        let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
171            let &crate::Function {
172                name: _,
173                ref arguments,
174                ref result,
175                ref local_variables,
176                ref expressions,
177                ref named_expressions,
178                ref body,
179                ref diagnostic_filter_leaf,
180            } = function;
181
182            for arg in arguments.iter() {
183                let &crate::FunctionArgument {
184                    name: _,
185                    ty,
186                    binding: _,
187                } = arg;
188                validate_type(ty)?;
189            }
190
191            if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
192                validate_type(ty)?;
193            }
194
195            for (_handle, local_variable) in local_variables.iter() {
196                let &crate::LocalVariable { name: _, ty, init } = local_variable;
197                validate_type(ty)?;
198                if let Some(init) = init {
199                    Self::validate_expression_handle(init, expressions)?;
200                }
201            }
202
203            for handle in named_expressions.keys().copied() {
204                Self::validate_expression_handle(handle, expressions)?;
205            }
206
207            for handle_and_expr in expressions.iter() {
208                Self::validate_expression_handles(
209                    handle_and_expr,
210                    constants,
211                    overrides,
212                    types,
213                    local_variables,
214                    global_variables,
215                    functions,
216                    function_handle,
217                )?;
218            }
219
220            Self::validate_block_handles(body, expressions, functions)?;
221
222            if let Some(handle) = *diagnostic_filter_leaf {
223                handle.check_valid_for(diagnostic_filters)?;
224            }
225
226            Ok(())
227        };
228
229        for entry_point in entry_points.iter() {
230            validate_function(None, &entry_point.function)?;
231            if let Some(sizes) = entry_point.workgroup_size_overrides {
232                for size in sizes.iter().filter_map(|x| *x) {
233                    validate_const_expr(size)?;
234                }
235            }
236            if let Some(task_payload) = entry_point.task_payload {
237                Self::validate_global_variable_handle(task_payload, global_variables)?;
238            }
239            if let Some(ref mesh_info) = entry_point.mesh_info {
240                Self::validate_global_variable_handle(mesh_info.output_variable, global_variables)?;
241                validate_type(mesh_info.vertex_output_type)?;
242                validate_type(mesh_info.primitive_output_type)?;
243                for ov in mesh_info
244                    .max_vertices_override
245                    .iter()
246                    .chain(mesh_info.max_primitives_override.iter())
247                {
248                    validate_const_expr(*ov)?;
249                }
250            }
251        }
252
253        for (function_handle, function) in functions.iter() {
254            validate_function(Some(function_handle), function)?;
255        }
256
257        if let Some(ty) = special_types.ray_desc {
258            validate_type(ty)?;
259        }
260        if let Some(ty) = special_types.ray_intersection {
261            validate_type(ty)?;
262        }
263        if let Some(ty) = special_types.ray_vertex_return {
264            validate_type(ty)?;
265        }
266
267        for (handle, _node) in diagnostic_filters.iter() {
268            let DiagnosticFilterNode { inner: _, parent } = diagnostic_filters[handle];
269            handle.check_dep_opt(parent)?;
270        }
271        if let Some(handle) = *diagnostic_filter_leaf {
272            handle.check_valid_for(diagnostic_filters)?;
273        }
274
275        if let Some(doc_comments) = doc_comments.as_ref() {
276            let crate::DocComments {
277                module: _,
278                types: ref doc_comments_for_types,
279                struct_members: ref doc_comments_for_struct_members,
280                entry_points: ref doc_comments_for_entry_points,
281                functions: ref doc_comments_for_functions,
282                constants: ref doc_comments_for_constants,
283                global_variables: ref doc_comments_for_global_variables,
284            } = **doc_comments;
285
286            for (&ty, _) in doc_comments_for_types.iter() {
287                validate_type(ty)?;
288            }
289
290            for (&(ty, struct_member_index), _) in doc_comments_for_struct_members.iter() {
291                validate_type(ty)?;
292                let struct_type = types.get_handle(ty).unwrap();
293                match struct_type.inner {
294                    crate::TypeInner::Struct {
295                        ref members,
296                        span: ref _span,
297                    } => {
298                        (0..members.len())
299                            .contains(&struct_member_index)
300                            .then_some(())
301                            // TODO: what errors should this be?
302                            .ok_or_else(|| ValidationError::Type {
303                                handle: ty,
304                                name: struct_type.name.as_ref().map_or_else(
305                                    || "members length incorrect".to_string(),
306                                    |name| name.to_string(),
307                                ),
308                                source: TypeError::InvalidData(ty),
309                            })?;
310                    }
311                    _ => {
312                        // TODO: internal error ? We should never get here.
313                        // If entering there, it's probably that we forgot to adjust a handle in the compact phase.
314                        return Err(ValidationError::Type {
315                            handle: ty,
316                            name: struct_type
317                                .name
318                                .as_ref()
319                                .map_or_else(|| "Unknown".to_string(), |name| name.to_string()),
320                            source: TypeError::InvalidData(ty),
321                        });
322                    }
323                }
324                for (&function, _) in doc_comments_for_functions.iter() {
325                    Self::validate_function_handle(function, functions)?;
326                }
327                for (&entry_point_index, _) in doc_comments_for_entry_points.iter() {
328                    Self::validate_entry_point_index(entry_point_index, entry_points)?;
329                }
330                for (&constant, _) in doc_comments_for_constants.iter() {
331                    Self::validate_constant_handle(constant, constants)?;
332                }
333                for (&global_variable, _) in doc_comments_for_global_variables.iter() {
334                    Self::validate_global_variable_handle(global_variable, global_variables)?;
335                }
336            }
337        }
338
339        Ok(())
340    }
341
342    fn validate_type_handle(
343        handle: Handle<crate::Type>,
344        types: &UniqueArena<crate::Type>,
345    ) -> Result<(), InvalidHandleError> {
346        handle.check_valid_for_uniq(types).map(|_| ())
347    }
348
349    fn validate_constant_handle(
350        handle: Handle<crate::Constant>,
351        constants: &Arena<crate::Constant>,
352    ) -> Result<(), InvalidHandleError> {
353        handle.check_valid_for(constants).map(|_| ())
354    }
355
356    fn validate_global_variable_handle(
357        handle: Handle<crate::GlobalVariable>,
358        global_variables: &Arena<crate::GlobalVariable>,
359    ) -> Result<(), InvalidHandleError> {
360        handle.check_valid_for(global_variables).map(|_| ())
361    }
362
363    fn validate_override_handle(
364        handle: Handle<crate::Override>,
365        overrides: &Arena<crate::Override>,
366    ) -> Result<(), InvalidHandleError> {
367        handle.check_valid_for(overrides).map(|_| ())
368    }
369
370    fn validate_expression_handle(
371        handle: Handle<crate::Expression>,
372        expressions: &Arena<crate::Expression>,
373    ) -> Result<(), InvalidHandleError> {
374        handle.check_valid_for(expressions).map(|_| ())
375    }
376
377    fn validate_function_handle(
378        handle: Handle<crate::Function>,
379        functions: &Arena<crate::Function>,
380    ) -> Result<(), InvalidHandleError> {
381        handle.check_valid_for(functions).map(|_| ())
382    }
383
384    /// Validate all handles that occur in `ty`, whose handle is `handle`.
385    ///
386    /// If `ty` refers to any expressions, return the highest-indexed expression
387    /// handle that it uses. This is used for detecting cycles between the
388    /// expression and type arenas.
389    fn validate_type_handles(
390        (handle, ty): (Handle<crate::Type>, &crate::Type),
391        overrides: &Arena<crate::Override>,
392    ) -> Result<Option<Handle<crate::Expression>>, InvalidHandleError> {
393        let max_expr = match ty.inner {
394            crate::TypeInner::Scalar { .. }
395            | crate::TypeInner::Vector { .. }
396            | crate::TypeInner::Matrix { .. }
397            | crate::TypeInner::CooperativeMatrix { .. }
398            | crate::TypeInner::ValuePointer { .. }
399            | crate::TypeInner::Atomic { .. }
400            | crate::TypeInner::Image { .. }
401            | crate::TypeInner::Sampler { .. }
402            | crate::TypeInner::AccelerationStructure { .. }
403            | crate::TypeInner::RayQuery { .. } => None,
404            crate::TypeInner::Pointer { base, space: _ } => {
405                handle.check_dep(base)?;
406                None
407            }
408            crate::TypeInner::Array { base, size, .. }
409            | crate::TypeInner::BindingArray { base, size, .. } => {
410                handle.check_dep(base)?;
411                match size {
412                    crate::ArraySize::Pending(h) => {
413                        Self::validate_override_handle(h, overrides)?;
414                        let r#override = &overrides[h];
415                        handle.check_dep(r#override.ty)?;
416                        r#override.init
417                    }
418                    crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
419                }
420            }
421            crate::TypeInner::Struct {
422                ref members,
423                span: _,
424            } => {
425                handle.check_dep_iter(members.iter().map(|m| m.ty))?;
426                None
427            }
428        };
429
430        Ok(max_expr)
431    }
432
433    fn validate_entry_point_index(
434        entry_point_index: usize,
435        entry_points: &[EntryPoint],
436    ) -> Result<(), InvalidHandleError> {
437        (0..entry_points.len())
438            .contains(&entry_point_index)
439            .then_some(())
440            .ok_or_else(|| {
441                BadHandle {
442                    kind: "EntryPoint",
443                    index: entry_point_index,
444                }
445                .into()
446            })
447    }
448
449    /// Validate all handles that occur in `expression`, whose handle is `handle`.
450    ///
451    /// If `expression` refers to any `Type`s, return the highest-indexed type
452    /// handle that it uses. This is used for detecting cycles between the
453    /// expression and type arenas.
454    fn validate_const_expression_handles(
455        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
456        constants: &Arena<crate::Constant>,
457        overrides: &Arena<crate::Override>,
458    ) -> Result<Option<Handle<crate::Type>>, InvalidHandleError> {
459        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
460        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
461
462        let max_type = match *expression {
463            crate::Expression::Literal(_) => None,
464            crate::Expression::Constant(constant) => {
465                validate_constant(constant)?;
466                handle.check_dep(constants[constant].init)?;
467                None
468            }
469            crate::Expression::Override(r#override) => {
470                validate_override(r#override)?;
471                if let Some(init) = overrides[r#override].init {
472                    handle.check_dep(init)?;
473                }
474                None
475            }
476            crate::Expression::ZeroValue(ty) => Some(ty),
477            crate::Expression::Compose { ty, ref components } => {
478                handle.check_dep_iter(components.iter().copied())?;
479                Some(ty)
480            }
481            _ => None,
482        };
483        Ok(max_type)
484    }
485
486    #[allow(clippy::too_many_arguments)]
487    fn validate_expression_handles(
488        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
489        constants: &Arena<crate::Constant>,
490        overrides: &Arena<crate::Override>,
491        types: &UniqueArena<crate::Type>,
492        local_variables: &Arena<crate::LocalVariable>,
493        global_variables: &Arena<crate::GlobalVariable>,
494        functions: &Arena<crate::Function>,
495        // The handle of the current function or `None` if it's an entry point
496        current_function: Option<Handle<crate::Function>>,
497    ) -> Result<(), InvalidHandleError> {
498        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
499        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
500        let validate_type = |handle| Self::validate_type_handle(handle, types);
501
502        match *expression {
503            crate::Expression::Access { base, index } => {
504                handle.check_dep(base)?.check_dep(index)?;
505            }
506            crate::Expression::AccessIndex { base, .. } => {
507                handle.check_dep(base)?;
508            }
509            crate::Expression::Splat { value, .. } => {
510                handle.check_dep(value)?;
511            }
512            crate::Expression::Swizzle { vector, .. } => {
513                handle.check_dep(vector)?;
514            }
515            crate::Expression::Literal(_) => {}
516            crate::Expression::Constant(constant) => {
517                validate_constant(constant)?;
518            }
519            crate::Expression::Override(r#override) => {
520                validate_override(r#override)?;
521            }
522            crate::Expression::ZeroValue(ty) => {
523                validate_type(ty)?;
524            }
525            crate::Expression::Compose { ty, ref components } => {
526                validate_type(ty)?;
527                handle.check_dep_iter(components.iter().copied())?;
528            }
529            crate::Expression::FunctionArgument(_arg_idx) => (),
530            crate::Expression::GlobalVariable(global_variable) => {
531                global_variable.check_valid_for(global_variables)?;
532            }
533            crate::Expression::LocalVariable(local_variable) => {
534                local_variable.check_valid_for(local_variables)?;
535            }
536            crate::Expression::Load { pointer } => {
537                handle.check_dep(pointer)?;
538            }
539            crate::Expression::ImageSample {
540                image,
541                sampler,
542                gather: _,
543                coordinate,
544                array_index,
545                offset,
546                level,
547                depth_ref,
548                clamp_to_edge: _,
549            } => {
550                handle
551                    .check_dep(image)?
552                    .check_dep(sampler)?
553                    .check_dep(coordinate)?
554                    .check_dep_opt(array_index)?
555                    .check_dep_opt(offset)?;
556
557                match level {
558                    crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
559                    crate::SampleLevel::Exact(expr) => {
560                        handle.check_dep(expr)?;
561                    }
562                    crate::SampleLevel::Bias(expr) => {
563                        handle.check_dep(expr)?;
564                    }
565                    crate::SampleLevel::Gradient { x, y } => {
566                        handle.check_dep(x)?.check_dep(y)?;
567                    }
568                };
569
570                handle.check_dep_opt(depth_ref)?;
571            }
572            crate::Expression::ImageLoad {
573                image,
574                coordinate,
575                array_index,
576                sample,
577                level,
578            } => {
579                handle
580                    .check_dep(image)?
581                    .check_dep(coordinate)?
582                    .check_dep_opt(array_index)?
583                    .check_dep_opt(sample)?
584                    .check_dep_opt(level)?;
585            }
586            crate::Expression::ImageQuery { image, query } => {
587                handle.check_dep(image)?;
588                match query {
589                    crate::ImageQuery::Size { level } => {
590                        handle.check_dep_opt(level)?;
591                    }
592                    crate::ImageQuery::NumLevels
593                    | crate::ImageQuery::NumLayers
594                    | crate::ImageQuery::NumSamples => (),
595                };
596            }
597            crate::Expression::Unary {
598                op: _,
599                expr: operand,
600            } => {
601                handle.check_dep(operand)?;
602            }
603            crate::Expression::Binary { op: _, left, right } => {
604                handle.check_dep(left)?.check_dep(right)?;
605            }
606            crate::Expression::Select {
607                condition,
608                accept,
609                reject,
610            } => {
611                handle
612                    .check_dep(condition)?
613                    .check_dep(accept)?
614                    .check_dep(reject)?;
615            }
616            crate::Expression::Derivative { expr: argument, .. } => {
617                handle.check_dep(argument)?;
618            }
619            crate::Expression::Relational { fun: _, argument } => {
620                handle.check_dep(argument)?;
621            }
622            crate::Expression::Math {
623                fun: _,
624                arg,
625                arg1,
626                arg2,
627                arg3,
628            } => {
629                handle
630                    .check_dep(arg)?
631                    .check_dep_opt(arg1)?
632                    .check_dep_opt(arg2)?
633                    .check_dep_opt(arg3)?;
634            }
635            crate::Expression::As {
636                expr: input,
637                kind: _,
638                convert: _,
639            } => {
640                handle.check_dep(input)?;
641            }
642            crate::Expression::CallResult(function) => {
643                Self::validate_function_handle(function, functions)?;
644                if let Some(handle) = current_function {
645                    handle.check_dep(function)?;
646                }
647            }
648            crate::Expression::AtomicResult { .. }
649            | crate::Expression::RayQueryProceedResult
650            | crate::Expression::SubgroupBallotResult
651            | crate::Expression::SubgroupOperationResult { .. }
652            | crate::Expression::WorkGroupUniformLoadResult { .. } => (),
653            crate::Expression::ArrayLength(array) => {
654                handle.check_dep(array)?;
655            }
656            crate::Expression::RayQueryGetIntersection {
657                query,
658                committed: _,
659            }
660            | crate::Expression::RayQueryVertexPositions {
661                query,
662                committed: _,
663            } => {
664                handle.check_dep(query)?;
665            }
666            crate::Expression::CooperativeLoad { ref data, .. } => {
667                handle.check_dep(data.pointer)?.check_dep(data.stride)?;
668            }
669            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
670                handle.check_dep(a)?.check_dep(b)?.check_dep(c)?;
671            }
672        }
673        Ok(())
674    }
675
676    fn validate_block_handles(
677        block: &crate::Block,
678        expressions: &Arena<crate::Expression>,
679        functions: &Arena<crate::Function>,
680    ) -> Result<(), InvalidHandleError> {
681        let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
682        let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
683        let validate_expr_opt = |handle_opt| {
684            if let Some(handle) = handle_opt {
685                validate_expr(handle)?;
686            }
687            Ok(())
688        };
689
690        block.iter().try_for_each(|stmt| match *stmt {
691            crate::Statement::Emit(ref expr_range) => {
692                expr_range.check_valid_for(expressions)?;
693                Ok(())
694            }
695            crate::Statement::Block(ref block) => {
696                validate_block(block)?;
697                Ok(())
698            }
699            crate::Statement::If {
700                condition,
701                ref accept,
702                ref reject,
703            } => {
704                validate_expr(condition)?;
705                validate_block(accept)?;
706                validate_block(reject)?;
707                Ok(())
708            }
709            crate::Statement::Switch {
710                selector,
711                ref cases,
712            } => {
713                validate_expr(selector)?;
714                for &crate::SwitchCase {
715                    value: _,
716                    ref body,
717                    fall_through: _,
718                } in cases
719                {
720                    validate_block(body)?;
721                }
722                Ok(())
723            }
724            crate::Statement::Loop {
725                ref body,
726                ref continuing,
727                break_if,
728            } => {
729                validate_block(body)?;
730                validate_block(continuing)?;
731                validate_expr_opt(break_if)?;
732                Ok(())
733            }
734            crate::Statement::Return { value } => validate_expr_opt(value),
735            crate::Statement::Store { pointer, value } => {
736                validate_expr(pointer)?;
737                validate_expr(value)?;
738                Ok(())
739            }
740            crate::Statement::ImageStore {
741                image,
742                coordinate,
743                array_index,
744                value,
745            } => {
746                validate_expr(image)?;
747                validate_expr(coordinate)?;
748                validate_expr_opt(array_index)?;
749                validate_expr(value)?;
750                Ok(())
751            }
752            crate::Statement::Atomic {
753                pointer,
754                fun,
755                value,
756                result,
757            } => {
758                validate_expr(pointer)?;
759                match fun {
760                    crate::AtomicFunction::Add
761                    | crate::AtomicFunction::Subtract
762                    | crate::AtomicFunction::And
763                    | crate::AtomicFunction::ExclusiveOr
764                    | crate::AtomicFunction::InclusiveOr
765                    | crate::AtomicFunction::Min
766                    | crate::AtomicFunction::Max => (),
767                    crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
768                };
769                validate_expr(value)?;
770                if let Some(result) = result {
771                    validate_expr(result)?;
772                }
773                Ok(())
774            }
775            crate::Statement::ImageAtomic {
776                image,
777                coordinate,
778                array_index,
779                fun: _,
780                value,
781            } => {
782                validate_expr(image)?;
783                validate_expr(coordinate)?;
784                validate_expr_opt(array_index)?;
785                validate_expr(value)?;
786                Ok(())
787            }
788            crate::Statement::WorkGroupUniformLoad { pointer, result } => {
789                validate_expr(pointer)?;
790                validate_expr(result)?;
791                Ok(())
792            }
793            crate::Statement::Call {
794                function,
795                ref arguments,
796                result,
797            } => {
798                Self::validate_function_handle(function, functions)?;
799                for arg in arguments.iter().copied() {
800                    validate_expr(arg)?;
801                }
802                validate_expr_opt(result)?;
803                Ok(())
804            }
805            crate::Statement::RayQuery { query, ref fun } => {
806                validate_expr(query)?;
807                match *fun {
808                    crate::RayQueryFunction::Initialize {
809                        acceleration_structure,
810                        descriptor,
811                    } => {
812                        validate_expr(acceleration_structure)?;
813                        validate_expr(descriptor)?;
814                    }
815                    crate::RayQueryFunction::Proceed { result } => {
816                        validate_expr(result)?;
817                    }
818                    crate::RayQueryFunction::GenerateIntersection { hit_t } => {
819                        validate_expr(hit_t)?;
820                    }
821                    crate::RayQueryFunction::ConfirmIntersection => {}
822                    crate::RayQueryFunction::Terminate => {}
823                }
824                Ok(())
825            }
826            crate::Statement::SubgroupBallot { result, predicate } => {
827                validate_expr_opt(predicate)?;
828                validate_expr(result)?;
829                Ok(())
830            }
831            crate::Statement::SubgroupCollectiveOperation {
832                op: _,
833                collective_op: _,
834                argument,
835                result,
836            } => {
837                validate_expr(argument)?;
838                validate_expr(result)?;
839                Ok(())
840            }
841            crate::Statement::SubgroupGather {
842                mode,
843                argument,
844                result,
845            } => {
846                validate_expr(argument)?;
847                match mode {
848                    crate::GatherMode::BroadcastFirst => {}
849                    crate::GatherMode::Broadcast(index)
850                    | crate::GatherMode::Shuffle(index)
851                    | crate::GatherMode::ShuffleDown(index)
852                    | crate::GatherMode::ShuffleUp(index)
853                    | crate::GatherMode::ShuffleXor(index)
854                    | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?,
855                    crate::GatherMode::QuadSwap(_) => {}
856                }
857                validate_expr(result)?;
858                Ok(())
859            }
860            crate::Statement::CooperativeStore { target, ref data } => {
861                validate_expr(target)?;
862                validate_expr(data.pointer)?;
863                validate_expr(data.stride)?;
864                Ok(())
865            }
866            crate::Statement::Break
867            | crate::Statement::Continue
868            | crate::Statement::Kill
869            | crate::Statement::ControlBarrier(_)
870            | crate::Statement::MemoryBarrier(_) => Ok(()),
871        })
872    }
873}
874
875impl From<BadHandle> for ValidationError {
876    fn from(source: BadHandle) -> Self {
877        Self::InvalidHandle(source.into())
878    }
879}
880
881impl From<FwdDepError> for ValidationError {
882    fn from(source: FwdDepError) -> Self {
883        Self::InvalidHandle(source.into())
884    }
885}
886
887impl From<BadRangeError> for ValidationError {
888    fn from(source: BadRangeError) -> Self {
889        Self::InvalidHandle(source.into())
890    }
891}
892
893#[derive(Clone, Debug, thiserror::Error)]
894#[cfg_attr(test, derive(PartialEq))]
895pub enum InvalidHandleError {
896    #[error(transparent)]
897    BadHandle(#[from] BadHandle),
898    #[error(transparent)]
899    ForwardDependency(#[from] FwdDepError),
900    #[error(transparent)]
901    BadRange(#[from] BadRangeError),
902}
903
904#[derive(Clone, Debug, thiserror::Error)]
905#[cfg_attr(test, derive(PartialEq))]
906#[error(
907    "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
908    which has not been processed yet"
909)]
910pub struct FwdDepError {
911    // This error is used for many `Handle` types, but there's no point in making this generic, so
912    // we just flatten them all to `Handle<()>` here.
913    subject: Handle<()>,
914    subject_kind: &'static str,
915    depends_on: Handle<()>,
916    depends_on_kind: &'static str,
917}
918
919impl<T> Handle<T> {
920    /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
921    pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
922        arena.check_contains_handle(self)?;
923        Ok(())
924    }
925
926    /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
927    pub(self) fn check_valid_for_uniq(
928        self,
929        arena: &UniqueArena<T>,
930    ) -> Result<(), InvalidHandleError>
931    where
932        T: Eq + Hash,
933    {
934        arena.check_contains_handle(self)?;
935        Ok(())
936    }
937
938    /// Check that `depends_on` was constructed before `self` by comparing handle indices.
939    ///
940    /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
941    /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
942    /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
943    /// recursive definitions of arena-based values in linear time.
944    ///
945    /// # Errors
946    ///
947    /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
948    /// than `self`'s, this function returns an error.
949    pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
950        if depends_on < self {
951            Ok(self)
952        } else {
953            let erase_handle_type = |handle: Handle<_>| {
954                Handle::new(NonMaxU32::new((handle.index()).try_into().unwrap()).unwrap())
955            };
956            Err(FwdDepError {
957                subject: erase_handle_type(self),
958                subject_kind: core::any::type_name::<T>(),
959                depends_on: erase_handle_type(depends_on),
960                depends_on_kind: core::any::type_name::<T>(),
961            })
962        }
963    }
964
965    /// Like [`Self::check_dep`], except for [`Option`]al handle values.
966    pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
967        self.check_dep_iter(depends_on.into_iter())
968    }
969
970    /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
971    pub(self) fn check_dep_iter(
972        self,
973        depends_on: impl Iterator<Item = Self>,
974    ) -> Result<Self, FwdDepError> {
975        for handle in depends_on {
976            self.check_dep(handle)?;
977        }
978        Ok(self)
979    }
980}
981
982impl<T> crate::arena::Range<T> {
983    pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
984        arena.check_contains_range(self)
985    }
986}
987
988#[test]
989fn constant_deps() {
990    use crate::{Constant, Expression, Literal, Span, Type, TypeInner};
991
992    let nowhere = Span::default();
993
994    let mut types = UniqueArena::new();
995    let mut const_exprs = Arena::new();
996    let mut fun_exprs = Arena::new();
997    let mut constants = Arena::new();
998    let overrides = Arena::new();
999
1000    let i32_handle = types.insert(
1001        Type {
1002            name: None,
1003            inner: TypeInner::Scalar(crate::Scalar::I32),
1004        },
1005        nowhere,
1006    );
1007
1008    // Construct a self-referential constant by misusing a handle to
1009    // fun_exprs as a constant initializer.
1010    let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere);
1011    let self_referential_const = constants.append(
1012        Constant {
1013            name: None,
1014            ty: i32_handle,
1015            init: fun_expr,
1016        },
1017        nowhere,
1018    );
1019    let _self_referential_expr =
1020        const_exprs.append(Expression::Constant(self_referential_const), nowhere);
1021
1022    for handle_and_expr in const_exprs.iter() {
1023        assert!(super::Validator::validate_const_expression_handles(
1024            handle_and_expr,
1025            &constants,
1026            &overrides,
1027        )
1028        .is_err());
1029    }
1030}
1031
1032#[test]
1033fn array_size_deps() {
1034    use super::Validator;
1035    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1036
1037    let nowhere = Span::default();
1038
1039    let mut m = crate::Module::default();
1040
1041    let ty_u32 = m.types.insert(
1042        Type {
1043            name: Some("u32".to_string()),
1044            inner: TypeInner::Scalar(Scalar::U32),
1045        },
1046        nowhere,
1047    );
1048    let ex_zero = m
1049        .global_expressions
1050        .append(Expression::ZeroValue(ty_u32), nowhere);
1051    let ty_handle = m.overrides.append(
1052        Override {
1053            name: None,
1054            id: None,
1055            ty: ty_u32,
1056            init: Some(ex_zero),
1057        },
1058        nowhere,
1059    );
1060    let ty_arr = m.types.insert(
1061        Type {
1062            name: Some("bad_array".to_string()),
1063            inner: TypeInner::Array {
1064                base: ty_u32,
1065                size: ArraySize::Pending(ty_handle),
1066                stride: 4,
1067            },
1068        },
1069        nowhere,
1070    );
1071
1072    // Everything should be okay now.
1073    assert!(Validator::validate_module_handles(&m).is_ok());
1074
1075    // Mutate `ex_zero`'s type to `ty_arr`, introducing a cycle.
1076    // Validation should catch the cycle.
1077    m.global_expressions[ex_zero] = Expression::ZeroValue(ty_arr);
1078    assert!(Validator::validate_module_handles(&m).is_err());
1079}
1080
1081#[test]
1082fn array_size_override() {
1083    use super::Validator;
1084    use crate::{ArraySize, Override, Scalar, Span, Type, TypeInner};
1085
1086    let nowhere = Span::default();
1087
1088    let mut m = crate::Module::default();
1089
1090    let ty_u32 = m.types.insert(
1091        Type {
1092            name: Some("u32".to_string()),
1093            inner: TypeInner::Scalar(Scalar::U32),
1094        },
1095        nowhere,
1096    );
1097
1098    let bad_override: Handle<Override> = Handle::new(NonMaxU32::new(1000).unwrap());
1099    let _ty_arr = m.types.insert(
1100        Type {
1101            name: Some("bad_array".to_string()),
1102            inner: TypeInner::Array {
1103                base: ty_u32,
1104                size: ArraySize::Pending(bad_override),
1105                stride: 4,
1106            },
1107        },
1108        nowhere,
1109    );
1110
1111    assert!(Validator::validate_module_handles(&m).is_err());
1112}
1113
1114#[test]
1115fn override_init_deps() {
1116    use super::Validator;
1117    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1118
1119    let nowhere = Span::default();
1120
1121    let mut m = crate::Module::default();
1122
1123    let ty_u32 = m.types.insert(
1124        Type {
1125            name: Some("u32".to_string()),
1126            inner: TypeInner::Scalar(Scalar::U32),
1127        },
1128        nowhere,
1129    );
1130    let ex_zero = m
1131        .global_expressions
1132        .append(Expression::ZeroValue(ty_u32), nowhere);
1133    let r#override = m.overrides.append(
1134        Override {
1135            name: Some("bad_override".into()),
1136            id: None,
1137            ty: ty_u32,
1138            init: Some(ex_zero),
1139        },
1140        nowhere,
1141    );
1142    let ty_arr = m.types.insert(
1143        Type {
1144            name: Some("bad_array".to_string()),
1145            inner: TypeInner::Array {
1146                base: ty_u32,
1147                size: ArraySize::Pending(r#override),
1148                stride: 4,
1149            },
1150        },
1151        nowhere,
1152    );
1153    let ex_arr = m
1154        .global_expressions
1155        .append(Expression::ZeroValue(ty_arr), nowhere);
1156
1157    assert!(Validator::validate_module_handles(&m).is_ok());
1158
1159    // Mutate `r#override`'s initializer to `ex_arr`, introducing a cycle.
1160    // Validation should catch the cycle.
1161    m.overrides[r#override].init = Some(ex_arr);
1162    assert!(Validator::validate_module_handles(&m).is_err());
1163}