naga/back/spv/index.rs
1/*!
2Bounds-checking for SPIR-V output.
3*/
4
5use super::{
6 helpers::{global_needs_wrapper, map_storage_class},
7 selection::Selection,
8 Block, BlockContext, Error, IdGenerator, Instruction, Word,
9};
10use crate::{
11 arena::Handle,
12 proc::{index::GuardedIndex, BoundsCheckPolicy},
13};
14
15/// The results of performing a bounds check.
16///
17/// On success, [`write_bounds_check`](BlockContext::write_bounds_check)
18/// returns a value of this type. The caller can assume that the right
19/// policy has been applied, and simply do what the variant says.
20#[derive(Debug)]
21pub(super) enum BoundsCheckResult {
22 /// The index is statically known and in bounds, with the given value.
23 KnownInBounds(u32),
24
25 /// The given instruction computes the index to be used.
26 ///
27 /// When [`BoundsCheckPolicy::Restrict`] is in force, this is a
28 /// clamped version of the index the user supplied.
29 ///
30 /// When [`BoundsCheckPolicy::Unchecked`] is in force, this is
31 /// simply the index the user supplied. This variant indicates
32 /// that we couldn't prove statically that the index was in
33 /// bounds; otherwise we would have returned [`KnownInBounds`].
34 ///
35 /// [`KnownInBounds`]: BoundsCheckResult::KnownInBounds
36 Computed(Word),
37
38 /// The given instruction computes a boolean condition which is true
39 /// if the index is in bounds.
40 ///
41 /// This is returned when [`BoundsCheckPolicy::ReadZeroSkipWrite`]
42 /// is in force.
43 Conditional {
44 /// The access should only be permitted if this value is true.
45 condition_id: Word,
46
47 /// The access should use this index value.
48 index_id: Word,
49 },
50}
51
52/// A value that we either know at translation time, or need to compute at runtime.
53#[derive(Copy, Clone)]
54pub(super) enum MaybeKnown<T> {
55 /// The value is known at shader translation time.
56 Known(T),
57
58 /// The value is computed by the instruction with the given id.
59 Computed(Word),
60}
61
62impl BlockContext<'_> {
63 /// Emit code to compute the length of a run-time array.
64 ///
65 /// Given `array`, an expression referring a runtime-sized array, return the
66 /// instruction id for the array's length.
67 ///
68 /// Runtime-sized arrays may only appear in the values of global
69 /// variables, which must have one of the following Naga types:
70 ///
71 /// 1. A runtime-sized array.
72 /// 2. A struct whose last member is a runtime-sized array.
73 /// 3. A binding array of 2.
74 ///
75 /// Thus, the expression `array` has the form of:
76 ///
77 /// - An optional [`AccessIndex`], for case 2, applied to...
78 /// - An optional [`Access`] or [`AccessIndex`], for case 3, applied to...
79 /// - A [`GlobalVariable`].
80 ///
81 /// The generated SPIR-V takes into account wrapped globals; see
82 /// [`back::spv::GlobalVariable`] for details.
83 ///
84 /// [`GlobalVariable`]: crate::Expression::GlobalVariable
85 /// [`AccessIndex`]: crate::Expression::AccessIndex
86 /// [`Access`]: crate::Expression::Access
87 /// [`base`]: crate::Expression::Access::base
88 /// [`back::spv::GlobalVariable`]: super::GlobalVariable
89 pub(super) fn write_runtime_array_length(
90 &mut self,
91 array: Handle<crate::Expression>,
92 block: &mut Block,
93 ) -> Result<Word, Error> {
94 // The index into the binding array, if any.
95 let binding_array_index_id: Option<Word>;
96
97 // The handle to the Naga IR global we're referring to.
98 let global_handle: Handle<crate::GlobalVariable>;
99
100 // At the Naga type level, if the runtime-sized array is the final member of a
101 // struct, this is that member's index.
102 //
103 // This does not cover wrappers: if this backend wrapped the Naga global's
104 // type in a synthetic SPIR-V struct (see `global_needs_wrapper`), this is
105 // `None`.
106 let opt_last_member_index: Option<u32>;
107
108 // Inspect `array` and decide whether we have a binding array and/or an
109 // enclosing struct.
110 match self.ir_function.expressions[array] {
111 crate::Expression::AccessIndex { base, index } => {
112 match self.ir_function.expressions[base] {
113 crate::Expression::AccessIndex {
114 base: base_outer,
115 index: index_outer,
116 } => match self.ir_function.expressions[base_outer] {
117 // An `AccessIndex` of an `AccessIndex` must be a
118 // binding array holding structs whose last members are
119 // runtime-sized arrays.
120 crate::Expression::GlobalVariable(handle) => {
121 let index_id = self.get_index_constant(index_outer);
122 binding_array_index_id = Some(index_id);
123 global_handle = handle;
124 opt_last_member_index = Some(index);
125 }
126 _ => {
127 return Err(Error::Validation(
128 "array length expression: AccessIndex(AccessIndex(Global))",
129 ))
130 }
131 },
132 crate::Expression::Access {
133 base: base_outer,
134 index: index_outer,
135 } => match self.ir_function.expressions[base_outer] {
136 // Similarly, an `AccessIndex` of an `Access` must be a
137 // binding array holding structs whose last members are
138 // runtime-sized arrays.
139 crate::Expression::GlobalVariable(handle) => {
140 let index_id = self.cached[index_outer];
141 binding_array_index_id = Some(index_id);
142 global_handle = handle;
143 opt_last_member_index = Some(index);
144 }
145 _ => {
146 return Err(Error::Validation(
147 "array length expression: AccessIndex(Access(Global))",
148 ))
149 }
150 },
151 crate::Expression::GlobalVariable(handle) => {
152 // An outer `AccessIndex` applied directly to a
153 // `GlobalVariable`. Since binding arrays can only contain
154 // structs, this must be referring to the last member of a
155 // struct that is a runtime-sized array.
156 binding_array_index_id = None;
157 global_handle = handle;
158 opt_last_member_index = Some(index);
159 }
160 _ => {
161 return Err(Error::Validation(
162 "array length expression: AccessIndex(<unexpected>)",
163 ))
164 }
165 }
166 }
167 crate::Expression::GlobalVariable(handle) => {
168 // A direct reference to a global variable. This must hold the
169 // runtime-sized array directly.
170 binding_array_index_id = None;
171 global_handle = handle;
172 opt_last_member_index = None;
173 }
174 _ => return Err(Error::Validation("array length expression case-4")),
175 };
176
177 // The verifier should have checked this, but make sure the inspection above
178 // agrees with the type about whether a binding array is involved.
179 //
180 // Eventually we do want to support `binding_array<array<T>>`. This check
181 // ensures that whoever relaxes the validator will get an error message from
182 // us, not just bogus SPIR-V.
183 let global = &self.ir_module.global_variables[global_handle];
184 match (
185 &self.ir_module.types[global.ty].inner,
186 binding_array_index_id,
187 ) {
188 (&crate::TypeInner::BindingArray { .. }, Some(_)) => {}
189 (_, None) => {}
190 _ => {
191 return Err(Error::Validation(
192 "array length expression: bad binding array inference",
193 ))
194 }
195 }
196
197 // SPIR-V allows runtime-sized arrays to appear only as the last member of a
198 // struct. Determine this member's index.
199 let gvar = self.writer.global_variables[global_handle].clone();
200 let global = &self.ir_module.global_variables[global_handle];
201 let needs_wrapper = global_needs_wrapper(self.ir_module, global);
202 let (last_member_index, gvar_id) = match (opt_last_member_index, needs_wrapper) {
203 (Some(index), false) => {
204 // At the Naga type level, the runtime-sized array appears as the
205 // final member of a struct, whose index is `index`. We didn't need to
206 // wrap this, since the Naga type meets SPIR-V's requirements already.
207 (index, gvar.access_id)
208 }
209 (None, true) => {
210 // At the Naga type level, the runtime-sized array does not appear
211 // within a struct. We wrapped this in an OpTypeStruct with nothing
212 // else in it, so the index is zero. OpArrayLength wants the pointer
213 // to the wrapper struct, so use `gvar.var_id`.
214 (0, gvar.var_id)
215 }
216 _ => {
217 return Err(Error::Validation(
218 "array length expression: bad SPIR-V wrapper struct inference",
219 ));
220 }
221 };
222
223 let structure_id = match binding_array_index_id {
224 // We are indexing inside a binding array, generate the access op.
225 Some(index_id) => {
226 let element_type_id = match self.ir_module.types[global.ty].inner {
227 crate::TypeInner::BindingArray { base, size: _ } => {
228 let base_id = self.get_handle_type_id(base);
229 let class = map_storage_class(global.space);
230 self.get_pointer_type_id(base_id, class)
231 }
232 _ => return Err(Error::Validation("array length expression case-5")),
233 };
234 let structure_id = self.gen_id();
235 block.body.push(Instruction::access_chain(
236 element_type_id,
237 structure_id,
238 gvar_id,
239 &[index_id],
240 ));
241 structure_id
242 }
243 None => gvar_id,
244 };
245 let length_id = self.gen_id();
246 block.body.push(Instruction::array_length(
247 self.writer.get_u32_type_id(),
248 length_id,
249 structure_id,
250 last_member_index,
251 ));
252
253 Ok(length_id)
254 }
255
256 /// Compute the length of a subscriptable value.
257 ///
258 /// Given `sequence`, an expression referring to some indexable type, return
259 /// its length. The result may either be computed by SPIR-V instructions, or
260 /// known at shader translation time.
261 ///
262 /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
263 /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
264 /// sized, or use a specializable constant as its length.
265 fn write_sequence_length(
266 &mut self,
267 sequence: Handle<crate::Expression>,
268 block: &mut Block,
269 ) -> Result<MaybeKnown<u32>, Error> {
270 let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
271 match sequence_ty.indexable_length_resolved(self.ir_module) {
272 Ok(crate::proc::IndexableLength::Known(known_length)) => {
273 Ok(MaybeKnown::Known(known_length))
274 }
275 Ok(crate::proc::IndexableLength::Dynamic) => {
276 let length_id = self.write_runtime_array_length(sequence, block)?;
277 Ok(MaybeKnown::Computed(length_id))
278 }
279 Err(err) => {
280 log::error!("Sequence length for {sequence:?} failed: {err}");
281 Err(Error::Validation("indexable length"))
282 }
283 }
284 }
285
286 /// Compute the maximum valid index of a subscriptable value.
287 ///
288 /// Given `sequence`, an expression referring to some indexable type, return
289 /// its maximum valid index - one less than its length. The result may
290 /// either be computed, or known at shader translation time.
291 ///
292 /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
293 /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
294 /// sized, or use a specializable constant as its length.
295 fn write_sequence_max_index(
296 &mut self,
297 sequence: Handle<crate::Expression>,
298 block: &mut Block,
299 ) -> Result<MaybeKnown<u32>, Error> {
300 match self.write_sequence_length(sequence, block)? {
301 MaybeKnown::Known(known_length) => {
302 // We should have thrown out all attempts to subscript zero-length
303 // sequences during validation, so the following subtraction should never
304 // underflow.
305 assert!(known_length > 0);
306 // Compute the max index from the length now.
307 Ok(MaybeKnown::Known(known_length - 1))
308 }
309 MaybeKnown::Computed(length_id) => {
310 // Emit code to compute the max index from the length.
311 let const_one_id = self.get_index_constant(1);
312 let max_index_id = self.gen_id();
313 block.body.push(Instruction::binary(
314 spirv::Op::ISub,
315 self.writer.get_u32_type_id(),
316 max_index_id,
317 length_id,
318 const_one_id,
319 ));
320 Ok(MaybeKnown::Computed(max_index_id))
321 }
322 }
323 }
324
325 /// Restrict an index to be in range for a vector, matrix, or array.
326 ///
327 /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
328 /// index is left unchanged. An out-of-bounds index is replaced with some
329 /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
330 /// example, negative indices might be changed to refer to the last element
331 /// of the sequence, not the first, as clamping would do.
332 ///
333 /// Either return the restricted index value, if known, or add instructions
334 /// to `block` to compute it, and return the id of the result. See the
335 /// documentation for `BoundsCheckResult` for details.
336 ///
337 /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
338 /// `Pointer` to any of those, or a `ValuePointer`. An array may be
339 /// fixed-size, dynamically sized, or use a specializable constant as its
340 /// length.
341 pub(super) fn write_restricted_index(
342 &mut self,
343 sequence: Handle<crate::Expression>,
344 index: GuardedIndex,
345 block: &mut Block,
346 ) -> Result<BoundsCheckResult, Error> {
347 let max_index = self.write_sequence_max_index(sequence, block)?;
348
349 // If both are known, we can compute the index to be used
350 // right now.
351 if let (GuardedIndex::Known(index), MaybeKnown::Known(max_index)) = (index, max_index) {
352 let restricted = core::cmp::min(index, max_index);
353 return Ok(BoundsCheckResult::KnownInBounds(restricted));
354 }
355
356 let index_id = match index {
357 GuardedIndex::Known(value) => self.get_index_constant(value),
358 GuardedIndex::Expression(expr) => self.cached[expr],
359 };
360
361 let max_index_id = match max_index {
362 MaybeKnown::Known(value) => self.get_index_constant(value),
363 MaybeKnown::Computed(id) => id,
364 };
365
366 // One or the other of the index or length is dynamic, so emit code for
367 // BoundsCheckPolicy::Restrict.
368 let restricted_index_id = self.gen_id();
369 block.body.push(Instruction::ext_inst(
370 self.writer.gl450_ext_inst_id,
371 spirv::GLOp::UMin,
372 self.writer.get_u32_type_id(),
373 restricted_index_id,
374 &[index_id, max_index_id],
375 ));
376 Ok(BoundsCheckResult::Computed(restricted_index_id))
377 }
378
379 /// Write an index bounds comparison to `block`, if needed.
380 ///
381 /// This is used to implement [`BoundsCheckPolicy::ReadZeroSkipWrite`].
382 ///
383 /// If we're able to determine statically that `index` is in bounds for
384 /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
385 /// value of the index. (In principle, one could know that the index is in
386 /// bounds without knowing its specific value, but in our simple-minded
387 /// situation, we always know it.)
388 ///
389 /// If instead we must generate code to perform the comparison at run time,
390 /// return `Conditional(comparison_id)`, where `comparison_id` is an
391 /// instruction producing a boolean value that is true if `index` is in
392 /// bounds for `sequence`.
393 ///
394 /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
395 /// `Pointer` to any of those, or a `ValuePointer`. An array may be
396 /// fixed-size, dynamically sized, or use a specializable constant as its
397 /// length.
398 fn write_index_comparison(
399 &mut self,
400 sequence: Handle<crate::Expression>,
401 index: GuardedIndex,
402 block: &mut Block,
403 ) -> Result<BoundsCheckResult, Error> {
404 let length = self.write_sequence_length(sequence, block)?;
405
406 // If both are known, we can decide whether the index is in
407 // bounds right now.
408 if let (GuardedIndex::Known(index), MaybeKnown::Known(length)) = (index, length) {
409 if index < length {
410 return Ok(BoundsCheckResult::KnownInBounds(index));
411 }
412
413 // In theory, when `index` is bad, we could return a new
414 // `KnownOutOfBounds` variant here. But it's simpler just to fall
415 // through and let the bounds check take place. The shader is broken
416 // anyway, so it doesn't make sense to invest in emitting the ideal
417 // code for it.
418 }
419
420 let index_id = match index {
421 GuardedIndex::Known(value) => self.get_index_constant(value),
422 GuardedIndex::Expression(expr) => self.cached[expr],
423 };
424
425 let length_id = match length {
426 MaybeKnown::Known(value) => self.get_index_constant(value),
427 MaybeKnown::Computed(id) => id,
428 };
429
430 // Compare the index against the length.
431 let condition_id = self.gen_id();
432 block.body.push(Instruction::binary(
433 spirv::Op::ULessThan,
434 self.writer.get_bool_type_id(),
435 condition_id,
436 index_id,
437 length_id,
438 ));
439
440 // Indicate that we did generate the check.
441 Ok(BoundsCheckResult::Conditional {
442 condition_id,
443 index_id,
444 })
445 }
446
447 /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
448 ///
449 /// Generate code to load a value of `result_type` if `condition` is true,
450 /// and generate a null value of that type if it is false. Call `emit_load`
451 /// to emit the instructions to perform the load. Return the id of the
452 /// merged value of the two branches.
453 pub(super) fn write_conditional_indexed_load<F>(
454 &mut self,
455 result_type: Word,
456 condition: Word,
457 block: &mut Block,
458 emit_load: F,
459 ) -> Word
460 where
461 F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
462 {
463 // For the out-of-bounds case, we produce a zero value.
464 let null_id = self.writer.get_constant_null(result_type);
465
466 let mut selection = Selection::start(block, result_type);
467
468 // As it turns out, we don't actually need a full 'if-then-else'
469 // structure for this: SPIR-V constants are declared up front, so the
470 // 'else' block would have no instructions. Instead we emit something
471 // like this:
472 //
473 // result = zero;
474 // if in_bounds {
475 // result = do the load;
476 // }
477 // use result;
478
479 // Continue only if the index was in bounds. Otherwise, branch to the
480 // merge block.
481 selection.if_true(self, condition, null_id);
482
483 // The in-bounds path. Perform the access and the load.
484 let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
485
486 selection.finish(self, loaded_value)
487 }
488
489 /// Emit code for bounds checks for an array, vector, or matrix access.
490 ///
491 /// This tries to handle all the critical steps for bounds checks:
492 ///
493 /// - First, select the appropriate bounds check policy for `base`,
494 /// depending on its address space.
495 ///
496 /// - Next, analyze `index` to see if its value is known at
497 /// compile time, in which case we can decide statically whether
498 /// the index is in bounds.
499 ///
500 /// - If the index's value is not known at compile time, emit code to:
501 ///
502 /// - restrict its value (for [`BoundsCheckPolicy::Restrict`]), or
503 ///
504 /// - check whether it's in bounds (for
505 /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]).
506 ///
507 /// Return a [`BoundsCheckResult`] indicating how the index should be
508 /// consumed. See that type's documentation for details.
509 pub(super) fn write_bounds_check(
510 &mut self,
511 base: Handle<crate::Expression>,
512 mut index: GuardedIndex,
513 block: &mut Block,
514 ) -> Result<BoundsCheckResult, Error> {
515 // If the value of `index` is known at compile time, find it now.
516 index.try_resolve_to_constant(&self.ir_function.expressions, self.ir_module);
517
518 let policy = self.writer.bounds_check_policies.choose_policy(
519 base,
520 &self.ir_module.types,
521 self.fun_info,
522 );
523
524 Ok(match policy {
525 BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
526 BoundsCheckPolicy::ReadZeroSkipWrite => {
527 self.write_index_comparison(base, index, block)?
528 }
529 BoundsCheckPolicy::Unchecked => match index {
530 GuardedIndex::Known(value) => BoundsCheckResult::KnownInBounds(value),
531 GuardedIndex::Expression(expr) => BoundsCheckResult::Computed(self.cached[expr]),
532 },
533 })
534 }
535
536 /// Emit code to subscript a vector by value with a computed index.
537 ///
538 /// Return the id of the element value.
539 pub(super) fn write_vector_access(
540 &mut self,
541 expr_handle: Handle<crate::Expression>,
542 base: Handle<crate::Expression>,
543 index: Handle<crate::Expression>,
544 block: &mut Block,
545 ) -> Result<Word, Error> {
546 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
547
548 let base_id = self.cached[base];
549 let index = GuardedIndex::Expression(index);
550
551 let result_id = match self.write_bounds_check(base, index, block)? {
552 BoundsCheckResult::KnownInBounds(known_index) => {
553 let result_id = self.gen_id();
554 block.body.push(Instruction::composite_extract(
555 result_type_id,
556 result_id,
557 base_id,
558 &[known_index],
559 ));
560 result_id
561 }
562 BoundsCheckResult::Computed(computed_index_id) => {
563 let result_id = self.gen_id();
564 block.body.push(Instruction::vector_extract_dynamic(
565 result_type_id,
566 result_id,
567 base_id,
568 computed_index_id,
569 ));
570 result_id
571 }
572 BoundsCheckResult::Conditional {
573 condition_id,
574 index_id,
575 } => {
576 // Run-time bounds checks were required. Emit
577 // conditional load.
578 self.write_conditional_indexed_load(
579 result_type_id,
580 condition_id,
581 block,
582 |id_gen, block| {
583 // The in-bounds path. Generate the access.
584 let element_id = id_gen.next();
585 block.body.push(Instruction::vector_extract_dynamic(
586 result_type_id,
587 element_id,
588 base_id,
589 index_id,
590 ));
591 element_id
592 },
593 )
594 }
595 };
596
597 Ok(result_id)
598 }
599}