naga/valid/
handles.rs

1//! Implementation of `Validator::validate_module_handles`.
2
3use core::{convert::TryInto, hash::Hash};
4
5use super::{TypeError, ValidationError};
6use crate::non_max_u32::NonMaxU32;
7use crate::{
8    arena::{BadHandle, BadRangeError},
9    diagnostic_filter::DiagnosticFilterNode,
10    EntryPoint, Handle,
11};
12use crate::{Arena, UniqueArena};
13
14use alloc::string::ToString;
15
16impl super::Validator {
17    /// Validates that all handles within `module` are:
18    ///
19    /// * Valid, in the sense that they contain indices within each arena structure inside the
20    ///   [`crate::Module`] type.
21    /// * No arena contents contain any items that have forward dependencies; that is, the value
22    ///   associated with a handle only may contain references to handles in the same arena that
23    ///   were constructed before it.
24    ///
25    /// By validating the above conditions, we free up subsequent logic to assume that handle
26    /// accesses are infallible.
27    ///
28    /// # Errors
29    ///
30    /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
31    /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
32    /// validation pass.
33    pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
34        let &crate::Module {
35            ref constants,
36            ref overrides,
37            ref entry_points,
38            ref functions,
39            ref global_variables,
40            ref types,
41            ref special_types,
42            ref global_expressions,
43            ref diagnostic_filters,
44            ref diagnostic_filter_leaf,
45            ref doc_comments,
46        } = module;
47
48        // Because types can refer to global expressions and vice versa, to
49        // ensure the overall structure is free of cycles, we must traverse them
50        // both in tandem.
51        //
52        // Try to visit all types and global expressions in an order such that
53        // each item refers only to previously visited items. If we succeed,
54        // that shows that there cannot be any cycles, since walking any edge
55        // advances you towards the beginning of the visiting order.
56        //
57        // Validate all the handles in types and expressions as we traverse the
58        // arenas.
59        let mut global_exprs_iter = global_expressions.iter().peekable();
60        for (th, t) in types.iter() {
61            // Imagine the `for` loop and `global_exprs_iter` as two fingers
62            // walking the type and global expression arenas. They don't visit
63            // elements at the same rate: sometimes one processes a bunch of
64            // elements while the other one stays still. But at each point, they
65            // check that the two ranges of elements they've visited only refer
66            // to other elements in those ranges.
67            //
68            // For brevity, we'll say 'handles behind `global_exprs_iter`' to
69            // mean handles that have already been produced by
70            // `global_exprs_iter`. Once `global_exprs_iter` returns `None`, all
71            // global expression handles are 'behind' it.
72            //
73            // At this point:
74            //
75            // - All types visited by prior iterations (that is, before
76            //   `th`/`t`) refer only to expressions behind `global_exprs_iter`.
77            //
78            //   On the first iteration, this is obviously true: there are no
79            //   prior iterations, and `global_exprs_iter` hasn't produced
80            //   anything yet. At the bottom of the loop, we'll claim that it's
81            //   true for `th`/`t` as well, so the condition remains true when
82            //   we advance to the next type.
83            //
84            // - All expressions behind `global_exprs_iter` refer only to
85            //   previously visited types.
86            //
87            //   Again, trivially true at the start, and we'll show it's true
88            //   about each expression that `global_exprs_iter` produces.
89            //
90            // Once we also check that arena elements only refer to prior
91            // elements in that arena, we can see that `th`/`t` does not
92            // participate in a cycle: it only refers to previously visited
93            // types and expressions behind `global_exprs_iter`, and none of
94            // those refer to `th`/`t`, because they passed the same checks
95            // before we reached `th`/`t`.
96            if let Some(max_expr) = Self::validate_type_handles((th, t), overrides)? {
97                max_expr.check_valid_for(global_expressions)?;
98                // Since `t` refers to `max_expr`, if we want our invariants to
99                // remain true, we must advance `global_exprs_iter` beyond
100                // `max_expr`.
101                while let Some((eh, e)) = global_exprs_iter.next_if(|&(eh, _)| eh <= max_expr) {
102                    if let Some(max_type) =
103                        Self::validate_const_expression_handles((eh, e), constants, overrides)?
104                    {
105                        // Show that `eh` refers only to previously visited types.
106                        th.check_dep(max_type)?;
107                    }
108                    // We've advanced `global_exprs_iter` past `eh` already. But
109                    // since we now know that `eh` refers only to previously
110                    // visited types, it is again true that all expressions
111                    // behind `global_exprs_iter` refer only to previously
112                    // visited types. So we can continue to the next expression.
113                }
114            }
115
116            // Here we know that if `th` refers to any expressions at all,
117            // `max_expr` is the latest one. And we know that `max_expr` is
118            // behind `global_exprs_iter`. So `th` refers only to expressions
119            // behind `global_exprs_iter`, and the invariants will still be
120            // true on the next iteration.
121        }
122
123        // Since we also enforced the usual intra-arena rules that expressions
124        // refer only to prior expressions, expressions can only form cycles if
125        // they include types. But we've shown that all types are acyclic, so
126        // all expressions must be acyclic as well.
127        //
128        // Validate the remaining expressions normally.
129        for handle_and_expr in global_exprs_iter {
130            Self::validate_const_expression_handles(handle_and_expr, constants, overrides)?;
131        }
132
133        let validate_type = |handle| Self::validate_type_handle(handle, types);
134        let validate_const_expr =
135            |handle| Self::validate_expression_handle(handle, global_expressions);
136
137        for (_handle, constant) in constants.iter() {
138            let &crate::Constant { name: _, ty, init } = constant;
139            validate_type(ty)?;
140            validate_const_expr(init)?;
141        }
142
143        for (_handle, r#override) in overrides.iter() {
144            let &crate::Override {
145                name: _,
146                id: _,
147                ty,
148                init,
149            } = r#override;
150            validate_type(ty)?;
151            if let Some(init_expr) = init {
152                validate_const_expr(init_expr)?;
153            }
154        }
155
156        for (_handle, global_variable) in global_variables.iter() {
157            let &crate::GlobalVariable {
158                name: _,
159                space: _,
160                binding: _,
161                ty,
162                init,
163                memory_decorations: _,
164            } = global_variable;
165            validate_type(ty)?;
166            if let Some(init_expr) = init {
167                validate_const_expr(init_expr)?;
168            }
169        }
170
171        let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
172            let &crate::Function {
173                name: _,
174                ref arguments,
175                ref result,
176                ref local_variables,
177                ref expressions,
178                ref named_expressions,
179                ref body,
180                ref diagnostic_filter_leaf,
181            } = function;
182
183            for arg in arguments.iter() {
184                let &crate::FunctionArgument {
185                    name: _,
186                    ty,
187                    binding: _,
188                } = arg;
189                validate_type(ty)?;
190            }
191
192            if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
193                validate_type(ty)?;
194            }
195
196            for (_handle, local_variable) in local_variables.iter() {
197                let &crate::LocalVariable { name: _, ty, init } = local_variable;
198                validate_type(ty)?;
199                if let Some(init) = init {
200                    Self::validate_expression_handle(init, expressions)?;
201                }
202            }
203
204            for handle in named_expressions.keys().copied() {
205                Self::validate_expression_handle(handle, expressions)?;
206            }
207
208            for handle_and_expr in expressions.iter() {
209                Self::validate_expression_handles(
210                    handle_and_expr,
211                    constants,
212                    overrides,
213                    types,
214                    local_variables,
215                    global_variables,
216                    functions,
217                    function_handle,
218                )?;
219            }
220
221            Self::validate_block_handles(body, expressions, functions)?;
222
223            if let Some(handle) = *diagnostic_filter_leaf {
224                handle.check_valid_for(diagnostic_filters)?;
225            }
226
227            Ok(())
228        };
229
230        for entry_point in entry_points.iter() {
231            validate_function(None, &entry_point.function)?;
232            if let Some(sizes) = entry_point.workgroup_size_overrides {
233                for size in sizes.iter().filter_map(|x| *x) {
234                    validate_const_expr(size)?;
235                }
236            }
237            if let Some(task_payload) = entry_point.task_payload {
238                Self::validate_global_variable_handle(task_payload, global_variables)?;
239            }
240            if let Some(ref mesh_info) = entry_point.mesh_info {
241                Self::validate_global_variable_handle(mesh_info.output_variable, global_variables)?;
242                validate_type(mesh_info.vertex_output_type)?;
243                validate_type(mesh_info.primitive_output_type)?;
244                for ov in mesh_info
245                    .max_vertices_override
246                    .iter()
247                    .chain(mesh_info.max_primitives_override.iter())
248                {
249                    validate_const_expr(*ov)?;
250                }
251            }
252        }
253
254        for (function_handle, function) in functions.iter() {
255            validate_function(Some(function_handle), function)?;
256        }
257
258        if let Some(ty) = special_types.ray_desc {
259            validate_type(ty)?;
260        }
261        if let Some(ty) = special_types.ray_intersection {
262            validate_type(ty)?;
263        }
264        if let Some(ty) = special_types.ray_vertex_return {
265            validate_type(ty)?;
266        }
267
268        for (handle, _node) in diagnostic_filters.iter() {
269            let DiagnosticFilterNode { inner: _, parent } = diagnostic_filters[handle];
270            handle.check_dep_opt(parent)?;
271        }
272        if let Some(handle) = *diagnostic_filter_leaf {
273            handle.check_valid_for(diagnostic_filters)?;
274        }
275
276        if let Some(doc_comments) = doc_comments.as_ref() {
277            let crate::DocComments {
278                module: _,
279                types: ref doc_comments_for_types,
280                struct_members: ref doc_comments_for_struct_members,
281                entry_points: ref doc_comments_for_entry_points,
282                functions: ref doc_comments_for_functions,
283                constants: ref doc_comments_for_constants,
284                global_variables: ref doc_comments_for_global_variables,
285            } = **doc_comments;
286
287            for (&ty, _) in doc_comments_for_types.iter() {
288                validate_type(ty)?;
289            }
290
291            for (&(ty, struct_member_index), _) in doc_comments_for_struct_members.iter() {
292                validate_type(ty)?;
293                let struct_type = types.get_handle(ty).unwrap();
294                match struct_type.inner {
295                    crate::TypeInner::Struct {
296                        ref members,
297                        span: ref _span,
298                    } => {
299                        (0..members.len())
300                            .contains(&struct_member_index)
301                            .then_some(())
302                            // TODO: what errors should this be?
303                            .ok_or_else(|| ValidationError::Type {
304                                handle: ty,
305                                name: struct_type.name.as_ref().map_or_else(
306                                    || "members length incorrect".to_string(),
307                                    |name| name.to_string(),
308                                ),
309                                source: TypeError::InvalidData(ty),
310                            })?;
311                    }
312                    _ => {
313                        // TODO: internal error ? We should never get here.
314                        // If entering there, it's probably that we forgot to adjust a handle in the compact phase.
315                        return Err(ValidationError::Type {
316                            handle: ty,
317                            name: struct_type
318                                .name
319                                .as_ref()
320                                .map_or_else(|| "Unknown".to_string(), |name| name.to_string()),
321                            source: TypeError::InvalidData(ty),
322                        });
323                    }
324                }
325                for (&function, _) in doc_comments_for_functions.iter() {
326                    Self::validate_function_handle(function, functions)?;
327                }
328                for (&entry_point_index, _) in doc_comments_for_entry_points.iter() {
329                    Self::validate_entry_point_index(entry_point_index, entry_points)?;
330                }
331                for (&constant, _) in doc_comments_for_constants.iter() {
332                    Self::validate_constant_handle(constant, constants)?;
333                }
334                for (&global_variable, _) in doc_comments_for_global_variables.iter() {
335                    Self::validate_global_variable_handle(global_variable, global_variables)?;
336                }
337            }
338        }
339
340        Ok(())
341    }
342
343    fn validate_type_handle(
344        handle: Handle<crate::Type>,
345        types: &UniqueArena<crate::Type>,
346    ) -> Result<(), InvalidHandleError> {
347        handle.check_valid_for_uniq(types).map(|_| ())
348    }
349
350    fn validate_constant_handle(
351        handle: Handle<crate::Constant>,
352        constants: &Arena<crate::Constant>,
353    ) -> Result<(), InvalidHandleError> {
354        handle.check_valid_for(constants).map(|_| ())
355    }
356
357    fn validate_global_variable_handle(
358        handle: Handle<crate::GlobalVariable>,
359        global_variables: &Arena<crate::GlobalVariable>,
360    ) -> Result<(), InvalidHandleError> {
361        handle.check_valid_for(global_variables).map(|_| ())
362    }
363
364    fn validate_override_handle(
365        handle: Handle<crate::Override>,
366        overrides: &Arena<crate::Override>,
367    ) -> Result<(), InvalidHandleError> {
368        handle.check_valid_for(overrides).map(|_| ())
369    }
370
371    fn validate_expression_handle(
372        handle: Handle<crate::Expression>,
373        expressions: &Arena<crate::Expression>,
374    ) -> Result<(), InvalidHandleError> {
375        handle.check_valid_for(expressions).map(|_| ())
376    }
377
378    fn validate_function_handle(
379        handle: Handle<crate::Function>,
380        functions: &Arena<crate::Function>,
381    ) -> Result<(), InvalidHandleError> {
382        handle.check_valid_for(functions).map(|_| ())
383    }
384
385    /// Validate all handles that occur in `ty`, whose handle is `handle`.
386    ///
387    /// If `ty` refers to any expressions, return the highest-indexed expression
388    /// handle that it uses. This is used for detecting cycles between the
389    /// expression and type arenas.
390    fn validate_type_handles(
391        (handle, ty): (Handle<crate::Type>, &crate::Type),
392        overrides: &Arena<crate::Override>,
393    ) -> Result<Option<Handle<crate::Expression>>, InvalidHandleError> {
394        let max_expr = match ty.inner {
395            crate::TypeInner::Scalar { .. }
396            | crate::TypeInner::Vector { .. }
397            | crate::TypeInner::Matrix { .. }
398            | crate::TypeInner::CooperativeMatrix { .. }
399            | crate::TypeInner::ValuePointer { .. }
400            | crate::TypeInner::Atomic { .. }
401            | crate::TypeInner::Image { .. }
402            | crate::TypeInner::Sampler { .. }
403            | crate::TypeInner::AccelerationStructure { .. }
404            | crate::TypeInner::RayQuery { .. } => None,
405            crate::TypeInner::Pointer { base, space: _ } => {
406                handle.check_dep(base)?;
407                None
408            }
409            crate::TypeInner::Array { base, size, .. }
410            | crate::TypeInner::BindingArray { base, size, .. } => {
411                handle.check_dep(base)?;
412                match size {
413                    crate::ArraySize::Pending(h) => {
414                        Self::validate_override_handle(h, overrides)?;
415                        let r#override = &overrides[h];
416                        handle.check_dep(r#override.ty)?;
417                        r#override.init
418                    }
419                    crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
420                }
421            }
422            crate::TypeInner::Struct {
423                ref members,
424                span: _,
425            } => {
426                handle.check_dep_iter(members.iter().map(|m| m.ty))?;
427                None
428            }
429        };
430
431        Ok(max_expr)
432    }
433
434    fn validate_entry_point_index(
435        entry_point_index: usize,
436        entry_points: &[EntryPoint],
437    ) -> Result<(), InvalidHandleError> {
438        (0..entry_points.len())
439            .contains(&entry_point_index)
440            .then_some(())
441            .ok_or_else(|| {
442                BadHandle {
443                    kind: "EntryPoint",
444                    index: entry_point_index,
445                }
446                .into()
447            })
448    }
449
450    /// Validate all handles that occur in `expression`, whose handle is `handle`.
451    ///
452    /// If `expression` refers to any `Type`s, return the highest-indexed type
453    /// handle that it uses. This is used for detecting cycles between the
454    /// expression and type arenas.
455    fn validate_const_expression_handles(
456        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
457        constants: &Arena<crate::Constant>,
458        overrides: &Arena<crate::Override>,
459    ) -> Result<Option<Handle<crate::Type>>, InvalidHandleError> {
460        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
461        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
462
463        let max_type = match *expression {
464            crate::Expression::Literal(_) => None,
465            crate::Expression::Constant(constant) => {
466                validate_constant(constant)?;
467                handle.check_dep(constants[constant].init)?;
468                None
469            }
470            crate::Expression::Override(r#override) => {
471                validate_override(r#override)?;
472                if let Some(init) = overrides[r#override].init {
473                    handle.check_dep(init)?;
474                }
475                None
476            }
477            crate::Expression::ZeroValue(ty) => Some(ty),
478            crate::Expression::Compose { ty, ref components } => {
479                handle.check_dep_iter(components.iter().copied())?;
480                Some(ty)
481            }
482            _ => None,
483        };
484        Ok(max_type)
485    }
486
487    #[allow(clippy::too_many_arguments)]
488    fn validate_expression_handles(
489        (handle, expression): (Handle<crate::Expression>, &crate::Expression),
490        constants: &Arena<crate::Constant>,
491        overrides: &Arena<crate::Override>,
492        types: &UniqueArena<crate::Type>,
493        local_variables: &Arena<crate::LocalVariable>,
494        global_variables: &Arena<crate::GlobalVariable>,
495        functions: &Arena<crate::Function>,
496        // The handle of the current function or `None` if it's an entry point
497        current_function: Option<Handle<crate::Function>>,
498    ) -> Result<(), InvalidHandleError> {
499        let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
500        let validate_override = |handle| Self::validate_override_handle(handle, overrides);
501        let validate_type = |handle| Self::validate_type_handle(handle, types);
502
503        match *expression {
504            crate::Expression::Access { base, index } => {
505                handle.check_dep(base)?.check_dep(index)?;
506            }
507            crate::Expression::AccessIndex { base, .. } => {
508                handle.check_dep(base)?;
509            }
510            crate::Expression::Splat { value, .. } => {
511                handle.check_dep(value)?;
512            }
513            crate::Expression::Swizzle { vector, .. } => {
514                handle.check_dep(vector)?;
515            }
516            crate::Expression::Literal(_) => {}
517            crate::Expression::Constant(constant) => {
518                validate_constant(constant)?;
519            }
520            crate::Expression::Override(r#override) => {
521                validate_override(r#override)?;
522            }
523            crate::Expression::ZeroValue(ty) => {
524                validate_type(ty)?;
525            }
526            crate::Expression::Compose { ty, ref components } => {
527                validate_type(ty)?;
528                handle.check_dep_iter(components.iter().copied())?;
529            }
530            crate::Expression::FunctionArgument(_arg_idx) => (),
531            crate::Expression::GlobalVariable(global_variable) => {
532                global_variable.check_valid_for(global_variables)?;
533            }
534            crate::Expression::LocalVariable(local_variable) => {
535                local_variable.check_valid_for(local_variables)?;
536            }
537            crate::Expression::Load { pointer } => {
538                handle.check_dep(pointer)?;
539            }
540            crate::Expression::ImageSample {
541                image,
542                sampler,
543                gather: _,
544                coordinate,
545                array_index,
546                offset,
547                level,
548                depth_ref,
549                clamp_to_edge: _,
550            } => {
551                handle
552                    .check_dep(image)?
553                    .check_dep(sampler)?
554                    .check_dep(coordinate)?
555                    .check_dep_opt(array_index)?
556                    .check_dep_opt(offset)?;
557
558                match level {
559                    crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
560                    crate::SampleLevel::Exact(expr) => {
561                        handle.check_dep(expr)?;
562                    }
563                    crate::SampleLevel::Bias(expr) => {
564                        handle.check_dep(expr)?;
565                    }
566                    crate::SampleLevel::Gradient { x, y } => {
567                        handle.check_dep(x)?.check_dep(y)?;
568                    }
569                };
570
571                handle.check_dep_opt(depth_ref)?;
572            }
573            crate::Expression::ImageLoad {
574                image,
575                coordinate,
576                array_index,
577                sample,
578                level,
579            } => {
580                handle
581                    .check_dep(image)?
582                    .check_dep(coordinate)?
583                    .check_dep_opt(array_index)?
584                    .check_dep_opt(sample)?
585                    .check_dep_opt(level)?;
586            }
587            crate::Expression::ImageQuery { image, query } => {
588                handle.check_dep(image)?;
589                match query {
590                    crate::ImageQuery::Size { level } => {
591                        handle.check_dep_opt(level)?;
592                    }
593                    crate::ImageQuery::NumLevels
594                    | crate::ImageQuery::NumLayers
595                    | crate::ImageQuery::NumSamples => (),
596                };
597            }
598            crate::Expression::Unary {
599                op: _,
600                expr: operand,
601            } => {
602                handle.check_dep(operand)?;
603            }
604            crate::Expression::Binary { op: _, left, right } => {
605                handle.check_dep(left)?.check_dep(right)?;
606            }
607            crate::Expression::Select {
608                condition,
609                accept,
610                reject,
611            } => {
612                handle
613                    .check_dep(condition)?
614                    .check_dep(accept)?
615                    .check_dep(reject)?;
616            }
617            crate::Expression::Derivative { expr: argument, .. } => {
618                handle.check_dep(argument)?;
619            }
620            crate::Expression::Relational { fun: _, argument } => {
621                handle.check_dep(argument)?;
622            }
623            crate::Expression::Math {
624                fun: _,
625                arg,
626                arg1,
627                arg2,
628                arg3,
629            } => {
630                handle
631                    .check_dep(arg)?
632                    .check_dep_opt(arg1)?
633                    .check_dep_opt(arg2)?
634                    .check_dep_opt(arg3)?;
635            }
636            crate::Expression::As {
637                expr: input,
638                kind: _,
639                convert: _,
640            } => {
641                handle.check_dep(input)?;
642            }
643            crate::Expression::CallResult(function) => {
644                Self::validate_function_handle(function, functions)?;
645                if let Some(handle) = current_function {
646                    handle.check_dep(function)?;
647                }
648            }
649            crate::Expression::AtomicResult { .. }
650            | crate::Expression::RayQueryProceedResult
651            | crate::Expression::SubgroupBallotResult
652            | crate::Expression::SubgroupOperationResult { .. }
653            | crate::Expression::WorkGroupUniformLoadResult { .. } => (),
654            crate::Expression::ArrayLength(array) => {
655                handle.check_dep(array)?;
656            }
657            crate::Expression::RayQueryGetIntersection {
658                query,
659                committed: _,
660            }
661            | crate::Expression::RayQueryVertexPositions {
662                query,
663                committed: _,
664            } => {
665                handle.check_dep(query)?;
666            }
667            crate::Expression::CooperativeLoad { ref data, .. } => {
668                handle.check_dep(data.pointer)?.check_dep(data.stride)?;
669            }
670            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
671                handle.check_dep(a)?.check_dep(b)?.check_dep(c)?;
672            }
673        }
674        Ok(())
675    }
676
677    fn validate_block_handles(
678        block: &crate::Block,
679        expressions: &Arena<crate::Expression>,
680        functions: &Arena<crate::Function>,
681    ) -> Result<(), InvalidHandleError> {
682        let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
683        let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
684        let validate_expr_opt = |handle_opt| {
685            if let Some(handle) = handle_opt {
686                validate_expr(handle)?;
687            }
688            Ok(())
689        };
690
691        block.iter().try_for_each(|stmt| match *stmt {
692            crate::Statement::Emit(ref expr_range) => {
693                expr_range.check_valid_for(expressions)?;
694                Ok(())
695            }
696            crate::Statement::Block(ref block) => {
697                validate_block(block)?;
698                Ok(())
699            }
700            crate::Statement::If {
701                condition,
702                ref accept,
703                ref reject,
704            } => {
705                validate_expr(condition)?;
706                validate_block(accept)?;
707                validate_block(reject)?;
708                Ok(())
709            }
710            crate::Statement::Switch {
711                selector,
712                ref cases,
713            } => {
714                validate_expr(selector)?;
715                for &crate::SwitchCase {
716                    value: _,
717                    ref body,
718                    fall_through: _,
719                } in cases
720                {
721                    validate_block(body)?;
722                }
723                Ok(())
724            }
725            crate::Statement::Loop {
726                ref body,
727                ref continuing,
728                break_if,
729            } => {
730                validate_block(body)?;
731                validate_block(continuing)?;
732                validate_expr_opt(break_if)?;
733                Ok(())
734            }
735            crate::Statement::Return { value } => validate_expr_opt(value),
736            crate::Statement::Store { pointer, value } => {
737                validate_expr(pointer)?;
738                validate_expr(value)?;
739                Ok(())
740            }
741            crate::Statement::ImageStore {
742                image,
743                coordinate,
744                array_index,
745                value,
746            } => {
747                validate_expr(image)?;
748                validate_expr(coordinate)?;
749                validate_expr_opt(array_index)?;
750                validate_expr(value)?;
751                Ok(())
752            }
753            crate::Statement::Atomic {
754                pointer,
755                fun,
756                value,
757                result,
758            } => {
759                validate_expr(pointer)?;
760                match fun {
761                    crate::AtomicFunction::Add
762                    | crate::AtomicFunction::Subtract
763                    | crate::AtomicFunction::And
764                    | crate::AtomicFunction::ExclusiveOr
765                    | crate::AtomicFunction::InclusiveOr
766                    | crate::AtomicFunction::Min
767                    | crate::AtomicFunction::Max => (),
768                    crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
769                };
770                validate_expr(value)?;
771                if let Some(result) = result {
772                    validate_expr(result)?;
773                }
774                Ok(())
775            }
776            crate::Statement::ImageAtomic {
777                image,
778                coordinate,
779                array_index,
780                fun: _,
781                value,
782            } => {
783                validate_expr(image)?;
784                validate_expr(coordinate)?;
785                validate_expr_opt(array_index)?;
786                validate_expr(value)?;
787                Ok(())
788            }
789            crate::Statement::WorkGroupUniformLoad { pointer, result } => {
790                validate_expr(pointer)?;
791                validate_expr(result)?;
792                Ok(())
793            }
794            crate::Statement::Call {
795                function,
796                ref arguments,
797                result,
798            } => {
799                Self::validate_function_handle(function, functions)?;
800                for arg in arguments.iter().copied() {
801                    validate_expr(arg)?;
802                }
803                validate_expr_opt(result)?;
804                Ok(())
805            }
806            crate::Statement::RayQuery { query, ref fun } => {
807                validate_expr(query)?;
808                match *fun {
809                    crate::RayQueryFunction::Initialize {
810                        acceleration_structure,
811                        descriptor,
812                    } => {
813                        validate_expr(acceleration_structure)?;
814                        validate_expr(descriptor)?;
815                    }
816                    crate::RayQueryFunction::Proceed { result } => {
817                        validate_expr(result)?;
818                    }
819                    crate::RayQueryFunction::GenerateIntersection { hit_t } => {
820                        validate_expr(hit_t)?;
821                    }
822                    crate::RayQueryFunction::ConfirmIntersection => {}
823                    crate::RayQueryFunction::Terminate => {}
824                }
825                Ok(())
826            }
827            crate::Statement::SubgroupBallot { result, predicate } => {
828                validate_expr_opt(predicate)?;
829                validate_expr(result)?;
830                Ok(())
831            }
832            crate::Statement::SubgroupCollectiveOperation {
833                op: _,
834                collective_op: _,
835                argument,
836                result,
837            } => {
838                validate_expr(argument)?;
839                validate_expr(result)?;
840                Ok(())
841            }
842            crate::Statement::SubgroupGather {
843                mode,
844                argument,
845                result,
846            } => {
847                validate_expr(argument)?;
848                match mode {
849                    crate::GatherMode::BroadcastFirst => {}
850                    crate::GatherMode::Broadcast(index)
851                    | crate::GatherMode::Shuffle(index)
852                    | crate::GatherMode::ShuffleDown(index)
853                    | crate::GatherMode::ShuffleUp(index)
854                    | crate::GatherMode::ShuffleXor(index)
855                    | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?,
856                    crate::GatherMode::QuadSwap(_) => {}
857                }
858                validate_expr(result)?;
859                Ok(())
860            }
861            crate::Statement::CooperativeStore { target, ref data } => {
862                validate_expr(target)?;
863                validate_expr(data.pointer)?;
864                validate_expr(data.stride)?;
865                Ok(())
866            }
867            crate::Statement::RayPipelineFunction(fun) => match fun {
868                crate::RayPipelineFunction::TraceRay {
869                    acceleration_structure,
870                    descriptor,
871                    payload,
872                } => {
873                    validate_expr(acceleration_structure)?;
874                    validate_expr(descriptor)?;
875                    validate_expr(payload)?;
876                    Ok(())
877                }
878            },
879            crate::Statement::Break
880            | crate::Statement::Continue
881            | crate::Statement::Kill
882            | crate::Statement::ControlBarrier(_)
883            | crate::Statement::MemoryBarrier(_) => Ok(()),
884        })
885    }
886}
887
888impl From<BadHandle> for ValidationError {
889    fn from(source: BadHandle) -> Self {
890        Self::InvalidHandle(source.into())
891    }
892}
893
894impl From<FwdDepError> for ValidationError {
895    fn from(source: FwdDepError) -> Self {
896        Self::InvalidHandle(source.into())
897    }
898}
899
900impl From<BadRangeError> for ValidationError {
901    fn from(source: BadRangeError) -> Self {
902        Self::InvalidHandle(source.into())
903    }
904}
905
906#[derive(Clone, Debug, thiserror::Error)]
907#[cfg_attr(test, derive(PartialEq))]
908pub enum InvalidHandleError {
909    #[error(transparent)]
910    BadHandle(#[from] BadHandle),
911    #[error(transparent)]
912    ForwardDependency(#[from] FwdDepError),
913    #[error(transparent)]
914    BadRange(#[from] BadRangeError),
915}
916
917#[derive(Clone, Debug, thiserror::Error)]
918#[cfg_attr(test, derive(PartialEq))]
919#[error(
920    "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
921    which has not been processed yet"
922)]
923pub struct FwdDepError {
924    // This error is used for many `Handle` types, but there's no point in making this generic, so
925    // we just flatten them all to `Handle<()>` here.
926    subject: Handle<()>,
927    subject_kind: &'static str,
928    depends_on: Handle<()>,
929    depends_on_kind: &'static str,
930}
931
932impl<T> Handle<T> {
933    /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
934    pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
935        arena.check_contains_handle(self)?;
936        Ok(())
937    }
938
939    /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
940    pub(self) fn check_valid_for_uniq(
941        self,
942        arena: &UniqueArena<T>,
943    ) -> Result<(), InvalidHandleError>
944    where
945        T: Eq + Hash,
946    {
947        arena.check_contains_handle(self)?;
948        Ok(())
949    }
950
951    /// Check that `depends_on` was constructed before `self` by comparing handle indices.
952    ///
953    /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
954    /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
955    /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
956    /// recursive definitions of arena-based values in linear time.
957    ///
958    /// # Errors
959    ///
960    /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
961    /// than `self`'s, this function returns an error.
962    pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
963        if depends_on < self {
964            Ok(self)
965        } else {
966            let erase_handle_type = |handle: Handle<_>| {
967                Handle::new(NonMaxU32::new((handle.index()).try_into().unwrap()).unwrap())
968            };
969            Err(FwdDepError {
970                subject: erase_handle_type(self),
971                subject_kind: core::any::type_name::<T>(),
972                depends_on: erase_handle_type(depends_on),
973                depends_on_kind: core::any::type_name::<T>(),
974            })
975        }
976    }
977
978    /// Like [`Self::check_dep`], except for [`Option`]al handle values.
979    pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
980        self.check_dep_iter(depends_on.into_iter())
981    }
982
983    /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
984    pub(self) fn check_dep_iter(
985        self,
986        depends_on: impl Iterator<Item = Self>,
987    ) -> Result<Self, FwdDepError> {
988        for handle in depends_on {
989            self.check_dep(handle)?;
990        }
991        Ok(self)
992    }
993}
994
995impl<T> crate::arena::Range<T> {
996    pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
997        arena.check_contains_range(self)
998    }
999}
1000
1001#[test]
1002fn constant_deps() {
1003    use crate::{Constant, Expression, Literal, Span, Type, TypeInner};
1004
1005    let nowhere = Span::default();
1006
1007    let mut types = UniqueArena::new();
1008    let mut const_exprs = Arena::new();
1009    let mut fun_exprs = Arena::new();
1010    let mut constants = Arena::new();
1011    let overrides = Arena::new();
1012
1013    let i32_handle = types.insert(
1014        Type {
1015            name: None,
1016            inner: TypeInner::Scalar(crate::Scalar::I32),
1017        },
1018        nowhere,
1019    );
1020
1021    // Construct a self-referential constant by misusing a handle to
1022    // fun_exprs as a constant initializer.
1023    let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere);
1024    let self_referential_const = constants.append(
1025        Constant {
1026            name: None,
1027            ty: i32_handle,
1028            init: fun_expr,
1029        },
1030        nowhere,
1031    );
1032    let _self_referential_expr =
1033        const_exprs.append(Expression::Constant(self_referential_const), nowhere);
1034
1035    for handle_and_expr in const_exprs.iter() {
1036        assert!(super::Validator::validate_const_expression_handles(
1037            handle_and_expr,
1038            &constants,
1039            &overrides,
1040        )
1041        .is_err());
1042    }
1043}
1044
1045#[test]
1046fn array_size_deps() {
1047    use super::Validator;
1048    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1049
1050    let nowhere = Span::default();
1051
1052    let mut m = crate::Module::default();
1053
1054    let ty_u32 = m.types.insert(
1055        Type {
1056            name: Some("u32".to_string()),
1057            inner: TypeInner::Scalar(Scalar::U32),
1058        },
1059        nowhere,
1060    );
1061    let ex_zero = m
1062        .global_expressions
1063        .append(Expression::ZeroValue(ty_u32), nowhere);
1064    let ty_handle = m.overrides.append(
1065        Override {
1066            name: None,
1067            id: None,
1068            ty: ty_u32,
1069            init: Some(ex_zero),
1070        },
1071        nowhere,
1072    );
1073    let ty_arr = m.types.insert(
1074        Type {
1075            name: Some("bad_array".to_string()),
1076            inner: TypeInner::Array {
1077                base: ty_u32,
1078                size: ArraySize::Pending(ty_handle),
1079                stride: 4,
1080            },
1081        },
1082        nowhere,
1083    );
1084
1085    // Everything should be okay now.
1086    assert!(Validator::validate_module_handles(&m).is_ok());
1087
1088    // Mutate `ex_zero`'s type to `ty_arr`, introducing a cycle.
1089    // Validation should catch the cycle.
1090    m.global_expressions[ex_zero] = Expression::ZeroValue(ty_arr);
1091    assert!(Validator::validate_module_handles(&m).is_err());
1092}
1093
1094#[test]
1095fn array_size_override() {
1096    use super::Validator;
1097    use crate::{ArraySize, Override, Scalar, Span, Type, TypeInner};
1098
1099    let nowhere = Span::default();
1100
1101    let mut m = crate::Module::default();
1102
1103    let ty_u32 = m.types.insert(
1104        Type {
1105            name: Some("u32".to_string()),
1106            inner: TypeInner::Scalar(Scalar::U32),
1107        },
1108        nowhere,
1109    );
1110
1111    let bad_override: Handle<Override> = Handle::new(NonMaxU32::new(1000).unwrap());
1112    let _ty_arr = m.types.insert(
1113        Type {
1114            name: Some("bad_array".to_string()),
1115            inner: TypeInner::Array {
1116                base: ty_u32,
1117                size: ArraySize::Pending(bad_override),
1118                stride: 4,
1119            },
1120        },
1121        nowhere,
1122    );
1123
1124    assert!(Validator::validate_module_handles(&m).is_err());
1125}
1126
1127#[test]
1128fn override_init_deps() {
1129    use super::Validator;
1130    use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner};
1131
1132    let nowhere = Span::default();
1133
1134    let mut m = crate::Module::default();
1135
1136    let ty_u32 = m.types.insert(
1137        Type {
1138            name: Some("u32".to_string()),
1139            inner: TypeInner::Scalar(Scalar::U32),
1140        },
1141        nowhere,
1142    );
1143    let ex_zero = m
1144        .global_expressions
1145        .append(Expression::ZeroValue(ty_u32), nowhere);
1146    let r#override = m.overrides.append(
1147        Override {
1148            name: Some("bad_override".into()),
1149            id: None,
1150            ty: ty_u32,
1151            init: Some(ex_zero),
1152        },
1153        nowhere,
1154    );
1155    let ty_arr = m.types.insert(
1156        Type {
1157            name: Some("bad_array".to_string()),
1158            inner: TypeInner::Array {
1159                base: ty_u32,
1160                size: ArraySize::Pending(r#override),
1161                stride: 4,
1162            },
1163        },
1164        nowhere,
1165    );
1166    let ex_arr = m
1167        .global_expressions
1168        .append(Expression::ZeroValue(ty_arr), nowhere);
1169
1170    assert!(Validator::validate_module_handles(&m).is_ok());
1171
1172    // Mutate `r#override`'s initializer to `ex_arr`, introducing a cycle.
1173    // Validation should catch the cycle.
1174    m.overrides[r#override].init = Some(ex_arr);
1175    assert!(Validator::validate_module_handles(&m).is_err());
1176}