naga/back/spv/
image.rs

1/*!
2Generating SPIR-V for image operations.
3*/
4
5use spirv::Word;
6
7use super::{
8    selection::{MergeTuple, Selection},
9    Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
10};
11use crate::arena::Handle;
12
13/// Information about a vector of coordinates.
14///
15/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch`
16/// supply the array index for arrayed images as an additional component at
17/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry
18/// the array index as a separate field.
19///
20/// In the process of generating code to compute the combined vector, we also
21/// produce SPIR-V types and vector lengths that are useful elsewhere. This
22/// struct gathers that information into one place, with standard names.
23struct ImageCoordinates {
24    /// The SPIR-V id of the combined coordinate/index vector value.
25    ///
26    /// Note: when indexing a non-arrayed 1D image, this will be a scalar.
27    value_id: Word,
28
29    /// The SPIR-V id of the type of `value`.
30    type_id: Word,
31
32    /// The number of components in `value`, if it is a vector, or `None` if it
33    /// is a scalar.
34    size: Option<crate::VectorSize>,
35}
36
37/// A trait for image access (load or store) code generators.
38///
39/// Types implementing this trait hold information about an `ImageStore` or
40/// `ImageLoad` operation that is not affected by the bounds check policy. The
41/// `generate` method emits code for the access, given the results of bounds
42/// checking.
43///
44/// The [`image`] bounds checks policy affects access coordinates, level of
45/// detail, and sample index, but never the image id, result type (if any), or
46/// the specific SPIR-V instruction used. Types that implement this trait gather
47/// together the latter category, so we don't have to plumb them through the
48/// bounds-checking code.
49///
50/// [`image`]: crate::proc::BoundsCheckPolicies::index
51trait Access {
52    /// The Rust type that represents SPIR-V values and types for this access.
53    ///
54    /// For operations like loads, this is `Word`. For operations like stores,
55    /// this is `()`.
56    ///
57    /// For `ReadZeroSkipWrite`, this will be the type of the selection
58    /// construct that performs the bounds checks, so it must implement
59    /// `MergeTuple`.
60    type Output: MergeTuple + Copy + Clone;
61
62    /// Write an image access to `block`.
63    ///
64    /// Access the texel at `coordinates_id`. The optional `level_id` indicates
65    /// the level of detail, and `sample_id` is the index of the sample to
66    /// access in a multisampled texel.
67    ///
68    /// This method assumes that `coordinates_id` has already had the image array
69    /// index, if any, folded in, as done by `write_image_coordinates`.
70    ///
71    /// Return the value id produced by the instruction, if any.
72    ///
73    /// Use `id_gen` to generate SPIR-V ids as necessary.
74    fn generate(
75        &self,
76        id_gen: &mut IdGenerator,
77        coordinates_id: Word,
78        level_id: Option<Word>,
79        sample_id: Option<Word>,
80        block: &mut Block,
81    ) -> Self::Output;
82
83    /// Return the SPIR-V type of the value produced by the code written by
84    /// `generate`. If the access does not produce a value, `Self::Output`
85    /// should be `()`.
86    fn result_type(&self) -> Self::Output;
87
88    /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds
89    /// access under the `ReadZeroSkipWrite` policy. If the access does not
90    /// produce a value, `Self::Output` should be `()`.
91    fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output;
92}
93
94/// Texel access information for an [`ImageLoad`] expression.
95///
96/// [`ImageLoad`]: crate::Expression::ImageLoad
97struct Load {
98    /// The specific opcode we'll use to perform the fetch. Storage images
99    /// require `OpImageRead`, while sampled images require `OpImageFetch`.
100    opcode: spirv::Op,
101
102    /// The type id produced by the actual image access instruction.
103    type_id: Word,
104
105    /// The id of the image being accessed.
106    image_id: Word,
107}
108
109impl Load {
110    fn from_image_expr(
111        ctx: &mut BlockContext<'_>,
112        image_id: Word,
113        image_class: crate::ImageClass,
114        result_type_id: Word,
115    ) -> Result<Load, Error> {
116        let opcode = match image_class {
117            crate::ImageClass::Storage { .. } => spirv::Op::ImageRead,
118            crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => {
119                spirv::Op::ImageFetch
120            }
121            crate::ImageClass::External => unimplemented!(),
122        };
123
124        // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32>
125        // values. Most of the time, we can just use `result_type_id` for
126        // this. The exception is that `Expression::ImageLoad` from a depth
127        // image produces a scalar `f32`, so in that case we need to find
128        // the right SPIR-V type for the access instruction here.
129        let type_id = match image_class {
130            crate::ImageClass::Depth { .. } => ctx.get_numeric_type_id(NumericType::Vector {
131                size: crate::VectorSize::Quad,
132                scalar: crate::Scalar::F32,
133            }),
134            _ => result_type_id,
135        };
136
137        Ok(Load {
138            opcode,
139            type_id,
140            image_id,
141        })
142    }
143}
144
145impl Access for Load {
146    type Output = Word;
147
148    /// Write an instruction to access a given texel of this image.
149    fn generate(
150        &self,
151        id_gen: &mut IdGenerator,
152        coordinates_id: Word,
153        level_id: Option<Word>,
154        sample_id: Option<Word>,
155        block: &mut Block,
156    ) -> Word {
157        let texel_id = id_gen.next();
158        let mut instruction = Instruction::image_fetch_or_read(
159            self.opcode,
160            self.type_id,
161            texel_id,
162            self.image_id,
163            coordinates_id,
164        );
165
166        match (level_id, sample_id) {
167            (None, None) => {}
168            (Some(level_id), None) => {
169                instruction.add_operand(spirv::ImageOperands::LOD.bits());
170                instruction.add_operand(level_id);
171            }
172            (None, Some(sample_id)) => {
173                instruction.add_operand(spirv::ImageOperands::SAMPLE.bits());
174                instruction.add_operand(sample_id);
175            }
176            // There's no such thing as a multi-sampled mipmap.
177            (Some(_), Some(_)) => unreachable!(),
178        }
179
180        block.body.push(instruction);
181
182        texel_id
183    }
184
185    fn result_type(&self) -> Word {
186        self.type_id
187    }
188
189    fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word {
190        ctx.writer.get_constant_null(self.type_id)
191    }
192}
193
194/// Texel access information for a [`Store`] statement.
195///
196/// [`Store`]: crate::Statement::Store
197struct Store {
198    /// The id of the image being written to.
199    image_id: Word,
200
201    /// The value we're going to write to the texel.
202    value_id: Word,
203}
204
205impl Access for Store {
206    /// Stores don't generate any value.
207    type Output = ();
208
209    fn generate(
210        &self,
211        _id_gen: &mut IdGenerator,
212        coordinates_id: Word,
213        _level_id: Option<Word>,
214        _sample_id: Option<Word>,
215        block: &mut Block,
216    ) {
217        block.body.push(Instruction::image_write(
218            self.image_id,
219            coordinates_id,
220            self.value_id,
221        ));
222    }
223
224    /// Stores don't generate any value, so this just returns `()`.
225    fn result_type(&self) {}
226
227    /// Stores don't generate any value, so this just returns `()`.
228    fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {}
229}
230
231impl BlockContext<'_> {
232    /// Extend image coordinates with an array index, if necessary.
233    ///
234    /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array
235    /// index as a separate operand from the coordinates, SPIR-V image access
236    /// instructions include the array index in the `coordinates` operand. This
237    /// function builds a SPIR-V coordinate vector from a Naga coordinate vector
238    /// and array index, if one is supplied, and returns a `ImageCoordinates`
239    /// struct describing what it built.
240    ///
241    /// If `array_index` is `Some(expr)`, then this function constructs a new
242    /// vector that is `coordinates` with `array_index` concatenated onto the
243    /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on.
244    ///
245    /// If `array_index` is `None`, then the return value uses `coordinates`
246    /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be
247    /// a scalar value.
248    ///
249    /// If needed, this function generates code to convert the array index,
250    /// always an integer scalar, to match the component type of `coordinates`.
251    /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and
252    /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample`
253    /// and SPIR-V's `OpImageSample...` instructions all take floating-point
254    /// coordinate vectors.
255    ///
256    /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad
257    /// [`ImageSample`]: crate::Expression::ImageSample
258    fn write_image_coordinates(
259        &mut self,
260        coordinates: Handle<crate::Expression>,
261        array_index: Option<Handle<crate::Expression>>,
262        block: &mut Block,
263    ) -> Result<ImageCoordinates, Error> {
264        use crate::TypeInner as Ti;
265        use crate::VectorSize as Vs;
266
267        let coordinates_id = self.cached[coordinates];
268        let ty = &self.fun_info[coordinates].ty;
269        let inner_ty = ty.inner_with(&self.ir_module.types);
270
271        // If there's no array index, the image coordinates are exactly the
272        // `coordinate` field of the `Expression::ImageLoad`. No work is needed.
273        let array_index = match array_index {
274            None => {
275                let value_id = coordinates_id;
276                let type_id = self.get_expression_type_id(ty);
277                let size = match *inner_ty {
278                    Ti::Scalar { .. } => None,
279                    Ti::Vector { size, .. } => Some(size),
280                    _ => return Err(Error::Validation("coordinate type")),
281                };
282                return Ok(ImageCoordinates {
283                    value_id,
284                    type_id,
285                    size,
286                });
287            }
288            Some(ix) => ix,
289        };
290
291        // Find the component type of `coordinates`, and figure out the size the
292        // combined coordinate vector will have.
293        let (component_scalar, size) = match *inner_ty {
294            Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Vs::Bi),
295            Ti::Vector {
296                scalar: scalar @ crate::Scalar { width: 4, .. },
297                size: Vs::Bi,
298            } => (scalar, Vs::Tri),
299            Ti::Vector {
300                scalar: scalar @ crate::Scalar { width: 4, .. },
301                size: Vs::Tri,
302            } => (scalar, Vs::Quad),
303            Ti::Vector { size: Vs::Quad, .. } => {
304                return Err(Error::Validation("extending vec4 coordinate"));
305            }
306            ref other => {
307                log::error!("wrong coordinate type {other:?}");
308                return Err(Error::Validation("coordinate type"));
309            }
310        };
311
312        // Convert the index to the coordinate component type, if necessary.
313        let array_index_id = self.cached[array_index];
314        let ty = &self.fun_info[array_index].ty;
315        let inner_ty = ty.inner_with(&self.ir_module.types);
316        let array_index_scalar = match *inner_ty {
317            Ti::Scalar(
318                scalar @ crate::Scalar {
319                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
320                    width: 4,
321                },
322            ) => scalar,
323            _ => unreachable!("we only allow i32 and u32"),
324        };
325        let cast = match (component_scalar.kind, array_index_scalar.kind) {
326            (crate::ScalarKind::Sint, crate::ScalarKind::Sint)
327            | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None,
328            (crate::ScalarKind::Sint, crate::ScalarKind::Uint)
329            | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast),
330            (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF),
331            (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF),
332            (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"),
333            (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => {
334                unreachable!("we don't allow bool or float for array index")
335            }
336            (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _)
337            | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => {
338                unreachable!("abstract types should never reach backends")
339            }
340        };
341        let reconciled_array_index_id = if let Some(cast) = cast {
342            let component_ty_id = self.get_numeric_type_id(NumericType::Scalar(component_scalar));
343            let reconciled_id = self.gen_id();
344            block.body.push(Instruction::unary(
345                cast,
346                component_ty_id,
347                reconciled_id,
348                array_index_id,
349            ));
350            reconciled_id
351        } else {
352            array_index_id
353        };
354
355        // Find the SPIR-V type for the combined coordinates/index vector.
356        let type_id = self.get_numeric_type_id(NumericType::Vector {
357            size,
358            scalar: component_scalar,
359        });
360
361        // Schmear the coordinates and index together.
362        let value_id = self.gen_id();
363        block.body.push(Instruction::composite_construct(
364            type_id,
365            value_id,
366            &[coordinates_id, reconciled_array_index_id],
367        ));
368        Ok(ImageCoordinates {
369            value_id,
370            type_id,
371            size: Some(size),
372        })
373    }
374
375    pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
376        let id = match self.ir_function.expressions[expr_handle] {
377            crate::Expression::GlobalVariable(handle) => {
378                self.writer.global_variables[handle].handle_id
379            }
380            crate::Expression::FunctionArgument(i) => {
381                self.function.parameters[i as usize].handle_id
382            }
383            crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
384                self.cached[expr_handle]
385            }
386            ref other => unreachable!("Unexpected image expression {:?}", other),
387        };
388
389        if id == 0 {
390            unreachable!(
391                "Image expression {:?} doesn't have a handle ID",
392                expr_handle
393            );
394        }
395
396        id
397    }
398
399    /// Generate a vector or scalar 'one' for arithmetic on `coordinates`.
400    ///
401    /// If `coordinates` is a scalar, return a scalar one. Otherwise, return
402    /// a vector of ones.
403    fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> {
404        let one = self.get_scope_constant(1);
405        match coordinates.size {
406            None => Ok(one),
407            Some(vector_size) => {
408                let ones = [one; 4];
409                let id = self.gen_id();
410                Instruction::constant_composite(
411                    coordinates.type_id,
412                    id,
413                    &ones[..vector_size as usize],
414                )
415                .to_words(&mut self.writer.logical_layout.declarations);
416                Ok(id)
417            }
418        }
419    }
420
421    /// Generate code to restrict `input` to fall between zero and one less than
422    /// `size_id`.
423    ///
424    /// Both must be 32-bit scalar integer values, whose type is given by
425    /// `type_id`. The computed value is also of type `type_id`.
426    fn restrict_scalar(
427        &mut self,
428        type_id: Word,
429        input_id: Word,
430        size_id: Word,
431        block: &mut Block,
432    ) -> Result<Word, Error> {
433        let i32_one_id = self.get_scope_constant(1);
434
435        // Subtract one from `size` to get the largest valid value.
436        let limit_id = self.gen_id();
437        block.body.push(Instruction::binary(
438            spirv::Op::ISub,
439            type_id,
440            limit_id,
441            size_id,
442            i32_one_id,
443        ));
444
445        // Use an unsigned minimum, to handle both positive out-of-range values
446        // and negative values in a single instruction: negative values of
447        // `input_id` get treated as very large positive values.
448        let restricted_id = self.gen_id();
449        block.body.push(Instruction::ext_inst(
450            self.writer.gl450_ext_inst_id,
451            spirv::GLOp::UMin,
452            type_id,
453            restricted_id,
454            &[input_id, limit_id],
455        ));
456
457        Ok(restricted_id)
458    }
459
460    /// Write instructions to query the size of an image.
461    ///
462    /// This takes care of selecting the right instruction depending on whether
463    /// a level of detail parameter is present.
464    fn write_coordinate_bounds(
465        &mut self,
466        type_id: Word,
467        image_id: Word,
468        level_id: Option<Word>,
469        block: &mut Block,
470    ) -> Word {
471        let coordinate_bounds_id = self.gen_id();
472        match level_id {
473            Some(level_id) => {
474                // A level of detail was provided, so fetch the image size for
475                // that level.
476                let mut inst = Instruction::image_query(
477                    spirv::Op::ImageQuerySizeLod,
478                    type_id,
479                    coordinate_bounds_id,
480                    image_id,
481                );
482                inst.add_operand(level_id);
483                block.body.push(inst);
484            }
485            _ => {
486                // No level of detail was given.
487                block.body.push(Instruction::image_query(
488                    spirv::Op::ImageQuerySize,
489                    type_id,
490                    coordinate_bounds_id,
491                    image_id,
492                ));
493            }
494        }
495
496        coordinate_bounds_id
497    }
498
499    /// Write code to restrict coordinates for an image reference.
500    ///
501    /// First, clamp the level of detail or sample index to fall within bounds.
502    /// Then, obtain the image size, possibly using the clamped level of detail.
503    /// Finally, use an unsigned minimum instruction to force all coordinates
504    /// into range.
505    ///
506    /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate
507    /// vector (including the array index, if any), `LEVEL` is an optional level
508    /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to
509    /// be in-bounds for `image_id`.
510    ///
511    /// The result is usually a vector, but it is a scalar when indexing
512    /// non-arrayed 1D images.
513    fn write_restricted_coordinates(
514        &mut self,
515        image_id: Word,
516        coordinates: ImageCoordinates,
517        level_id: Option<Word>,
518        sample_id: Option<Word>,
519        block: &mut Block,
520    ) -> Result<(Word, Option<Word>, Option<Word>), Error> {
521        self.writer.require_any(
522            "the `Restrict` image bounds check policy",
523            &[spirv::Capability::ImageQuery],
524        )?;
525
526        let i32_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
527
528        // If `level` is `Some`, clamp it to fall within bounds. This must
529        // happen first, because we'll use it to query the image size for
530        // clamping the actual coordinates.
531        let level_id = level_id
532            .map(|level_id| {
533                // Find the number of mipmap levels in this image.
534                let num_levels_id = self.gen_id();
535                block.body.push(Instruction::image_query(
536                    spirv::Op::ImageQueryLevels,
537                    i32_type_id,
538                    num_levels_id,
539                    image_id,
540                ));
541
542                self.restrict_scalar(i32_type_id, level_id, num_levels_id, block)
543            })
544            .transpose()?;
545
546        // If `sample_id` is `Some`, clamp it to fall within bounds.
547        let sample_id = sample_id
548            .map(|sample_id| {
549                // Find the number of samples per texel.
550                let num_samples_id = self.gen_id();
551                block.body.push(Instruction::image_query(
552                    spirv::Op::ImageQuerySamples,
553                    i32_type_id,
554                    num_samples_id,
555                    image_id,
556                ));
557
558                self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block)
559            })
560            .transpose()?;
561
562        // Obtain the image bounds, including the array element count.
563        let coordinate_bounds_id =
564            self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block);
565
566        // Compute maximum valid values from the bounds.
567        let ones = self.write_coordinate_one(&coordinates)?;
568        let coordinate_limit_id = self.gen_id();
569        block.body.push(Instruction::binary(
570            spirv::Op::ISub,
571            coordinates.type_id,
572            coordinate_limit_id,
573            coordinate_bounds_id,
574            ones,
575        ));
576
577        // Restrict the coordinates to fall within those bounds.
578        //
579        // Use an unsigned minimum, to handle both positive out-of-range values
580        // and negative values in a single instruction: negative values of
581        // `coordinates` get treated as very large positive values.
582        let restricted_coordinates_id = self.gen_id();
583        block.body.push(Instruction::ext_inst(
584            self.writer.gl450_ext_inst_id,
585            spirv::GLOp::UMin,
586            coordinates.type_id,
587            restricted_coordinates_id,
588            &[coordinates.value_id, coordinate_limit_id],
589        ));
590
591        Ok((restricted_coordinates_id, level_id, sample_id))
592    }
593
594    fn write_conditional_image_access<A: Access>(
595        &mut self,
596        image_id: Word,
597        coordinates: ImageCoordinates,
598        level_id: Option<Word>,
599        sample_id: Option<Word>,
600        block: &mut Block,
601        access: &A,
602    ) -> Result<A::Output, Error> {
603        self.writer.require_any(
604            "the `ReadZeroSkipWrite` image bounds check policy",
605            &[spirv::Capability::ImageQuery],
606        )?;
607
608        let bool_type_id = self.writer.get_bool_type_id();
609        let i32_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
610
611        let null_id = access.out_of_bounds_value(self);
612
613        let mut selection = Selection::start(block, access.result_type());
614
615        // If `level_id` is `Some`, check whether it is within bounds. This must
616        // happen first, because we'll be supplying this as an argument when we
617        // query the image size.
618        if let Some(level_id) = level_id {
619            // Find the number of mipmap levels in this image.
620            let num_levels_id = self.gen_id();
621            selection.block().body.push(Instruction::image_query(
622                spirv::Op::ImageQueryLevels,
623                i32_type_id,
624                num_levels_id,
625                image_id,
626            ));
627
628            let lod_cond_id = self.gen_id();
629            selection.block().body.push(Instruction::binary(
630                spirv::Op::ULessThan,
631                bool_type_id,
632                lod_cond_id,
633                level_id,
634                num_levels_id,
635            ));
636
637            selection.if_true(self, lod_cond_id, null_id);
638        }
639
640        // If `sample_id` is `Some`, check whether it is in bounds.
641        if let Some(sample_id) = sample_id {
642            // Find the number of samples per texel.
643            let num_samples_id = self.gen_id();
644            selection.block().body.push(Instruction::image_query(
645                spirv::Op::ImageQuerySamples,
646                i32_type_id,
647                num_samples_id,
648                image_id,
649            ));
650
651            let samples_cond_id = self.gen_id();
652            selection.block().body.push(Instruction::binary(
653                spirv::Op::ULessThan,
654                bool_type_id,
655                samples_cond_id,
656                sample_id,
657                num_samples_id,
658            ));
659
660            selection.if_true(self, samples_cond_id, null_id);
661        }
662
663        // Obtain the image bounds, including any array element count.
664        let coordinate_bounds_id = self.write_coordinate_bounds(
665            coordinates.type_id,
666            image_id,
667            level_id,
668            selection.block(),
669        );
670
671        // Compare the coordinates against the bounds.
672        let coords_numeric_type = match coordinates.size {
673            Some(size) => NumericType::Vector {
674                size,
675                scalar: crate::Scalar::BOOL,
676            },
677            None => NumericType::Scalar(crate::Scalar::BOOL),
678        };
679        let coords_bool_type_id = self.get_numeric_type_id(coords_numeric_type);
680        let coords_conds_id = self.gen_id();
681        selection.block().body.push(Instruction::binary(
682            spirv::Op::ULessThan,
683            coords_bool_type_id,
684            coords_conds_id,
685            coordinates.value_id,
686            coordinate_bounds_id,
687        ));
688
689        // If the comparison above was a vector comparison, then we need to
690        // check that all components of the comparison are true.
691        let coords_cond_id = if coords_bool_type_id != bool_type_id {
692            let id = self.gen_id();
693            selection.block().body.push(Instruction::relational(
694                spirv::Op::All,
695                bool_type_id,
696                id,
697                coords_conds_id,
698            ));
699            id
700        } else {
701            coords_conds_id
702        };
703
704        selection.if_true(self, coords_cond_id, null_id);
705
706        // All conditions are met. We can carry out the access.
707        let texel_id = access.generate(
708            &mut self.writer.id_gen,
709            coordinates.value_id,
710            level_id,
711            sample_id,
712            selection.block(),
713        );
714
715        // This, then, is the value of the 'true' branch.
716        Ok(selection.finish(self, texel_id))
717    }
718
719    /// Generate code for an `ImageLoad` expression.
720    ///
721    /// The arguments are the components of an `Expression::ImageLoad` variant.
722    #[allow(clippy::too_many_arguments)]
723    pub(super) fn write_image_load(
724        &mut self,
725        result_type_id: Word,
726        image: Handle<crate::Expression>,
727        coordinate: Handle<crate::Expression>,
728        array_index: Option<Handle<crate::Expression>>,
729        level: Option<Handle<crate::Expression>>,
730        sample: Option<Handle<crate::Expression>>,
731        block: &mut Block,
732    ) -> Result<Word, Error> {
733        let image_id = self.get_handle_id(image);
734        let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types);
735        let image_class = match *image_type {
736            crate::TypeInner::Image { class, .. } => class,
737            _ => return Err(Error::Validation("image type")),
738        };
739
740        let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?;
741        let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
742
743        let level_id = level.map(|expr| self.cached[expr]);
744        let sample_id = sample.map(|expr| self.cached[expr]);
745
746        // Perform the access, according to the bounds check policy.
747        let access_id = match self.writer.bounds_check_policies.image_load {
748            crate::proc::BoundsCheckPolicy::Restrict => {
749                let (coords, level_id, sample_id) = self.write_restricted_coordinates(
750                    image_id,
751                    coordinates,
752                    level_id,
753                    sample_id,
754                    block,
755                )?;
756                access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block)
757            }
758            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self
759                .write_conditional_image_access(
760                    image_id,
761                    coordinates,
762                    level_id,
763                    sample_id,
764                    block,
765                    &access,
766                )?,
767            crate::proc::BoundsCheckPolicy::Unchecked => access.generate(
768                &mut self.writer.id_gen,
769                coordinates.value_id,
770                level_id,
771                sample_id,
772                block,
773            ),
774        };
775
776        // For depth images, `ImageLoad` expressions produce a single f32,
777        // whereas the SPIR-V instructions always produce a vec4. So we may have
778        // to pull out the component we need.
779        let result_id = if result_type_id == access.result_type() {
780            // The instruction produced the type we expected. We can use
781            // its result as-is.
782            access_id
783        } else {
784            // For `ImageClass::Depth` images, SPIR-V gave us four components,
785            // but we only want the first one.
786            let component_id = self.gen_id();
787            block.body.push(Instruction::composite_extract(
788                result_type_id,
789                component_id,
790                access_id,
791                &[0],
792            ));
793            component_id
794        };
795
796        Ok(result_id)
797    }
798
799    /// Generate code for an `ImageSample` expression.
800    ///
801    /// The arguments are the components of an `Expression::ImageSample` variant.
802    #[allow(clippy::too_many_arguments)]
803    pub(super) fn write_image_sample(
804        &mut self,
805        result_type_id: Word,
806        image: Handle<crate::Expression>,
807        sampler: Handle<crate::Expression>,
808        gather: Option<crate::SwizzleComponent>,
809        coordinate: Handle<crate::Expression>,
810        array_index: Option<Handle<crate::Expression>>,
811        offset: Option<Handle<crate::Expression>>,
812        level: crate::SampleLevel,
813        depth_ref: Option<Handle<crate::Expression>>,
814        clamp_to_edge: bool,
815        block: &mut Block,
816    ) -> Result<Word, Error> {
817        use super::instructions::SampleLod;
818        // image
819        let image_id = self.get_handle_id(image);
820        let image_type = self.fun_info[image].ty.handle().unwrap();
821        // SPIR-V doesn't know about our `Depth` class, and it returns
822        // `vec4<f32>`, so we need to grab the first component out of it.
823        let needs_sub_access = match self.ir_module.types[image_type].inner {
824            crate::TypeInner::Image {
825                class: crate::ImageClass::Depth { .. },
826                ..
827            } => depth_ref.is_none() && gather.is_none(),
828            _ => false,
829        };
830        let sample_result_type_id = if needs_sub_access {
831            self.get_numeric_type_id(NumericType::Vector {
832                size: crate::VectorSize::Quad,
833                scalar: crate::Scalar::F32,
834            })
835        } else {
836            result_type_id
837        };
838
839        // OpTypeSampledImage
840        let image_type_id = self.get_handle_type_id(image_type);
841        let sampled_image_type_id =
842            self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));
843
844        let sampler_id = self.get_handle_id(sampler);
845
846        let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
847        let coordinates_id = if clamp_to_edge {
848            self.writer.require_any(
849                "clamp sample coordinates to edge",
850                &[spirv::Capability::ImageQuery],
851            )?;
852
853            // clamp_to_edge can only be used with Level 0, and no array offset, offset,
854            // depth_ref or gather. This should have been caught by validation. Rather
855            // than entirely duplicate validation code here just ensure the level is
856            // zero, as we rely on that to query the texture size in order to calculate
857            // the clamped coordinates.
858            if level != crate::SampleLevel::Zero {
859                return Err(Error::Validation(
860                    "ImageSample::clamp_to_edge requires SampleLevel::Zero",
861                ));
862            }
863
864            // Query the size of level 0 of the texture.
865            let image_size_id = self.gen_id();
866            let vec2u_type_id = self.writer.get_vec2u_type_id();
867            let const_zero_uint_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
868            let mut query_inst = Instruction::image_query(
869                spirv::Op::ImageQuerySizeLod,
870                vec2u_type_id,
871                image_size_id,
872                image_id,
873            );
874            query_inst.add_operand(const_zero_uint_id);
875            block.body.push(query_inst);
876
877            let image_size_f_id = self.gen_id();
878            let vec2f_type_id = self.writer.get_vec2f_type_id();
879            block.body.push(Instruction::unary(
880                spirv::Op::ConvertUToF,
881                vec2f_type_id,
882                image_size_f_id,
883                image_size_id,
884            ));
885
886            // Calculate the top-left and bottom-right margin for clamping to. I.e. a
887            // half-texel from each side.
888            let const_0_5_f32_id = self.writer.get_constant_scalar(crate::Literal::F32(0.5));
889            let const_0_5_vec2f_id = self.writer.get_constant_composite(
890                LookupType::Local(LocalType::Numeric(NumericType::Vector {
891                    size: crate::VectorSize::Bi,
892                    scalar: crate::Scalar::F32,
893                })),
894                &[const_0_5_f32_id, const_0_5_f32_id],
895            );
896
897            let margin_left_id = self.gen_id();
898            block.body.push(Instruction::binary(
899                spirv::Op::FDiv,
900                vec2f_type_id,
901                margin_left_id,
902                const_0_5_vec2f_id,
903                image_size_f_id,
904            ));
905
906            let const_1_f32_id = self.writer.get_constant_scalar(crate::Literal::F32(1.0));
907            let const_1_vec2f_id = self.writer.get_constant_composite(
908                LookupType::Local(LocalType::Numeric(NumericType::Vector {
909                    size: crate::VectorSize::Bi,
910                    scalar: crate::Scalar::F32,
911                })),
912                &[const_1_f32_id, const_1_f32_id],
913            );
914
915            let margin_right_id = self.gen_id();
916            block.body.push(Instruction::binary(
917                spirv::Op::FSub,
918                vec2f_type_id,
919                margin_right_id,
920                const_1_vec2f_id,
921                margin_left_id,
922            ));
923
924            // Clamp the coords to the calculated margins
925            let clamped_coords_id = self.gen_id();
926            block.body.push(Instruction::ext_inst(
927                self.writer.gl450_ext_inst_id,
928                spirv::GLOp::NClamp,
929                vec2f_type_id,
930                clamped_coords_id,
931                &[coordinates.value_id, margin_left_id, margin_right_id],
932            ));
933
934            clamped_coords_id
935        } else {
936            coordinates.value_id
937        };
938
939        let sampled_image_id = self.gen_id();
940        block.body.push(Instruction::sampled_image(
941            sampled_image_type_id,
942            sampled_image_id,
943            image_id,
944            sampler_id,
945        ));
946        let id = self.gen_id();
947
948        let depth_id = depth_ref.map(|handle| self.cached[handle]);
949        let mut mask = spirv::ImageOperands::empty();
950        mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some());
951
952        let mut main_instruction = match (level, gather) {
953            (_, Some(component)) => {
954                let component_id = self.get_index_constant(component as u32);
955                let mut inst = Instruction::image_gather(
956                    sample_result_type_id,
957                    id,
958                    sampled_image_id,
959                    coordinates_id,
960                    component_id,
961                    depth_id,
962                );
963                if !mask.is_empty() {
964                    inst.add_operand(mask.bits());
965                }
966                inst
967            }
968            (crate::SampleLevel::Zero, None) => {
969                let mut inst = Instruction::image_sample(
970                    sample_result_type_id,
971                    id,
972                    SampleLod::Explicit,
973                    sampled_image_id,
974                    coordinates_id,
975                    depth_id,
976                );
977
978                let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0));
979
980                mask |= spirv::ImageOperands::LOD;
981                inst.add_operand(mask.bits());
982                inst.add_operand(zero_id);
983
984                inst
985            }
986            (crate::SampleLevel::Auto, None) => {
987                let mut inst = Instruction::image_sample(
988                    sample_result_type_id,
989                    id,
990                    SampleLod::Implicit,
991                    sampled_image_id,
992                    coordinates_id,
993                    depth_id,
994                );
995                if !mask.is_empty() {
996                    inst.add_operand(mask.bits());
997                }
998                inst
999            }
1000            (crate::SampleLevel::Exact(lod_handle), None) => {
1001                let mut inst = Instruction::image_sample(
1002                    sample_result_type_id,
1003                    id,
1004                    SampleLod::Explicit,
1005                    sampled_image_id,
1006                    coordinates_id,
1007                    depth_id,
1008                );
1009
1010                let mut lod_id = self.cached[lod_handle];
1011                // SPIR-V expects the LOD to be a float for all image classes.
1012                // lod_id, however, will be an integer for depth images,
1013                // therefore we must do a conversion.
1014                if matches!(
1015                    self.ir_module.types[image_type].inner,
1016                    crate::TypeInner::Image {
1017                        class: crate::ImageClass::Depth { .. },
1018                        ..
1019                    }
1020                ) {
1021                    let lod_f32_id = self.gen_id();
1022                    let f32_type_id =
1023                        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32));
1024                    let convert_op = match *self.fun_info[lod_handle]
1025                        .ty
1026                        .inner_with(&self.ir_module.types)
1027                    {
1028                        crate::TypeInner::Scalar(crate::Scalar {
1029                            kind: crate::ScalarKind::Uint,
1030                            width: 4,
1031                        }) => spirv::Op::ConvertUToF,
1032                        crate::TypeInner::Scalar(crate::Scalar {
1033                            kind: crate::ScalarKind::Sint,
1034                            width: 4,
1035                        }) => spirv::Op::ConvertSToF,
1036                        _ => unreachable!(),
1037                    };
1038                    block.body.push(Instruction::unary(
1039                        convert_op,
1040                        f32_type_id,
1041                        lod_f32_id,
1042                        lod_id,
1043                    ));
1044                    lod_id = lod_f32_id;
1045                }
1046                mask |= spirv::ImageOperands::LOD;
1047                inst.add_operand(mask.bits());
1048                inst.add_operand(lod_id);
1049
1050                inst
1051            }
1052            (crate::SampleLevel::Bias(bias_handle), None) => {
1053                let mut inst = Instruction::image_sample(
1054                    sample_result_type_id,
1055                    id,
1056                    SampleLod::Implicit,
1057                    sampled_image_id,
1058                    coordinates_id,
1059                    depth_id,
1060                );
1061
1062                let bias_id = self.cached[bias_handle];
1063                mask |= spirv::ImageOperands::BIAS;
1064                inst.add_operand(mask.bits());
1065                inst.add_operand(bias_id);
1066
1067                inst
1068            }
1069            (crate::SampleLevel::Gradient { x, y }, None) => {
1070                let mut inst = Instruction::image_sample(
1071                    sample_result_type_id,
1072                    id,
1073                    SampleLod::Explicit,
1074                    sampled_image_id,
1075                    coordinates_id,
1076                    depth_id,
1077                );
1078
1079                let x_id = self.cached[x];
1080                let y_id = self.cached[y];
1081                mask |= spirv::ImageOperands::GRAD;
1082                inst.add_operand(mask.bits());
1083                inst.add_operand(x_id);
1084                inst.add_operand(y_id);
1085
1086                inst
1087            }
1088        };
1089
1090        if let Some(offset_const) = offset {
1091            let offset_id = self.cached[offset_const];
1092            main_instruction.add_operand(offset_id);
1093        }
1094
1095        block.body.push(main_instruction);
1096
1097        let id = if needs_sub_access {
1098            let sub_id = self.gen_id();
1099            block.body.push(Instruction::composite_extract(
1100                result_type_id,
1101                sub_id,
1102                id,
1103                &[0],
1104            ));
1105            sub_id
1106        } else {
1107            id
1108        };
1109
1110        Ok(id)
1111    }
1112
1113    /// Generate code for an `ImageQuery` expression.
1114    ///
1115    /// The arguments are the components of an `Expression::ImageQuery` variant.
1116    pub(super) fn write_image_query(
1117        &mut self,
1118        result_type_id: Word,
1119        image: Handle<crate::Expression>,
1120        query: crate::ImageQuery,
1121        block: &mut Block,
1122    ) -> Result<Word, Error> {
1123        use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq};
1124
1125        let image_id = self.get_handle_id(image);
1126        let image_type = self.fun_info[image].ty.handle().unwrap();
1127        let (dim, arrayed, class) = match self.ir_module.types[image_type].inner {
1128            crate::TypeInner::Image {
1129                dim,
1130                arrayed,
1131                class,
1132            } => (dim, arrayed, class),
1133            _ => {
1134                return Err(Error::Validation("image type"));
1135            }
1136        };
1137
1138        self.writer
1139            .require_any("image queries", &[spirv::Capability::ImageQuery])?;
1140
1141        let id = match query {
1142            Iq::Size { level } => {
1143                let dim_coords = match dim {
1144                    Id::D1 => 1,
1145                    Id::D2 | Id::Cube => 2,
1146                    Id::D3 => 3,
1147                };
1148                let array_coords = usize::from(arrayed);
1149                let vector_size = match dim_coords + array_coords {
1150                    2 => Some(crate::VectorSize::Bi),
1151                    3 => Some(crate::VectorSize::Tri),
1152                    4 => Some(crate::VectorSize::Quad),
1153                    _ => None,
1154                };
1155                let vector_numeric_type = match vector_size {
1156                    Some(size) => NumericType::Vector {
1157                        size,
1158                        scalar: crate::Scalar::U32,
1159                    },
1160                    None => NumericType::Scalar(crate::Scalar::U32),
1161                };
1162
1163                let extended_size_type_id = self.get_numeric_type_id(vector_numeric_type);
1164
1165                let (query_op, level_id) = match class {
1166                    Ic::Sampled { multi: true, .. }
1167                    | Ic::Depth { multi: true }
1168                    | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None),
1169                    _ => {
1170                        let level_id = match level {
1171                            Some(expr) => self.cached[expr],
1172                            None => self.get_index_constant(0),
1173                        };
1174                        (spirv::Op::ImageQuerySizeLod, Some(level_id))
1175                    }
1176                };
1177
1178                // The ID of the vector returned by SPIR-V, which contains the dimensions
1179                // as well as the layer count.
1180                let id_extended = self.gen_id();
1181                let mut inst = Instruction::image_query(
1182                    query_op,
1183                    extended_size_type_id,
1184                    id_extended,
1185                    image_id,
1186                );
1187                if let Some(expr_id) = level_id {
1188                    inst.add_operand(expr_id);
1189                }
1190                block.body.push(inst);
1191
1192                if result_type_id != extended_size_type_id {
1193                    let id = self.gen_id();
1194                    let components = match dim {
1195                        // always pick the first component, and duplicate it for all 3 dimensions
1196                        Id::Cube => &[0u32, 0][..],
1197                        _ => &[0u32, 1, 2, 3][..dim_coords],
1198                    };
1199                    block.body.push(Instruction::vector_shuffle(
1200                        result_type_id,
1201                        id,
1202                        id_extended,
1203                        id_extended,
1204                        components,
1205                    ));
1206
1207                    id
1208                } else {
1209                    id_extended
1210                }
1211            }
1212            Iq::NumLevels => {
1213                let query_id = self.gen_id();
1214                block.body.push(Instruction::image_query(
1215                    spirv::Op::ImageQueryLevels,
1216                    result_type_id,
1217                    query_id,
1218                    image_id,
1219                ));
1220
1221                query_id
1222            }
1223            Iq::NumLayers => {
1224                let vec_size = match dim {
1225                    Id::D1 => crate::VectorSize::Bi,
1226                    Id::D2 | Id::Cube => crate::VectorSize::Tri,
1227                    Id::D3 => crate::VectorSize::Quad,
1228                };
1229                let extended_size_type_id = self.get_numeric_type_id(NumericType::Vector {
1230                    size: vec_size,
1231                    scalar: crate::Scalar::U32,
1232                });
1233                let id_extended = self.gen_id();
1234                let mut inst = Instruction::image_query(
1235                    spirv::Op::ImageQuerySizeLod,
1236                    extended_size_type_id,
1237                    id_extended,
1238                    image_id,
1239                );
1240                inst.add_operand(self.get_index_constant(0));
1241                block.body.push(inst);
1242
1243                let extract_id = self.gen_id();
1244                block.body.push(Instruction::composite_extract(
1245                    result_type_id,
1246                    extract_id,
1247                    id_extended,
1248                    &[vec_size as u32 - 1],
1249                ));
1250
1251                extract_id
1252            }
1253            Iq::NumSamples => {
1254                let query_id = self.gen_id();
1255                block.body.push(Instruction::image_query(
1256                    spirv::Op::ImageQuerySamples,
1257                    result_type_id,
1258                    query_id,
1259                    image_id,
1260                ));
1261
1262                query_id
1263            }
1264        };
1265
1266        Ok(id)
1267    }
1268
1269    pub(super) fn write_image_store(
1270        &mut self,
1271        image: Handle<crate::Expression>,
1272        coordinate: Handle<crate::Expression>,
1273        array_index: Option<Handle<crate::Expression>>,
1274        value: Handle<crate::Expression>,
1275        block: &mut Block,
1276    ) -> Result<(), Error> {
1277        let image_id = self.get_handle_id(image);
1278        let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
1279        let value_id = self.cached[value];
1280
1281        let write = Store { image_id, value_id };
1282
1283        match *self.fun_info[image].ty.inner_with(&self.ir_module.types) {
1284            crate::TypeInner::Image {
1285                class:
1286                    crate::ImageClass::Storage {
1287                        format: crate::StorageFormat::Bgra8Unorm,
1288                        ..
1289                    },
1290                ..
1291            } => self.writer.require_any(
1292                "Bgra8Unorm storage write",
1293                &[spirv::Capability::StorageImageWriteWithoutFormat],
1294            )?,
1295            _ => {}
1296        }
1297
1298        write.generate(
1299            &mut self.writer.id_gen,
1300            coordinates.value_id,
1301            None,
1302            None,
1303            block,
1304        );
1305
1306        Ok(())
1307    }
1308
1309    pub(super) fn write_image_atomic(
1310        &mut self,
1311        image: Handle<crate::Expression>,
1312        coordinate: Handle<crate::Expression>,
1313        array_index: Option<Handle<crate::Expression>>,
1314        fun: crate::AtomicFunction,
1315        value: Handle<crate::Expression>,
1316        block: &mut Block,
1317    ) -> Result<(), Error> {
1318        let image_id = match self.ir_function.originating_global(image) {
1319            Some(handle) => self.writer.global_variables[handle].var_id,
1320            _ => return Err(Error::Validation("Unexpected image type")),
1321        };
1322        let crate::TypeInner::Image { class, .. } =
1323            *self.fun_info[image].ty.inner_with(&self.ir_module.types)
1324        else {
1325            return Err(Error::Validation("Invalid image type"));
1326        };
1327        let crate::ImageClass::Storage { format, .. } = class else {
1328            return Err(Error::Validation("Invalid image class"));
1329        };
1330        let scalar = format.into();
1331        let scalar_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
1332        let pointer_type_id = self.get_pointer_type_id(scalar_type_id, spirv::StorageClass::Image);
1333        let signed = scalar.kind == crate::ScalarKind::Sint;
1334        if scalar.width == 8 {
1335            self.writer
1336                .require_any("64 bit image atomics", &[spirv::Capability::Int64Atomics])?;
1337        }
1338        let pointer_id = self.gen_id();
1339        let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
1340        let sample_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
1341        block.body.push(Instruction::image_texel_pointer(
1342            pointer_type_id,
1343            pointer_id,
1344            image_id,
1345            coordinates.value_id,
1346            sample_id,
1347        ));
1348
1349        let op = match fun {
1350            crate::AtomicFunction::Add => spirv::Op::AtomicIAdd,
1351            crate::AtomicFunction::Subtract => spirv::Op::AtomicISub,
1352            crate::AtomicFunction::And => spirv::Op::AtomicAnd,
1353            crate::AtomicFunction::ExclusiveOr => spirv::Op::AtomicXor,
1354            crate::AtomicFunction::InclusiveOr => spirv::Op::AtomicOr,
1355            crate::AtomicFunction::Min if signed => spirv::Op::AtomicSMin,
1356            crate::AtomicFunction::Min => spirv::Op::AtomicUMin,
1357            crate::AtomicFunction::Max if signed => spirv::Op::AtomicSMax,
1358            crate::AtomicFunction::Max => spirv::Op::AtomicUMax,
1359            crate::AtomicFunction::Exchange { .. } => {
1360                return Err(Error::Validation("Exchange atomics are not supported yet"))
1361            }
1362        };
1363        let result_type_id = self.get_expression_type_id(&self.fun_info[value].ty);
1364        let id = self.gen_id();
1365        let space = crate::AddressSpace::Handle;
1366        let (semantics, scope) = space.to_spirv_semantics_and_scope();
1367        let scope_constant_id = self.get_scope_constant(scope as u32);
1368        let semantics_id = self.get_index_constant(semantics.bits());
1369        let value_id = self.cached[value];
1370
1371        block.body.push(Instruction::image_atomic(
1372            op,
1373            result_type_id,
1374            id,
1375            pointer_id,
1376            scope_constant_id,
1377            semantics_id,
1378            value_id,
1379        ));
1380
1381        Ok(())
1382    }
1383}