naga/back/spv/
index.rs

1/*!
2Bounds-checking for SPIR-V output.
3*/
4
5use super::{
6    helpers::{global_needs_wrapper, map_storage_class},
7    selection::Selection,
8    Block, BlockContext, Error, IdGenerator, Instruction, Word,
9};
10use crate::{
11    arena::Handle,
12    proc::{index::GuardedIndex, BoundsCheckPolicy},
13};
14
15/// The results of performing a bounds check.
16///
17/// On success, [`write_bounds_check`](BlockContext::write_bounds_check)
18/// returns a value of this type. The caller can assume that the right
19/// policy has been applied, and simply do what the variant says.
20#[derive(Debug)]
21pub(super) enum BoundsCheckResult {
22    /// The index is statically known and in bounds, with the given value.
23    KnownInBounds(u32),
24
25    /// The given instruction computes the index to be used.
26    ///
27    /// When [`BoundsCheckPolicy::Restrict`] is in force, this is a
28    /// clamped version of the index the user supplied.
29    ///
30    /// When [`BoundsCheckPolicy::Unchecked`] is in force, this is
31    /// simply the index the user supplied. This variant indicates
32    /// that we couldn't prove statically that the index was in
33    /// bounds; otherwise we would have returned [`KnownInBounds`].
34    ///
35    /// [`KnownInBounds`]: BoundsCheckResult::KnownInBounds
36    Computed(Word),
37
38    /// The given instruction computes a boolean condition which is true
39    /// if the index is in bounds.
40    ///
41    /// This is returned when [`BoundsCheckPolicy::ReadZeroSkipWrite`]
42    /// is in force.
43    Conditional {
44        /// The access should only be permitted if this value is true.
45        condition_id: Word,
46
47        /// The access should use this index value.
48        index_id: Word,
49    },
50}
51
52/// A value that we either know at translation time, or need to compute at runtime.
53#[derive(Copy, Clone)]
54pub(super) enum MaybeKnown<T> {
55    /// The value is known at shader translation time.
56    Known(T),
57
58    /// The value is computed by the instruction with the given id.
59    Computed(Word),
60}
61
62impl BlockContext<'_> {
63    /// Emit code to compute the length of a run-time array.
64    ///
65    /// Given `array`, an expression referring a runtime-sized array, return the
66    /// instruction id for the array's length.
67    ///
68    /// Runtime-sized arrays may only appear in the values of global
69    /// variables, which must have one of the following Naga types:
70    ///
71    /// 1. A runtime-sized array.
72    /// 2. A struct whose last member is a runtime-sized array.
73    /// 3. A binding array of 2.
74    ///
75    /// Thus, the expression `array` has the form of:
76    ///
77    /// - An optional [`AccessIndex`], for case 2, applied to...
78    /// - An optional [`Access`] or [`AccessIndex`], for case 3, applied to...
79    /// - A [`GlobalVariable`].
80    ///
81    /// The generated SPIR-V takes into account wrapped globals; see
82    /// [`back::spv::GlobalVariable`] for details.
83    ///
84    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
85    /// [`AccessIndex`]: crate::Expression::AccessIndex
86    /// [`Access`]: crate::Expression::Access
87    /// [`base`]: crate::Expression::Access::base
88    /// [`back::spv::GlobalVariable`]: super::GlobalVariable
89    pub(super) fn write_runtime_array_length(
90        &mut self,
91        array: Handle<crate::Expression>,
92        block: &mut Block,
93    ) -> Result<Word, Error> {
94        // The index into the binding array, if any.
95        let binding_array_index_id: Option<Word>;
96
97        // The handle to the Naga IR global we're referring to.
98        let global_handle: Handle<crate::GlobalVariable>;
99
100        // At the Naga type level, if the runtime-sized array is the final member of a
101        // struct, this is that member's index.
102        //
103        // This does not cover wrappers: if this backend wrapped the Naga global's
104        // type in a synthetic SPIR-V struct (see `global_needs_wrapper`), this is
105        // `None`.
106        let opt_last_member_index: Option<u32>;
107
108        // Inspect `array` and decide whether we have a binding array and/or an
109        // enclosing struct.
110        match self.ir_function.expressions[array] {
111            crate::Expression::AccessIndex { base, index } => {
112                match self.ir_function.expressions[base] {
113                    crate::Expression::AccessIndex {
114                        base: base_outer,
115                        index: index_outer,
116                    } => match self.ir_function.expressions[base_outer] {
117                        // An `AccessIndex` of an `AccessIndex` must be a
118                        // binding array holding structs whose last members are
119                        // runtime-sized arrays.
120                        crate::Expression::GlobalVariable(handle) => {
121                            let index_id = self.get_index_constant(index_outer);
122                            binding_array_index_id = Some(index_id);
123                            global_handle = handle;
124                            opt_last_member_index = Some(index);
125                        }
126                        _ => {
127                            return Err(Error::Validation(
128                                "array length expression: AccessIndex(AccessIndex(Global))",
129                            ))
130                        }
131                    },
132                    crate::Expression::Access {
133                        base: base_outer,
134                        index: index_outer,
135                    } => match self.ir_function.expressions[base_outer] {
136                        // Similarly, an `AccessIndex` of an `Access` must be a
137                        // binding array holding structs whose last members are
138                        // runtime-sized arrays.
139                        crate::Expression::GlobalVariable(handle) => {
140                            let index_id = self.cached[index_outer];
141                            binding_array_index_id = Some(index_id);
142                            global_handle = handle;
143                            opt_last_member_index = Some(index);
144                        }
145                        _ => {
146                            return Err(Error::Validation(
147                                "array length expression: AccessIndex(Access(Global))",
148                            ))
149                        }
150                    },
151                    crate::Expression::GlobalVariable(handle) => {
152                        // An outer `AccessIndex` applied directly to a
153                        // `GlobalVariable`. Since binding arrays can only contain
154                        // structs, this must be referring to the last member of a
155                        // struct that is a runtime-sized array.
156                        binding_array_index_id = None;
157                        global_handle = handle;
158                        opt_last_member_index = Some(index);
159                    }
160                    _ => {
161                        return Err(Error::Validation(
162                            "array length expression: AccessIndex(<unexpected>)",
163                        ))
164                    }
165                }
166            }
167            crate::Expression::GlobalVariable(handle) => {
168                // A direct reference to a global variable. This must hold the
169                // runtime-sized array directly.
170                binding_array_index_id = None;
171                global_handle = handle;
172                opt_last_member_index = None;
173            }
174            _ => return Err(Error::Validation("array length expression case-4")),
175        };
176
177        // The verifier should have checked this, but make sure the inspection above
178        // agrees with the type about whether a binding array is involved.
179        //
180        // Eventually we do want to support `binding_array<array<T>>`. This check
181        // ensures that whoever relaxes the validator will get an error message from
182        // us, not just bogus SPIR-V.
183        let global = &self.ir_module.global_variables[global_handle];
184        match (
185            &self.ir_module.types[global.ty].inner,
186            binding_array_index_id,
187        ) {
188            (&crate::TypeInner::BindingArray { .. }, Some(_)) => {}
189            (_, None) => {}
190            _ => {
191                return Err(Error::Validation(
192                    "array length expression: bad binding array inference",
193                ))
194            }
195        }
196
197        // SPIR-V allows runtime-sized arrays to appear only as the last member of a
198        // struct. Determine this member's index.
199        let gvar = self.writer.global_variables[global_handle].clone();
200        let global = &self.ir_module.global_variables[global_handle];
201        let needs_wrapper = global_needs_wrapper(self.ir_module, global);
202        let (last_member_index, gvar_id) = match (opt_last_member_index, needs_wrapper) {
203            (Some(index), false) => {
204                // At the Naga type level, the runtime-sized array appears as the
205                // final member of a struct, whose index is `index`. We didn't need to
206                // wrap this, since the Naga type meets SPIR-V's requirements already.
207                (index, gvar.access_id)
208            }
209            (None, true) => {
210                // At the Naga type level, the runtime-sized array does not appear
211                // within a struct. We wrapped this in an OpTypeStruct with nothing
212                // else in it, so the index is zero. OpArrayLength wants the pointer
213                // to the wrapper struct, so use `gvar.var_id`.
214                (0, gvar.var_id)
215            }
216            _ => {
217                return Err(Error::Validation(
218                    "array length expression: bad SPIR-V wrapper struct inference",
219                ));
220            }
221        };
222
223        let structure_id = match binding_array_index_id {
224            // We are indexing inside a binding array, generate the access op.
225            Some(index_id) => {
226                let element_type_id = match self.ir_module.types[global.ty].inner {
227                    crate::TypeInner::BindingArray { base, size: _ } => {
228                        let base_id = self.get_handle_type_id(base);
229                        let class = map_storage_class(global.space);
230                        self.get_pointer_type_id(base_id, class)
231                    }
232                    _ => return Err(Error::Validation("array length expression case-5")),
233                };
234                let structure_id = self.gen_id();
235                block.body.push(Instruction::access_chain(
236                    element_type_id,
237                    structure_id,
238                    gvar_id,
239                    &[index_id],
240                ));
241                structure_id
242            }
243            None => gvar_id,
244        };
245        let length_id = self.gen_id();
246        block.body.push(Instruction::array_length(
247            self.writer.get_u32_type_id(),
248            length_id,
249            structure_id,
250            last_member_index,
251        ));
252
253        Ok(length_id)
254    }
255
256    /// Compute the length of a subscriptable value.
257    ///
258    /// Given `sequence`, an expression referring to some indexable type, return
259    /// its length. The result may either be computed by SPIR-V instructions, or
260    /// known at shader translation time.
261    ///
262    /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
263    /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
264    /// sized, or use a specializable constant as its length.
265    fn write_sequence_length(
266        &mut self,
267        sequence: Handle<crate::Expression>,
268        block: &mut Block,
269    ) -> Result<MaybeKnown<u32>, Error> {
270        let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
271        match sequence_ty.indexable_length_resolved(self.ir_module) {
272            Ok(crate::proc::IndexableLength::Known(known_length)) => {
273                Ok(MaybeKnown::Known(known_length))
274            }
275            Ok(crate::proc::IndexableLength::Dynamic) => {
276                let length_id = self.write_runtime_array_length(sequence, block)?;
277                Ok(MaybeKnown::Computed(length_id))
278            }
279            Err(err) => {
280                log::error!("Sequence length for {sequence:?} failed: {err}");
281                Err(Error::Validation("indexable length"))
282            }
283        }
284    }
285
286    /// Compute the maximum valid index of a subscriptable value.
287    ///
288    /// Given `sequence`, an expression referring to some indexable type, return
289    /// its maximum valid index - one less than its length. The result may
290    /// either be computed, or known at shader translation time.
291    ///
292    /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
293    /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
294    /// sized, or use a specializable constant as its length.
295    fn write_sequence_max_index(
296        &mut self,
297        sequence: Handle<crate::Expression>,
298        block: &mut Block,
299    ) -> Result<MaybeKnown<u32>, Error> {
300        match self.write_sequence_length(sequence, block)? {
301            MaybeKnown::Known(known_length) => {
302                // We should have thrown out all attempts to subscript zero-length
303                // sequences during validation, so the following subtraction should never
304                // underflow.
305                assert!(known_length > 0);
306                // Compute the max index from the length now.
307                Ok(MaybeKnown::Known(known_length - 1))
308            }
309            MaybeKnown::Computed(length_id) => {
310                // Emit code to compute the max index from the length.
311                let const_one_id = self.get_index_constant(1);
312                let max_index_id = self.gen_id();
313                block.body.push(Instruction::binary(
314                    spirv::Op::ISub,
315                    self.writer.get_u32_type_id(),
316                    max_index_id,
317                    length_id,
318                    const_one_id,
319                ));
320                Ok(MaybeKnown::Computed(max_index_id))
321            }
322        }
323    }
324
325    /// Restrict an index to be in range for a vector, matrix, or array.
326    ///
327    /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
328    /// index is left unchanged. An out-of-bounds index is replaced with some
329    /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
330    /// example, negative indices might be changed to refer to the last element
331    /// of the sequence, not the first, as clamping would do.
332    ///
333    /// Either return the restricted index value, if known, or add instructions
334    /// to `block` to compute it, and return the id of the result. See the
335    /// documentation for `BoundsCheckResult` for details.
336    ///
337    /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
338    /// `Pointer` to any of those, or a `ValuePointer`. An array may be
339    /// fixed-size, dynamically sized, or use a specializable constant as its
340    /// length.
341    pub(super) fn write_restricted_index(
342        &mut self,
343        sequence: Handle<crate::Expression>,
344        index: GuardedIndex,
345        block: &mut Block,
346    ) -> Result<BoundsCheckResult, Error> {
347        let max_index = self.write_sequence_max_index(sequence, block)?;
348
349        // If both are known, we can compute the index to be used
350        // right now.
351        if let (GuardedIndex::Known(index), MaybeKnown::Known(max_index)) = (index, max_index) {
352            let restricted = core::cmp::min(index, max_index);
353            return Ok(BoundsCheckResult::KnownInBounds(restricted));
354        }
355
356        let index_id = match index {
357            GuardedIndex::Known(value) => self.get_index_constant(value),
358            GuardedIndex::Expression(expr) => self.cached[expr],
359        };
360
361        let max_index_id = match max_index {
362            MaybeKnown::Known(value) => self.get_index_constant(value),
363            MaybeKnown::Computed(id) => id,
364        };
365
366        // One or the other of the index or length is dynamic, so emit code for
367        // BoundsCheckPolicy::Restrict.
368        let restricted_index_id = self.gen_id();
369        block.body.push(Instruction::ext_inst(
370            self.writer.gl450_ext_inst_id,
371            spirv::GLOp::UMin,
372            self.writer.get_u32_type_id(),
373            restricted_index_id,
374            &[index_id, max_index_id],
375        ));
376        Ok(BoundsCheckResult::Computed(restricted_index_id))
377    }
378
379    /// Write an index bounds comparison to `block`, if needed.
380    ///
381    /// This is used to implement [`BoundsCheckPolicy::ReadZeroSkipWrite`].
382    ///
383    /// If we're able to determine statically that `index` is in bounds for
384    /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
385    /// value of the index. (In principle, one could know that the index is in
386    /// bounds without knowing its specific value, but in our simple-minded
387    /// situation, we always know it.)
388    ///
389    /// If instead we must generate code to perform the comparison at run time,
390    /// return `Conditional(comparison_id)`, where `comparison_id` is an
391    /// instruction producing a boolean value that is true if `index` is in
392    /// bounds for `sequence`.
393    ///
394    /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
395    /// `Pointer` to any of those, or a `ValuePointer`. An array may be
396    /// fixed-size, dynamically sized, or use a specializable constant as its
397    /// length.
398    fn write_index_comparison(
399        &mut self,
400        sequence: Handle<crate::Expression>,
401        index: GuardedIndex,
402        block: &mut Block,
403    ) -> Result<BoundsCheckResult, Error> {
404        let length = self.write_sequence_length(sequence, block)?;
405
406        // If both are known, we can decide whether the index is in
407        // bounds right now.
408        if let (GuardedIndex::Known(index), MaybeKnown::Known(length)) = (index, length) {
409            if index < length {
410                return Ok(BoundsCheckResult::KnownInBounds(index));
411            }
412
413            // In theory, when `index` is bad, we could return a new
414            // `KnownOutOfBounds` variant here. But it's simpler just to fall
415            // through and let the bounds check take place. The shader is broken
416            // anyway, so it doesn't make sense to invest in emitting the ideal
417            // code for it.
418        }
419
420        let index_id = match index {
421            GuardedIndex::Known(value) => self.get_index_constant(value),
422            GuardedIndex::Expression(expr) => self.cached[expr],
423        };
424
425        let length_id = match length {
426            MaybeKnown::Known(value) => self.get_index_constant(value),
427            MaybeKnown::Computed(id) => id,
428        };
429
430        // Compare the index against the length.
431        let condition_id = self.gen_id();
432        block.body.push(Instruction::binary(
433            spirv::Op::ULessThan,
434            self.writer.get_bool_type_id(),
435            condition_id,
436            index_id,
437            length_id,
438        ));
439
440        // Indicate that we did generate the check.
441        Ok(BoundsCheckResult::Conditional {
442            condition_id,
443            index_id,
444        })
445    }
446
447    /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
448    ///
449    /// Generate code to load a value of `result_type` if `condition` is true,
450    /// and generate a null value of that type if it is false. Call `emit_load`
451    /// to emit the instructions to perform the load. Return the id of the
452    /// merged value of the two branches.
453    pub(super) fn write_conditional_indexed_load<F>(
454        &mut self,
455        result_type: Word,
456        condition: Word,
457        block: &mut Block,
458        emit_load: F,
459    ) -> Word
460    where
461        F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
462    {
463        // For the out-of-bounds case, we produce a zero value.
464        let null_id = self.writer.get_constant_null(result_type);
465
466        let mut selection = Selection::start(block, result_type);
467
468        // As it turns out, we don't actually need a full 'if-then-else'
469        // structure for this: SPIR-V constants are declared up front, so the
470        // 'else' block would have no instructions. Instead we emit something
471        // like this:
472        //
473        //     result = zero;
474        //     if in_bounds {
475        //         result = do the load;
476        //     }
477        //     use result;
478
479        // Continue only if the index was in bounds. Otherwise, branch to the
480        // merge block.
481        selection.if_true(self, condition, null_id);
482
483        // The in-bounds path. Perform the access and the load.
484        let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
485
486        selection.finish(self, loaded_value)
487    }
488
489    /// Emit code for bounds checks for an array, vector, or matrix access.
490    ///
491    /// This tries to handle all the critical steps for bounds checks:
492    ///
493    /// - First, select the appropriate bounds check policy for `base`,
494    ///   depending on its address space.
495    ///
496    /// - Next, analyze `index` to see if its value is known at
497    ///   compile time, in which case we can decide statically whether
498    ///   the index is in bounds.
499    ///
500    /// - If the index's value is not known at compile time, emit code to:
501    ///
502    ///     - restrict its value (for [`BoundsCheckPolicy::Restrict`]), or
503    ///
504    ///     - check whether it's in bounds (for
505    ///       [`BoundsCheckPolicy::ReadZeroSkipWrite`]).
506    ///
507    /// Return a [`BoundsCheckResult`] indicating how the index should be
508    /// consumed. See that type's documentation for details.
509    pub(super) fn write_bounds_check(
510        &mut self,
511        base: Handle<crate::Expression>,
512        mut index: GuardedIndex,
513        block: &mut Block,
514    ) -> Result<BoundsCheckResult, Error> {
515        // If the value of `index` is known at compile time, find it now.
516        index.try_resolve_to_constant(&self.ir_function.expressions, self.ir_module);
517
518        let policy = self.writer.bounds_check_policies.choose_policy(
519            base,
520            &self.ir_module.types,
521            self.fun_info,
522        );
523
524        Ok(match policy {
525            BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
526            BoundsCheckPolicy::ReadZeroSkipWrite => {
527                self.write_index_comparison(base, index, block)?
528            }
529            BoundsCheckPolicy::Unchecked => match index {
530                GuardedIndex::Known(value) => BoundsCheckResult::KnownInBounds(value),
531                GuardedIndex::Expression(expr) => BoundsCheckResult::Computed(self.cached[expr]),
532            },
533        })
534    }
535
536    /// Emit code to subscript a vector by value with a computed index.
537    ///
538    /// Return the id of the element value.
539    pub(super) fn write_vector_access(
540        &mut self,
541        expr_handle: Handle<crate::Expression>,
542        base: Handle<crate::Expression>,
543        index: Handle<crate::Expression>,
544        block: &mut Block,
545    ) -> Result<Word, Error> {
546        let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
547
548        let base_id = self.cached[base];
549        let index = GuardedIndex::Expression(index);
550
551        let result_id = match self.write_bounds_check(base, index, block)? {
552            BoundsCheckResult::KnownInBounds(known_index) => {
553                let result_id = self.gen_id();
554                block.body.push(Instruction::composite_extract(
555                    result_type_id,
556                    result_id,
557                    base_id,
558                    &[known_index],
559                ));
560                result_id
561            }
562            BoundsCheckResult::Computed(computed_index_id) => {
563                let result_id = self.gen_id();
564                block.body.push(Instruction::vector_extract_dynamic(
565                    result_type_id,
566                    result_id,
567                    base_id,
568                    computed_index_id,
569                ));
570                result_id
571            }
572            BoundsCheckResult::Conditional {
573                condition_id,
574                index_id,
575            } => {
576                // Run-time bounds checks were required. Emit
577                // conditional load.
578                self.write_conditional_indexed_load(
579                    result_type_id,
580                    condition_id,
581                    block,
582                    |id_gen, block| {
583                        // The in-bounds path. Generate the access.
584                        let element_id = id_gen.next();
585                        block.body.push(Instruction::vector_extract_dynamic(
586                            result_type_id,
587                            element_id,
588                            base_id,
589                            index_id,
590                        ));
591                        element_id
592                    },
593                )
594            }
595        };
596
597        Ok(result_id)
598    }
599}