naga/back/spv/
index.rs

1/*!
2Bounds-checking for SPIR-V output.
3*/
4
5use super::{
6    helpers::{global_needs_wrapper, map_storage_class},
7    selection::Selection,
8    Block, BlockContext, Error, IdGenerator, Instruction, Word,
9};
10use crate::{
11    arena::Handle,
12    proc::{index::GuardedIndex, BoundsCheckPolicy},
13};
14
15/// The results of performing a bounds check.
16///
17/// On success, [`write_bounds_check`](BlockContext::write_bounds_check)
18/// returns a value of this type. The caller can assume that the right
19/// policy has been applied, and simply do what the variant says.
20#[derive(Debug)]
21pub(super) enum BoundsCheckResult {
22    /// The index is statically known and in bounds, with the given value.
23    KnownInBounds(u32),
24
25    /// The given instruction computes the index to be used.
26    ///
27    /// When [`BoundsCheckPolicy::Restrict`] is in force, this is a
28    /// clamped version of the index the user supplied.
29    ///
30    /// When [`BoundsCheckPolicy::Unchecked`] is in force, this is
31    /// simply the index the user supplied. This variant indicates
32    /// that we couldn't prove statically that the index was in
33    /// bounds; otherwise we would have returned [`KnownInBounds`].
34    ///
35    /// [`KnownInBounds`]: BoundsCheckResult::KnownInBounds
36    Computed(Word),
37
38    /// The given instruction computes a boolean condition which is true
39    /// if the index is in bounds.
40    ///
41    /// This is returned when [`BoundsCheckPolicy::ReadZeroSkipWrite`]
42    /// is in force.
43    Conditional {
44        /// The access should only be permitted if this value is true.
45        condition_id: Word,
46
47        /// The access should use this index value.
48        index_id: Word,
49    },
50}
51
52/// A value that we either know at translation time, or need to compute at runtime.
53#[derive(Copy, Clone)]
54pub(super) enum MaybeKnown<T> {
55    /// The value is known at shader translation time.
56    Known(T),
57
58    /// The value is computed by the instruction with the given id.
59    Computed(Word),
60}
61
62impl BlockContext<'_> {
63    /// Emit code to compute the length of a run-time array.
64    ///
65    /// Given `array`, an expression referring a runtime-sized array, return the
66    /// instruction id for the array's length.
67    ///
68    /// Runtime-sized arrays may only appear in the values of global
69    /// variables, which must have one of the following Naga types:
70    ///
71    /// 1. A runtime-sized array.
72    /// 2. A struct whose last member is a runtime-sized array.
73    /// 3. A binding array of 2.
74    ///
75    /// Thus, the expression `array` has the form of:
76    ///
77    /// - An optional [`AccessIndex`], for case 2, applied to...
78    /// - An optional [`Access`] or [`AccessIndex`], for case 3, applied to...
79    /// - A [`GlobalVariable`].
80    ///
81    /// The generated SPIR-V takes into account wrapped globals; see
82    /// [`back::spv::GlobalVariable`] for details.
83    ///
84    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
85    /// [`AccessIndex`]: crate::Expression::AccessIndex
86    /// [`Access`]: crate::Expression::Access
87    /// [`base`]: crate::Expression::Access::base
88    /// [`back::spv::GlobalVariable`]: super::GlobalVariable
89    pub(super) fn write_runtime_array_length(
90        &mut self,
91        array: Handle<crate::Expression>,
92        block: &mut Block,
93    ) -> Result<Word, Error> {
94        // The index into the binding array, if any.
95        let binding_array_index_id: Option<Word>;
96
97        // The handle to the Naga IR global we're referring to.
98        let global_handle: Handle<crate::GlobalVariable>;
99
100        // At the Naga type level, if the runtime-sized array is the final member of a
101        // struct, this is that member's index.
102        //
103        // This does not cover wrappers: if this backend wrapped the Naga global's
104        // type in a synthetic SPIR-V struct (see `global_needs_wrapper`), this is
105        // `None`.
106        let opt_last_member_index: Option<u32>;
107
108        // Inspect `array` and decide whether we have a binding array and/or an
109        // enclosing struct.
110        match self.ir_function.expressions[array] {
111            crate::Expression::AccessIndex { base, index } => {
112                match self.ir_function.expressions[base] {
113                    crate::Expression::AccessIndex {
114                        base: base_outer,
115                        index: index_outer,
116                    } => match self.ir_function.expressions[base_outer] {
117                        // An `AccessIndex` of an `AccessIndex` must be a
118                        // binding array holding structs whose last members are
119                        // runtime-sized arrays.
120                        crate::Expression::GlobalVariable(handle) => {
121                            let index_id = self.get_index_constant(index_outer);
122                            binding_array_index_id = Some(index_id);
123                            global_handle = handle;
124                            opt_last_member_index = Some(index);
125                        }
126                        _ => {
127                            return Err(Error::Validation(
128                                "array length expression: AccessIndex(AccessIndex(Global))",
129                            ))
130                        }
131                    },
132                    crate::Expression::Access {
133                        base: base_outer,
134                        index: index_outer,
135                    } => match self.ir_function.expressions[base_outer] {
136                        // Similarly, an `AccessIndex` of an `Access` must be a
137                        // binding array holding structs whose last members are
138                        // runtime-sized arrays.
139                        crate::Expression::GlobalVariable(handle) => {
140                            let index_id = self.cached[index_outer];
141                            binding_array_index_id = Some(index_id);
142                            global_handle = handle;
143                            opt_last_member_index = Some(index);
144                        }
145                        _ => {
146                            return Err(Error::Validation(
147                                "array length expression: AccessIndex(Access(Global))",
148                            ))
149                        }
150                    },
151                    crate::Expression::GlobalVariable(handle) => {
152                        // An outer `AccessIndex` applied directly to a
153                        // `GlobalVariable`. Since binding arrays can only contain
154                        // structs, this must be referring to the last member of a
155                        // struct that is a runtime-sized array.
156                        binding_array_index_id = None;
157                        global_handle = handle;
158                        opt_last_member_index = Some(index);
159                    }
160                    _ => {
161                        return Err(Error::Validation(
162                            "array length expression: AccessIndex(<unexpected>)",
163                        ))
164                    }
165                }
166            }
167            crate::Expression::GlobalVariable(handle) => {
168                // A direct reference to a global variable. This must hold the
169                // runtime-sized array directly.
170                binding_array_index_id = None;
171                global_handle = handle;
172                opt_last_member_index = None;
173            }
174            _ => return Err(Error::Validation("array length expression case-4")),
175        };
176
177        // The verifier should have checked this, but make sure the inspection above
178        // agrees with the type about whether a binding array is involved.
179        //
180        // Eventually we do want to support `binding_array<array<T>>`. This check
181        // ensures that whoever relaxes the validator will get an error message from
182        // us, not just bogus SPIR-V.
183        let global = &self.ir_module.global_variables[global_handle];
184        match (
185            &self.ir_module.types[global.ty].inner,
186            binding_array_index_id,
187        ) {
188            (&crate::TypeInner::BindingArray { .. }, Some(_)) => {}
189            (_, None) => {}
190            _ => {
191                return Err(Error::Validation(
192                    "array length expression: bad binding array inference",
193                ))
194            }
195        }
196
197        // SPIR-V allows runtime-sized arrays to appear only as the last member of a
198        // struct. Determine this member's index.
199        let gvar = self.writer.global_variables[global_handle].clone();
200        let global = &self.ir_module.global_variables[global_handle];
201        let needs_wrapper = global_needs_wrapper(self.ir_module, global);
202        let (last_member_index, gvar_id) = match (opt_last_member_index, needs_wrapper) {
203            (Some(index), false) => {
204                // At the Naga type level, the runtime-sized array appears as the
205                // final member of a struct, whose index is `index`. We didn't need to
206                // wrap this, since the Naga type meets SPIR-V's requirements already.
207                (index, gvar.access_id)
208            }
209            (None, true) => {
210                // At the Naga type level, the runtime-sized array does not appear
211                // within a struct. We wrapped this in an OpTypeStruct with nothing
212                // else in it, so the index is zero. OpArrayLength wants the pointer
213                // to the wrapper struct, so use `gvar.var_id`.
214                (0, gvar.var_id)
215            }
216            _ => {
217                return Err(Error::Validation(
218                    "array length expression: bad SPIR-V wrapper struct inference",
219                ));
220            }
221        };
222
223        let structure_id = match binding_array_index_id {
224            // We are indexing inside a binding array, generate the access op.
225            Some(index_id) => {
226                let element_type_id = match self.ir_module.types[global.ty].inner {
227                    crate::TypeInner::BindingArray { base, size: _ } => {
228                        let base_id = self.get_handle_type_id(base);
229                        let class = map_storage_class(global.space);
230                        self.get_pointer_type_id(base_id, class)
231                    }
232                    _ => return Err(Error::Validation("array length expression case-5")),
233                };
234                let structure_id = self.gen_id();
235                block.body.push(Instruction::access_chain(
236                    element_type_id,
237                    structure_id,
238                    gvar_id,
239                    &[index_id],
240                ));
241                structure_id
242            }
243            None => gvar_id,
244        };
245        let length_id = self.gen_id();
246        block.body.push(Instruction::array_length(
247            self.writer.get_u32_type_id(),
248            length_id,
249            structure_id,
250            last_member_index,
251        ));
252
253        Ok(length_id)
254    }
255
256    /// If `sequence` refers to an unbounded binding array global, return its
257    /// layout size from the SPIR-V binding map.
258    fn binding_array_layout_size(&self, sequence: Handle<crate::Expression>) -> Option<u32> {
259        let global_handle = match self.ir_function.expressions[sequence] {
260            crate::Expression::GlobalVariable(handle) => handle,
261            _ => self.ir_function.originating_global(sequence)?,
262        };
263        let global = &self.ir_module.global_variables[global_handle];
264        let crate::TypeInner::BindingArray { .. } = self.ir_module.types[global.ty].inner else {
265            return None;
266        };
267        let binding = global.binding?;
268        let bind_target = self.writer.resolve_resource_binding(&binding).ok()?;
269        bind_target.binding_array_size
270    }
271
272    /// Compute the length of a subscriptable value.
273    ///
274    /// Given `sequence`, an expression referring to some indexable type, return
275    /// its length. The result may either be computed by SPIR-V instructions, or
276    /// known at shader translation time.
277    ///
278    /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
279    /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
280    /// sized, or use a specializable constant as its length.
281    fn write_sequence_length(
282        &mut self,
283        sequence: Handle<crate::Expression>,
284        block: &mut Block,
285    ) -> Result<MaybeKnown<u32>, Error> {
286        let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
287        match sequence_ty.indexable_length_resolved(self.ir_module) {
288            Ok(crate::proc::IndexableLength::Known(known_length)) => {
289                Ok(MaybeKnown::Known(known_length))
290            }
291            Ok(crate::proc::IndexableLength::Dynamic) => {
292                if let Some(size) = self.binding_array_layout_size(sequence) {
293                    return Ok(MaybeKnown::Known(size));
294                }
295                let length_id = self.write_runtime_array_length(sequence, block)?;
296                Ok(MaybeKnown::Computed(length_id))
297            }
298            Err(err) => {
299                log::error!("Sequence length for {sequence:?} failed: {err}");
300                Err(Error::Validation("indexable length"))
301            }
302        }
303    }
304
305    /// Compute the maximum valid index of a subscriptable value.
306    ///
307    /// Given `sequence`, an expression referring to some indexable type, return
308    /// its maximum valid index - one less than its length. The result may
309    /// either be computed, or known at shader translation time.
310    ///
311    /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
312    /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
313    /// sized, or use a specializable constant as its length.
314    fn write_sequence_max_index(
315        &mut self,
316        sequence: Handle<crate::Expression>,
317        block: &mut Block,
318    ) -> Result<MaybeKnown<u32>, Error> {
319        match self.write_sequence_length(sequence, block)? {
320            MaybeKnown::Known(known_length) => {
321                // We should have thrown out all attempts to subscript zero-length
322                // sequences during validation, so the following subtraction should never
323                // underflow.
324                assert!(known_length > 0);
325                // Compute the max index from the length now.
326                Ok(MaybeKnown::Known(known_length - 1))
327            }
328            MaybeKnown::Computed(length_id) => {
329                // Emit code to compute the max index from the length.
330                let const_one_id = self.get_index_constant(1);
331                let max_index_id = self.gen_id();
332                block.body.push(Instruction::binary(
333                    spirv::Op::ISub,
334                    self.writer.get_u32_type_id(),
335                    max_index_id,
336                    length_id,
337                    const_one_id,
338                ));
339                Ok(MaybeKnown::Computed(max_index_id))
340            }
341        }
342    }
343
344    /// Restrict an index to be in range for a vector, matrix, or array.
345    ///
346    /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
347    /// index is left unchanged. An out-of-bounds index is replaced with some
348    /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
349    /// example, negative indices might be changed to refer to the last element
350    /// of the sequence, not the first, as clamping would do.
351    ///
352    /// Either return the restricted index value, if known, or add instructions
353    /// to `block` to compute it, and return the id of the result. See the
354    /// documentation for `BoundsCheckResult` for details.
355    ///
356    /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
357    /// `Pointer` to any of those, or a `ValuePointer`. An array may be
358    /// fixed-size, dynamically sized, or use a specializable constant as its
359    /// length.
360    pub(super) fn write_restricted_index(
361        &mut self,
362        sequence: Handle<crate::Expression>,
363        index: GuardedIndex,
364        block: &mut Block,
365    ) -> Result<BoundsCheckResult, Error> {
366        let max_index = self.write_sequence_max_index(sequence, block)?;
367
368        // If both are known, we can compute the index to be used
369        // right now.
370        if let (GuardedIndex::Known(index), MaybeKnown::Known(max_index)) = (index, max_index) {
371            let restricted = core::cmp::min(index, max_index);
372            return Ok(BoundsCheckResult::KnownInBounds(restricted));
373        }
374
375        let index_id = match index {
376            GuardedIndex::Known(value) => self.get_index_constant(value),
377            GuardedIndex::Expression(expr) => self.cached[expr],
378        };
379
380        let max_index_id = match max_index {
381            MaybeKnown::Known(value) => self.get_index_constant(value),
382            MaybeKnown::Computed(id) => id,
383        };
384
385        // One or the other of the index or length is dynamic, so emit code for
386        // BoundsCheckPolicy::Restrict.
387        let restricted_index_id = self.gen_id();
388        block.body.push(Instruction::ext_inst_gl_op(
389            self.writer.gl450_ext_inst_id,
390            spirv::GlslStd450Op::UMin,
391            self.writer.get_u32_type_id(),
392            restricted_index_id,
393            &[index_id, max_index_id],
394        ));
395        Ok(BoundsCheckResult::Computed(restricted_index_id))
396    }
397
398    /// Write an index bounds comparison to `block`, if needed.
399    ///
400    /// This is used to implement [`BoundsCheckPolicy::ReadZeroSkipWrite`].
401    ///
402    /// If we're able to determine statically that `index` is in bounds for
403    /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
404    /// value of the index. (In principle, one could know that the index is in
405    /// bounds without knowing its specific value, but in our simple-minded
406    /// situation, we always know it.)
407    ///
408    /// If instead we must generate code to perform the comparison at run time,
409    /// return `Conditional(comparison_id)`, where `comparison_id` is an
410    /// instruction producing a boolean value that is true if `index` is in
411    /// bounds for `sequence`.
412    ///
413    /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
414    /// `Pointer` to any of those, or a `ValuePointer`. An array may be
415    /// fixed-size, dynamically sized, or use a specializable constant as its
416    /// length.
417    fn write_index_comparison(
418        &mut self,
419        sequence: Handle<crate::Expression>,
420        index: GuardedIndex,
421        block: &mut Block,
422    ) -> Result<BoundsCheckResult, Error> {
423        let length = self.write_sequence_length(sequence, block)?;
424
425        // If both are known, we can decide whether the index is in
426        // bounds right now.
427        if let (GuardedIndex::Known(index), MaybeKnown::Known(length)) = (index, length) {
428            if index < length {
429                return Ok(BoundsCheckResult::KnownInBounds(index));
430            }
431
432            // In theory, when `index` is bad, we could return a new
433            // `KnownOutOfBounds` variant here. But it's simpler just to fall
434            // through and let the bounds check take place. The shader is broken
435            // anyway, so it doesn't make sense to invest in emitting the ideal
436            // code for it.
437        }
438
439        let index_id = match index {
440            GuardedIndex::Known(value) => self.get_index_constant(value),
441            GuardedIndex::Expression(expr) => self.cached[expr],
442        };
443
444        let length_id = match length {
445            MaybeKnown::Known(value) => self.get_index_constant(value),
446            MaybeKnown::Computed(id) => id,
447        };
448
449        // Compare the index against the length.
450        let condition_id = self.gen_id();
451        block.body.push(Instruction::binary(
452            spirv::Op::ULessThan,
453            self.writer.get_bool_type_id(),
454            condition_id,
455            index_id,
456            length_id,
457        ));
458
459        // Indicate that we did generate the check.
460        Ok(BoundsCheckResult::Conditional {
461            condition_id,
462            index_id,
463        })
464    }
465
466    /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
467    ///
468    /// Generate code to load a value of `result_type` if `condition` is true,
469    /// and generate a null value of that type if it is false. Call `emit_load`
470    /// to emit the instructions to perform the load. Return the id of the
471    /// merged value of the two branches.
472    pub(super) fn write_conditional_indexed_load<F>(
473        &mut self,
474        result_type: Word,
475        condition: Word,
476        block: &mut Block,
477        emit_load: F,
478    ) -> Word
479    where
480        F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
481    {
482        // For the out-of-bounds case, we produce a zero value.
483        let null_id = self.writer.get_constant_null(result_type);
484
485        let mut selection = Selection::start(block, result_type);
486
487        // As it turns out, we don't actually need a full 'if-then-else'
488        // structure for this: SPIR-V constants are declared up front, so the
489        // 'else' block would have no instructions. Instead we emit something
490        // like this:
491        //
492        //     result = zero;
493        //     if in_bounds {
494        //         result = do the load;
495        //     }
496        //     use result;
497
498        // Continue only if the index was in bounds. Otherwise, branch to the
499        // merge block.
500        selection.if_true(self, condition, null_id);
501
502        // The in-bounds path. Perform the access and the load.
503        let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
504
505        selection.finish(self, loaded_value)
506    }
507
508    /// Emit code for bounds checks for an array, vector, or matrix access.
509    ///
510    /// This tries to handle all the critical steps for bounds checks:
511    ///
512    /// - First, select the appropriate bounds check policy for `base`,
513    ///   depending on its address space.
514    ///
515    /// - Next, analyze `index` to see if its value is known at
516    ///   compile time, in which case we can decide statically whether
517    ///   the index is in bounds.
518    ///
519    /// - If the index's value is not known at compile time, emit code to:
520    ///
521    ///     - restrict its value (for [`BoundsCheckPolicy::Restrict`]), or
522    ///
523    ///     - check whether it's in bounds (for
524    ///       [`BoundsCheckPolicy::ReadZeroSkipWrite`]).
525    ///
526    /// Return a [`BoundsCheckResult`] indicating how the index should be
527    /// consumed. See that type's documentation for details.
528    pub(super) fn write_bounds_check(
529        &mut self,
530        base: Handle<crate::Expression>,
531        mut index: GuardedIndex,
532        block: &mut Block,
533    ) -> Result<BoundsCheckResult, Error> {
534        // If the value of `index` is known at compile time, find it now.
535        index.try_resolve_to_constant(&self.ir_function.expressions, self.ir_module);
536
537        let policy = self.writer.bounds_check_policies.choose_policy(
538            base,
539            &self.ir_module.types,
540            self.fun_info,
541        );
542
543        Ok(match policy {
544            BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
545            BoundsCheckPolicy::ReadZeroSkipWrite => {
546                self.write_index_comparison(base, index, block)?
547            }
548            BoundsCheckPolicy::Unchecked => match index {
549                GuardedIndex::Known(value) => BoundsCheckResult::KnownInBounds(value),
550                GuardedIndex::Expression(expr) => BoundsCheckResult::Computed(self.cached[expr]),
551            },
552        })
553    }
554
555    /// Emit code to subscript a vector by value with a computed index.
556    ///
557    /// Return the id of the element value.
558    ///
559    /// If `base_id_override` is provided, it is used as the vector expression
560    /// to be subscripted into, rather than the cached value of `base`.
561    pub(super) fn write_vector_access(
562        &mut self,
563        result_type_id: Word,
564        base: Handle<crate::Expression>,
565        base_id_override: Option<Word>,
566        index: GuardedIndex,
567        block: &mut Block,
568    ) -> Result<Word, Error> {
569        let base_id = base_id_override.unwrap_or_else(|| self.cached[base]);
570
571        let result_id = match self.write_bounds_check(base, index, block)? {
572            BoundsCheckResult::KnownInBounds(known_index) => {
573                let result_id = self.gen_id();
574                block.body.push(Instruction::composite_extract(
575                    result_type_id,
576                    result_id,
577                    base_id,
578                    &[known_index],
579                ));
580                result_id
581            }
582            BoundsCheckResult::Computed(computed_index_id) => {
583                let result_id = self.gen_id();
584                block.body.push(Instruction::vector_extract_dynamic(
585                    result_type_id,
586                    result_id,
587                    base_id,
588                    computed_index_id,
589                ));
590                result_id
591            }
592            BoundsCheckResult::Conditional {
593                condition_id,
594                index_id,
595            } => {
596                // Run-time bounds checks were required. Emit
597                // conditional load.
598                self.write_conditional_indexed_load(
599                    result_type_id,
600                    condition_id,
601                    block,
602                    |id_gen, block| {
603                        // The in-bounds path. Generate the access.
604                        let element_id = id_gen.next();
605                        block.body.push(Instruction::vector_extract_dynamic(
606                            result_type_id,
607                            element_id,
608                            base_id,
609                            index_id,
610                        ));
611                        element_id
612                    },
613                )
614            }
615        };
616
617        Ok(result_id)
618    }
619}